Correct Graph Coloring Difficulty (#318)

* correct gcolor difficulty

* refactor test
This commit is contained in:
Rich Jones 2025-03-11 00:14:38 +01:00 committed by GitHub
parent d9ef4f4d14
commit 2b8f21c502
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 25 additions and 52 deletions

View file

@ -3,7 +3,7 @@ from dataclasses import dataclass
from random import Random
from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@ -155,11 +155,10 @@ def greedy_graph_coloring(puzzle):
class GraphColorConfig:
"""Configuration for GraphColor puzzle generation"""
min_num_colors: int = 3
max_num_colors: int = 3
num_colors: int = 3
min_num_vertices: int = 10
max_num_vertices: int = 10
edge_probability: float = 0.4
edge_probability: float = 0.1
seed: Optional[int] = None
size: int = 500
@ -188,7 +187,7 @@ class GraphColorDataset(ProceduralDataset):
puzzle = None
solution = None
num_vertices = rng.randint(self.config.min_num_vertices, self.config.max_num_vertices)
num_colors = rng.randint(self.config.min_num_colors, self.config.max_num_colors)
num_colors = self.config.num_colors
while solution is None:
puzzle = generate_graph_coloring_puzzle(
rng=rng,
@ -197,6 +196,8 @@ class GraphColorDataset(ProceduralDataset):
num_colors=num_colors,
)
solution = greedy_graph_coloring(puzzle)
if not solution:
num_vertices = rng.randint(self.config.min_num_vertices, self.config.max_num_vertices)
edges = str(puzzle["edges"])
question = f"""Please provide a coloring for this graph such that every vertex is not connected to a vertex of the same color. The graph has these properties:
@ -253,23 +254,22 @@ class GraphColorCurriculum(BaseCurriculum):
self._define_attributes(
RangeAttributeDefinition(
name="num_vertices",
levels=[10, 20, 30, 40],
levels=[10, 20, 25, 50],
default_level=0,
description="Number of vertices in the graph",
attr_type=AttributeType.APPEND,
attr_type=AttributeType.STATIC,
min_value=10,
lower_field_name="min_num_vertices",
upper_field_name="max_num_vertices",
),
RangeAttributeDefinition(
ScalarAttributeDefinition(
name="num_colors",
levels=[3, 4, 5, 6],
field_name="num_colors",
levels=[5, 4, 3],
default_level=0,
description="Number of colors in the graph",
attr_type=AttributeType.APPEND,
min_value=2,
lower_field_name="min_num_colors",
upper_field_name="max_num_colors",
attr_type=AttributeType.STATIC,
min_value=3,
),
)

View file

@ -12,8 +12,7 @@ def test_graph_color():
size=10,
min_num_vertices=10,
max_num_vertices=10,
min_num_colors=4,
max_num_colors=4,
num_colors=4,
edge_probability=0.4,
)
dataset = GraphColorDataset(config)
@ -35,9 +34,8 @@ def test_graph_color():
size=1,
min_num_vertices=10,
max_num_vertices=10,
min_num_colors=3,
max_num_colors=3,
edge_probability=0.3,
num_colors=3,
edge_probability=0.1,
)
dataset = GraphColorDataset(config)
@ -49,11 +47,10 @@ def test_graph_color():
config = GraphColorConfig(
seed=42,
size=1,
min_num_vertices=40,
max_num_vertices=40,
min_num_colors=4,
max_num_colors=4,
edge_probability=0.2,
min_num_vertices=15,
max_num_vertices=15,
num_colors=3,
edge_probability=0.1,
)
dataset = GraphColorDataset(config)
@ -67,8 +64,7 @@ def test_graph_color():
size=1,
min_num_vertices=50,
max_num_vertices=50,
min_num_colors=3,
max_num_colors=3,
num_colors=3,
edge_probability=0.1,
)
dataset = GraphColorDataset(config)
@ -87,38 +83,15 @@ def test_graph_color_curriculum():
assert base_cfg.size == 150
assert base_cfg.seed == 1
assert base_cfg.min_num_vertices == base_cfg.max_num_vertices == 10
assert base_cfg.min_num_colors == base_cfg.max_num_colors == 3
assert base_cfg.num_colors == base_cfg.num_colors == 5
curriculum.increment_attr_level("num_vertices")
cfg = curriculum.generate_configuration(base_value)
assert cfg.min_num_vertices == 10
assert cfg.max_num_vertices == 20
assert cfg.min_num_vertices == 20
curriculum.increment_attr_level("num_colors")
cfg = curriculum.generate_configuration(base_value)
assert cfg.min_num_colors == 3
assert cfg.max_num_colors == 4
curriculum.increment_attr_level("num_vertices")
assert cfg.num_colors == 4
curriculum.increment_attr_level("num_colors")
cfg = curriculum.generate_configuration(base_value)
assert cfg.min_num_vertices == 10
assert cfg.max_num_vertices == 30
assert cfg.min_num_colors == 3
assert cfg.max_num_colors == 5
curriculum.increment_attr_level("num_vertices")
curriculum.increment_attr_level("num_colors")
cfg = curriculum.generate_configuration(base_value)
assert cfg.min_num_vertices == 10
assert cfg.max_num_vertices == 40
assert cfg.min_num_colors == 3
assert cfg.max_num_colors == 6
curriculum.decrement_attr_level("num_vertices")
curriculum.decrement_attr_level("num_colors")
cfg = curriculum.generate_configuration(base_value)
assert cfg.min_num_vertices == 10
assert cfg.max_num_vertices == 30
assert cfg.min_num_colors == 3
assert cfg.max_num_colors == 5
assert cfg.num_colors == 3