graph color curriculum (#303)

This commit is contained in:
vncntt 2025-03-08 23:20:47 -08:00 committed by GitHub
parent 2fca962847
commit e0f8ef061d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 121 additions and 11 deletions

View file

@ -15,7 +15,7 @@ from .count_primes import CountPrimesConfig, CountPrimesCurriculum, CountPrimesD
from .cryptarithm import CryptarithmConfig, CryptarithmDataset from .cryptarithm import CryptarithmConfig, CryptarithmDataset
from .game_of_life import GameOfLifeConfig, GameOfLifeDataset from .game_of_life import GameOfLifeConfig, GameOfLifeDataset
from .game_of_life_halting import GameOfLifeHaltingConfig, GameOfLifeHaltingDataset from .game_of_life_halting import GameOfLifeHaltingConfig, GameOfLifeHaltingDataset
from .graph_color import GraphColorConfig, GraphColorDataset from .graph_color import GraphColorConfig, GraphColorCurriculum, GraphColorDataset
from .group_anagrams import GroupAnagramsConfig, GroupAnagramsCurriculum, GroupAnagramsDataset from .group_anagrams import GroupAnagramsConfig, GroupAnagramsCurriculum, GroupAnagramsDataset
from .isomorphic_strings import IsomorphicStringsConfig, IsomorphicStringsCurriculum, IsomorphicStringsDataset from .isomorphic_strings import IsomorphicStringsConfig, IsomorphicStringsCurriculum, IsomorphicStringsDataset
from .jugs import JugsConfig, JugsDataset from .jugs import JugsConfig, JugsDataset
@ -113,6 +113,7 @@ __all__ = [
"CountPrimesCurriculum", "CountPrimesCurriculum",
"GraphColorConfig", "GraphColorConfig",
"GraphColorDataset", "GraphColorDataset",
"GraphColorCurriculum",
"StringInsertionConfig", "StringInsertionConfig",
"StringInsertionDataset", "StringInsertionDataset",
"StringManipulationConfig", "StringManipulationConfig",

View file

@ -3,6 +3,7 @@ from dataclasses import dataclass
from random import Random from random import Random
from typing import Any, Optional from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
@ -154,8 +155,10 @@ def greedy_graph_coloring(puzzle):
class GraphColorConfig: class GraphColorConfig:
"""Configuration for GraphColor puzzle generation""" """Configuration for GraphColor puzzle generation"""
num_colors: int = 4 min_num_colors: int = 3
num_vertices: int = 10 max_num_colors: int = 3
min_num_vertices: int = 10
max_num_vertices: int = 10
edge_probability: float = 0.4 edge_probability: float = 0.4
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 size: int = 500
@ -187,9 +190,9 @@ class GraphColorDataset(ProceduralDataset):
while solution is None: while solution is None:
puzzle = generate_graph_coloring_puzzle( puzzle = generate_graph_coloring_puzzle(
rng=rng, rng=rng,
num_vertices=self.config.num_vertices, num_vertices=rng.randint(self.config.min_num_vertices, self.config.max_num_vertices),
edge_probability=self.config.edge_probability, edge_probability=self.config.edge_probability,
num_colors=self.config.num_colors, num_colors=rng.randint(self.config.min_num_colors, self.config.max_num_colors),
) )
solution = greedy_graph_coloring(puzzle) solution = greedy_graph_coloring(puzzle)
@ -237,4 +240,32 @@ Return your solution as a JSON map of vertices to colors. (For example: {{"0": 1
return 0.0 return 0.0
register_dataset("graph_color", GraphColorDataset, GraphColorConfig) class GraphColorCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(GraphColorCurriculum.__name__, GraphColorConfig)
self._define_attributes(
RangeAttributeDefinition(
name="num_vertices",
levels=[10, 20, 30, 40],
default_level=0,
description="Number of vertices in the graph",
attr_type=AttributeType.APPEND,
min_value=10,
lower_field_name="min_num_vertices",
upper_field_name="max_num_vertices",
),
RangeAttributeDefinition(
name="num_colors",
levels=[3, 4, 5, 6],
default_level=0,
description="Number of colors in the graph",
attr_type=AttributeType.APPEND,
min_value=2,
lower_field_name="min_num_colors",
upper_field_name="max_num_colors",
),
)
register_dataset("graph_color", GraphColorDataset, GraphColorConfig, GraphColorCurriculum)

View file

@ -2,12 +2,20 @@ import json
import pytest import pytest
from reasoning_gym.algorithmic.graph_color import GraphColorConfig, GraphColorDataset from reasoning_gym.algorithmic.graph_color import GraphColorConfig, GraphColorCurriculum, GraphColorDataset
def test_graph_color(): def test_graph_color():
"""Test basic properties and solution of generated items""" """Test basic properties and solution of generated items"""
config = GraphColorConfig(seed=42, size=10, num_vertices=10, num_colors=4, edge_probability=0.4) config = GraphColorConfig(
seed=42,
size=10,
min_num_vertices=10,
max_num_vertices=10,
min_num_colors=4,
max_num_colors=4,
edge_probability=0.4,
)
dataset = GraphColorDataset(config) dataset = GraphColorDataset(config)
# easy # easy
@ -22,7 +30,15 @@ def test_graph_color():
assert dataset.score_answer(answer=None, entry=item) == 0.0 assert dataset.score_answer(answer=None, entry=item) == 0.0
# medium # medium
config = GraphColorConfig(seed=42, size=1, num_vertices=10, num_colors=3, edge_probability=0.3) config = GraphColorConfig(
seed=42,
size=1,
min_num_vertices=10,
max_num_vertices=10,
min_num_colors=3,
max_num_colors=3,
edge_probability=0.3,
)
dataset = GraphColorDataset(config) dataset = GraphColorDataset(config)
for item in dataset: for item in dataset:
@ -30,7 +46,15 @@ def test_graph_color():
assert dataset.score_answer(answer=None, entry=item) == 0.0 assert dataset.score_answer(answer=None, entry=item) == 0.0
# hard # hard
config = GraphColorConfig(seed=42, size=1, num_vertices=40, num_colors=4, edge_probability=0.2) config = GraphColorConfig(
seed=42,
size=1,
min_num_vertices=40,
max_num_vertices=40,
min_num_colors=4,
max_num_colors=4,
edge_probability=0.2,
)
dataset = GraphColorDataset(config) dataset = GraphColorDataset(config)
for item in dataset: for item in dataset:
@ -38,9 +62,63 @@ def test_graph_color():
assert dataset.score_answer(answer=None, entry=item) == 0.0 assert dataset.score_answer(answer=None, entry=item) == 0.0
# v hard # v hard
config = GraphColorConfig(seed=42, size=1, num_vertices=50, num_colors=3, edge_probability=0.1) config = GraphColorConfig(
seed=42,
size=1,
min_num_vertices=50,
max_num_vertices=50,
min_num_colors=3,
max_num_colors=3,
edge_probability=0.1,
)
dataset = GraphColorDataset(config) dataset = GraphColorDataset(config)
for item in dataset: for item in dataset:
assert dataset.score_answer(answer=json.dumps(item["metadata"]["possible_answer"]), entry=item) == 1.0 assert dataset.score_answer(answer=json.dumps(item["metadata"]["possible_answer"]), entry=item) == 1.0
assert dataset.score_answer(answer=None, entry=item) == 0.0 assert dataset.score_answer(answer=None, entry=item) == 0.0
def test_graph_color_curriculum():
curriculum = GraphColorCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: GraphColorConfig = curriculum.generate_configuration(base_value)
assert base_cfg.size == 150
assert base_cfg.seed == 1
assert base_cfg.min_num_vertices == base_cfg.max_num_vertices == 10
assert base_cfg.min_num_colors == base_cfg.max_num_colors == 3
curriculum.increment_attr_level("num_vertices")
cfg = curriculum.generate_configuration(base_value)
assert cfg.min_num_vertices == 10
assert cfg.max_num_vertices == 20
curriculum.increment_attr_level("num_colors")
cfg = curriculum.generate_configuration(base_value)
assert cfg.min_num_colors == 3
assert cfg.max_num_colors == 4
curriculum.increment_attr_level("num_vertices")
curriculum.increment_attr_level("num_colors")
cfg = curriculum.generate_configuration(base_value)
assert cfg.min_num_vertices == 10
assert cfg.max_num_vertices == 30
assert cfg.min_num_colors == 3
assert cfg.max_num_colors == 5
curriculum.increment_attr_level("num_vertices")
curriculum.increment_attr_level("num_colors")
cfg = curriculum.generate_configuration(base_value)
assert cfg.min_num_vertices == 10
assert cfg.max_num_vertices == 40
assert cfg.min_num_colors == 3
assert cfg.max_num_colors == 6
curriculum.decrement_attr_level("num_vertices")
curriculum.decrement_attr_level("num_colors")
cfg = curriculum.generate_configuration(base_value)
assert cfg.min_num_vertices == 10
assert cfg.max_num_vertices == 30
assert cfg.min_num_colors == 3
assert cfg.max_num_colors == 5