mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-28 17:29:39 +00:00
graph color curriculum (#303)
This commit is contained in:
parent
2fca962847
commit
e0f8ef061d
3 changed files with 121 additions and 11 deletions
|
|
@ -15,7 +15,7 @@ from .count_primes import CountPrimesConfig, CountPrimesCurriculum, CountPrimesD
|
||||||
from .cryptarithm import CryptarithmConfig, CryptarithmDataset
|
from .cryptarithm import CryptarithmConfig, CryptarithmDataset
|
||||||
from .game_of_life import GameOfLifeConfig, GameOfLifeDataset
|
from .game_of_life import GameOfLifeConfig, GameOfLifeDataset
|
||||||
from .game_of_life_halting import GameOfLifeHaltingConfig, GameOfLifeHaltingDataset
|
from .game_of_life_halting import GameOfLifeHaltingConfig, GameOfLifeHaltingDataset
|
||||||
from .graph_color import GraphColorConfig, GraphColorDataset
|
from .graph_color import GraphColorConfig, GraphColorCurriculum, GraphColorDataset
|
||||||
from .group_anagrams import GroupAnagramsConfig, GroupAnagramsCurriculum, GroupAnagramsDataset
|
from .group_anagrams import GroupAnagramsConfig, GroupAnagramsCurriculum, GroupAnagramsDataset
|
||||||
from .isomorphic_strings import IsomorphicStringsConfig, IsomorphicStringsCurriculum, IsomorphicStringsDataset
|
from .isomorphic_strings import IsomorphicStringsConfig, IsomorphicStringsCurriculum, IsomorphicStringsDataset
|
||||||
from .jugs import JugsConfig, JugsDataset
|
from .jugs import JugsConfig, JugsDataset
|
||||||
|
|
@ -113,6 +113,7 @@ __all__ = [
|
||||||
"CountPrimesCurriculum",
|
"CountPrimesCurriculum",
|
||||||
"GraphColorConfig",
|
"GraphColorConfig",
|
||||||
"GraphColorDataset",
|
"GraphColorDataset",
|
||||||
|
"GraphColorCurriculum",
|
||||||
"StringInsertionConfig",
|
"StringInsertionConfig",
|
||||||
"StringInsertionDataset",
|
"StringInsertionDataset",
|
||||||
"StringManipulationConfig",
|
"StringManipulationConfig",
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||||
from ..factory import ProceduralDataset, register_dataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -154,8 +155,10 @@ def greedy_graph_coloring(puzzle):
|
||||||
class GraphColorConfig:
|
class GraphColorConfig:
|
||||||
"""Configuration for GraphColor puzzle generation"""
|
"""Configuration for GraphColor puzzle generation"""
|
||||||
|
|
||||||
num_colors: int = 4
|
min_num_colors: int = 3
|
||||||
num_vertices: int = 10
|
max_num_colors: int = 3
|
||||||
|
min_num_vertices: int = 10
|
||||||
|
max_num_vertices: int = 10
|
||||||
edge_probability: float = 0.4
|
edge_probability: float = 0.4
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500
|
size: int = 500
|
||||||
|
|
@ -187,9 +190,9 @@ class GraphColorDataset(ProceduralDataset):
|
||||||
while solution is None:
|
while solution is None:
|
||||||
puzzle = generate_graph_coloring_puzzle(
|
puzzle = generate_graph_coloring_puzzle(
|
||||||
rng=rng,
|
rng=rng,
|
||||||
num_vertices=self.config.num_vertices,
|
num_vertices=rng.randint(self.config.min_num_vertices, self.config.max_num_vertices),
|
||||||
edge_probability=self.config.edge_probability,
|
edge_probability=self.config.edge_probability,
|
||||||
num_colors=self.config.num_colors,
|
num_colors=rng.randint(self.config.min_num_colors, self.config.max_num_colors),
|
||||||
)
|
)
|
||||||
solution = greedy_graph_coloring(puzzle)
|
solution = greedy_graph_coloring(puzzle)
|
||||||
|
|
||||||
|
|
@ -237,4 +240,32 @@ Return your solution as a JSON map of vertices to colors. (For example: {{"0": 1
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
register_dataset("graph_color", GraphColorDataset, GraphColorConfig)
|
class GraphColorCurriculum(BaseCurriculum):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(GraphColorCurriculum.__name__, GraphColorConfig)
|
||||||
|
|
||||||
|
self._define_attributes(
|
||||||
|
RangeAttributeDefinition(
|
||||||
|
name="num_vertices",
|
||||||
|
levels=[10, 20, 30, 40],
|
||||||
|
default_level=0,
|
||||||
|
description="Number of vertices in the graph",
|
||||||
|
attr_type=AttributeType.APPEND,
|
||||||
|
min_value=10,
|
||||||
|
lower_field_name="min_num_vertices",
|
||||||
|
upper_field_name="max_num_vertices",
|
||||||
|
),
|
||||||
|
RangeAttributeDefinition(
|
||||||
|
name="num_colors",
|
||||||
|
levels=[3, 4, 5, 6],
|
||||||
|
default_level=0,
|
||||||
|
description="Number of colors in the graph",
|
||||||
|
attr_type=AttributeType.APPEND,
|
||||||
|
min_value=2,
|
||||||
|
lower_field_name="min_num_colors",
|
||||||
|
upper_field_name="max_num_colors",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_dataset("graph_color", GraphColorDataset, GraphColorConfig, GraphColorCurriculum)
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,20 @@ import json
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.algorithmic.graph_color import GraphColorConfig, GraphColorDataset
|
from reasoning_gym.algorithmic.graph_color import GraphColorConfig, GraphColorCurriculum, GraphColorDataset
|
||||||
|
|
||||||
|
|
||||||
def test_graph_color():
|
def test_graph_color():
|
||||||
"""Test basic properties and solution of generated items"""
|
"""Test basic properties and solution of generated items"""
|
||||||
config = GraphColorConfig(seed=42, size=10, num_vertices=10, num_colors=4, edge_probability=0.4)
|
config = GraphColorConfig(
|
||||||
|
seed=42,
|
||||||
|
size=10,
|
||||||
|
min_num_vertices=10,
|
||||||
|
max_num_vertices=10,
|
||||||
|
min_num_colors=4,
|
||||||
|
max_num_colors=4,
|
||||||
|
edge_probability=0.4,
|
||||||
|
)
|
||||||
dataset = GraphColorDataset(config)
|
dataset = GraphColorDataset(config)
|
||||||
|
|
||||||
# easy
|
# easy
|
||||||
|
|
@ -22,7 +30,15 @@ def test_graph_color():
|
||||||
assert dataset.score_answer(answer=None, entry=item) == 0.0
|
assert dataset.score_answer(answer=None, entry=item) == 0.0
|
||||||
|
|
||||||
# medium
|
# medium
|
||||||
config = GraphColorConfig(seed=42, size=1, num_vertices=10, num_colors=3, edge_probability=0.3)
|
config = GraphColorConfig(
|
||||||
|
seed=42,
|
||||||
|
size=1,
|
||||||
|
min_num_vertices=10,
|
||||||
|
max_num_vertices=10,
|
||||||
|
min_num_colors=3,
|
||||||
|
max_num_colors=3,
|
||||||
|
edge_probability=0.3,
|
||||||
|
)
|
||||||
dataset = GraphColorDataset(config)
|
dataset = GraphColorDataset(config)
|
||||||
|
|
||||||
for item in dataset:
|
for item in dataset:
|
||||||
|
|
@ -30,7 +46,15 @@ def test_graph_color():
|
||||||
assert dataset.score_answer(answer=None, entry=item) == 0.0
|
assert dataset.score_answer(answer=None, entry=item) == 0.0
|
||||||
|
|
||||||
# hard
|
# hard
|
||||||
config = GraphColorConfig(seed=42, size=1, num_vertices=40, num_colors=4, edge_probability=0.2)
|
config = GraphColorConfig(
|
||||||
|
seed=42,
|
||||||
|
size=1,
|
||||||
|
min_num_vertices=40,
|
||||||
|
max_num_vertices=40,
|
||||||
|
min_num_colors=4,
|
||||||
|
max_num_colors=4,
|
||||||
|
edge_probability=0.2,
|
||||||
|
)
|
||||||
dataset = GraphColorDataset(config)
|
dataset = GraphColorDataset(config)
|
||||||
|
|
||||||
for item in dataset:
|
for item in dataset:
|
||||||
|
|
@ -38,9 +62,63 @@ def test_graph_color():
|
||||||
assert dataset.score_answer(answer=None, entry=item) == 0.0
|
assert dataset.score_answer(answer=None, entry=item) == 0.0
|
||||||
|
|
||||||
# v hard
|
# v hard
|
||||||
config = GraphColorConfig(seed=42, size=1, num_vertices=50, num_colors=3, edge_probability=0.1)
|
config = GraphColorConfig(
|
||||||
|
seed=42,
|
||||||
|
size=1,
|
||||||
|
min_num_vertices=50,
|
||||||
|
max_num_vertices=50,
|
||||||
|
min_num_colors=3,
|
||||||
|
max_num_colors=3,
|
||||||
|
edge_probability=0.1,
|
||||||
|
)
|
||||||
dataset = GraphColorDataset(config)
|
dataset = GraphColorDataset(config)
|
||||||
|
|
||||||
for item in dataset:
|
for item in dataset:
|
||||||
assert dataset.score_answer(answer=json.dumps(item["metadata"]["possible_answer"]), entry=item) == 1.0
|
assert dataset.score_answer(answer=json.dumps(item["metadata"]["possible_answer"]), entry=item) == 1.0
|
||||||
assert dataset.score_answer(answer=None, entry=item) == 0.0
|
assert dataset.score_answer(answer=None, entry=item) == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_graph_color_curriculum():
|
||||||
|
curriculum = GraphColorCurriculum()
|
||||||
|
|
||||||
|
base_value = {"size": 150, "seed": 1}
|
||||||
|
|
||||||
|
base_cfg: GraphColorConfig = curriculum.generate_configuration(base_value)
|
||||||
|
assert base_cfg.size == 150
|
||||||
|
assert base_cfg.seed == 1
|
||||||
|
assert base_cfg.min_num_vertices == base_cfg.max_num_vertices == 10
|
||||||
|
assert base_cfg.min_num_colors == base_cfg.max_num_colors == 3
|
||||||
|
|
||||||
|
curriculum.increment_attr_level("num_vertices")
|
||||||
|
cfg = curriculum.generate_configuration(base_value)
|
||||||
|
assert cfg.min_num_vertices == 10
|
||||||
|
assert cfg.max_num_vertices == 20
|
||||||
|
|
||||||
|
curriculum.increment_attr_level("num_colors")
|
||||||
|
cfg = curriculum.generate_configuration(base_value)
|
||||||
|
assert cfg.min_num_colors == 3
|
||||||
|
assert cfg.max_num_colors == 4
|
||||||
|
|
||||||
|
curriculum.increment_attr_level("num_vertices")
|
||||||
|
curriculum.increment_attr_level("num_colors")
|
||||||
|
cfg = curriculum.generate_configuration(base_value)
|
||||||
|
assert cfg.min_num_vertices == 10
|
||||||
|
assert cfg.max_num_vertices == 30
|
||||||
|
assert cfg.min_num_colors == 3
|
||||||
|
assert cfg.max_num_colors == 5
|
||||||
|
|
||||||
|
curriculum.increment_attr_level("num_vertices")
|
||||||
|
curriculum.increment_attr_level("num_colors")
|
||||||
|
cfg = curriculum.generate_configuration(base_value)
|
||||||
|
assert cfg.min_num_vertices == 10
|
||||||
|
assert cfg.max_num_vertices == 40
|
||||||
|
assert cfg.min_num_colors == 3
|
||||||
|
assert cfg.max_num_colors == 6
|
||||||
|
|
||||||
|
curriculum.decrement_attr_level("num_vertices")
|
||||||
|
curriculum.decrement_attr_level("num_colors")
|
||||||
|
cfg = curriculum.generate_configuration(base_value)
|
||||||
|
assert cfg.min_num_vertices == 10
|
||||||
|
assert cfg.max_num_vertices == 30
|
||||||
|
assert cfg.min_num_colors == 3
|
||||||
|
assert cfg.max_num_colors == 5
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue