refactor test

This commit is contained in:
Rich Jones 2025-03-10 14:22:00 +01:00
parent 4ea70e2ce4
commit cedba8ec26
2 changed files with 19 additions and 49 deletions

View file

@ -3,7 +3,7 @@ from dataclasses import dataclass
from random import Random
from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@ -155,11 +155,10 @@ def greedy_graph_coloring(puzzle):
class GraphColorConfig:
"""Configuration for GraphColor puzzle generation"""
min_num_colors: int = 3
max_num_colors: int = 3
num_colors: int = 3
min_num_vertices: int = 10
max_num_vertices: int = 10
edge_probability: float = 0.4
edge_probability: float = 0.1
seed: Optional[int] = None
size: int = 500
@ -188,7 +187,7 @@ class GraphColorDataset(ProceduralDataset):
puzzle = None
solution = None
num_vertices = rng.randint(self.config.min_num_vertices, self.config.max_num_vertices)
num_colors = rng.randint(self.config.min_num_colors, self.config.max_num_colors)
num_colors = self.config.num_colors
while solution is None:
puzzle = generate_graph_coloring_puzzle(
rng=rng,
@ -199,7 +198,6 @@ class GraphColorDataset(ProceduralDataset):
solution = greedy_graph_coloring(puzzle)
if not solution:
num_vertices = rng.randint(self.config.min_num_vertices, self.config.max_num_vertices)
num_colors = rng.randint(self.config.min_num_colors, self.config.max_num_colors)
edges = str(puzzle["edges"])
question = f"""Please provide a coloring for this graph such that every vertex is not connected to a vertex of the same color. The graph has these properties:
@ -256,7 +254,7 @@ class GraphColorCurriculum(BaseCurriculum):
self._define_attributes(
RangeAttributeDefinition(
name="num_vertices",
levels=[10, 12, 15, 18],
levels=[10, 20, 25, 50],
default_level=0,
description="Number of vertices in the graph",
attr_type=AttributeType.STATIC,
@ -264,15 +262,14 @@ class GraphColorCurriculum(BaseCurriculum):
lower_field_name="min_num_vertices",
upper_field_name="max_num_vertices",
),
RangeAttributeDefinition(
ScalarAttributeDefinition(
name="num_colors",
field_name="num_colors",
levels=[5, 4, 3],
default_level=0,
description="Number of colors in the graph",
attr_type=AttributeType.STATIC,
min_value=3,
lower_field_name="min_num_colors",
upper_field_name="max_num_colors",
),
)