diff --git a/reasoning_gym/algorithmic/graph_color.py b/reasoning_gym/algorithmic/graph_color.py index 408a9f43..4dc57589 100644 --- a/reasoning_gym/algorithmic/graph_color.py +++ b/reasoning_gym/algorithmic/graph_color.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from random import Random from typing import Any, Optional -from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -155,11 +155,10 @@ def greedy_graph_coloring(puzzle): class GraphColorConfig: """Configuration for GraphColor puzzle generation""" - min_num_colors: int = 3 - max_num_colors: int = 3 + num_colors: int = 3 min_num_vertices: int = 10 max_num_vertices: int = 10 - edge_probability: float = 0.4 + edge_probability: float = 0.1 seed: Optional[int] = None size: int = 500 @@ -188,7 +187,7 @@ class GraphColorDataset(ProceduralDataset): puzzle = None solution = None num_vertices = rng.randint(self.config.min_num_vertices, self.config.max_num_vertices) - num_colors = rng.randint(self.config.min_num_colors, self.config.max_num_colors) + num_colors = self.config.num_colors while solution is None: puzzle = generate_graph_coloring_puzzle( rng=rng, @@ -197,6 +196,8 @@ class GraphColorDataset(ProceduralDataset): num_colors=num_colors, ) solution = greedy_graph_coloring(puzzle) + if not solution: + num_vertices = rng.randint(self.config.min_num_vertices, self.config.max_num_vertices) edges = str(puzzle["edges"]) question = f"""Please provide a coloring for this graph such that every vertex is not connected to a vertex of the same color. The graph has these properties: @@ -253,23 +254,22 @@ class GraphColorCurriculum(BaseCurriculum): self._define_attributes( RangeAttributeDefinition( name="num_vertices", - levels=[10, 20, 30, 40], + levels=[10, 20, 25, 50], default_level=0, description="Number of vertices in the graph", - attr_type=AttributeType.APPEND, + attr_type=AttributeType.STATIC, min_value=10, lower_field_name="min_num_vertices", upper_field_name="max_num_vertices", ), - RangeAttributeDefinition( + ScalarAttributeDefinition( name="num_colors", - levels=[3, 4, 5, 6], + field_name="num_colors", + levels=[5, 4, 3], default_level=0, description="Number of colors in the graph", - attr_type=AttributeType.APPEND, - min_value=2, - lower_field_name="min_num_colors", - upper_field_name="max_num_colors", + attr_type=AttributeType.STATIC, + min_value=3, ), ) diff --git a/tests/test_graph_color.py b/tests/test_graph_color.py index 7f631e6f..7b81842c 100644 --- a/tests/test_graph_color.py +++ b/tests/test_graph_color.py @@ -12,8 +12,7 @@ def test_graph_color(): size=10, min_num_vertices=10, max_num_vertices=10, - min_num_colors=4, - max_num_colors=4, + num_colors=4, edge_probability=0.4, ) dataset = GraphColorDataset(config) @@ -35,9 +34,8 @@ def test_graph_color(): size=1, min_num_vertices=10, max_num_vertices=10, - min_num_colors=3, - max_num_colors=3, - edge_probability=0.3, + num_colors=3, + edge_probability=0.1, ) dataset = GraphColorDataset(config) @@ -49,11 +47,10 @@ def test_graph_color(): config = GraphColorConfig( seed=42, size=1, - min_num_vertices=40, - max_num_vertices=40, - min_num_colors=4, - max_num_colors=4, - edge_probability=0.2, + min_num_vertices=15, + max_num_vertices=15, + num_colors=3, + edge_probability=0.1, ) dataset = GraphColorDataset(config) @@ -67,8 +64,7 @@ def test_graph_color(): size=1, min_num_vertices=50, max_num_vertices=50, - min_num_colors=3, - max_num_colors=3, + num_colors=3, edge_probability=0.1, ) dataset = GraphColorDataset(config) @@ -87,38 +83,15 @@ def test_graph_color_curriculum(): assert base_cfg.size == 150 assert base_cfg.seed == 1 assert base_cfg.min_num_vertices == base_cfg.max_num_vertices == 10 - assert base_cfg.min_num_colors == base_cfg.max_num_colors == 3 + assert base_cfg.num_colors == base_cfg.num_colors == 5 curriculum.increment_attr_level("num_vertices") cfg = curriculum.generate_configuration(base_value) - assert cfg.min_num_vertices == 10 - assert cfg.max_num_vertices == 20 + assert cfg.min_num_vertices == 20 curriculum.increment_attr_level("num_colors") cfg = curriculum.generate_configuration(base_value) - assert cfg.min_num_colors == 3 - assert cfg.max_num_colors == 4 - - curriculum.increment_attr_level("num_vertices") + assert cfg.num_colors == 4 curriculum.increment_attr_level("num_colors") cfg = curriculum.generate_configuration(base_value) - assert cfg.min_num_vertices == 10 - assert cfg.max_num_vertices == 30 - assert cfg.min_num_colors == 3 - assert cfg.max_num_colors == 5 - - curriculum.increment_attr_level("num_vertices") - curriculum.increment_attr_level("num_colors") - cfg = curriculum.generate_configuration(base_value) - assert cfg.min_num_vertices == 10 - assert cfg.max_num_vertices == 40 - assert cfg.min_num_colors == 3 - assert cfg.max_num_colors == 6 - - curriculum.decrement_attr_level("num_vertices") - curriculum.decrement_attr_level("num_colors") - cfg = curriculum.generate_configuration(base_value) - assert cfg.min_num_vertices == 10 - assert cfg.max_num_vertices == 30 - assert cfg.min_num_colors == 3 - assert cfg.max_num_colors == 5 + assert cfg.num_colors == 3