refactor test

This commit is contained in:
Rich Jones 2025-03-10 14:22:00 +01:00
parent 4ea70e2ce4
commit cedba8ec26
2 changed files with 19 additions and 49 deletions

View file

@ -12,8 +12,7 @@ def test_graph_color():
size=10,
min_num_vertices=10,
max_num_vertices=10,
min_num_colors=4,
max_num_colors=4,
num_colors=4,
edge_probability=0.4,
)
dataset = GraphColorDataset(config)
@ -35,9 +34,8 @@ def test_graph_color():
size=1,
min_num_vertices=10,
max_num_vertices=10,
min_num_colors=3,
max_num_colors=3,
edge_probability=0.3,
num_colors=3,
edge_probability=0.1,
)
dataset = GraphColorDataset(config)
@ -49,11 +47,10 @@ def test_graph_color():
config = GraphColorConfig(
seed=42,
size=1,
min_num_vertices=40,
max_num_vertices=40,
min_num_colors=4,
max_num_colors=4,
edge_probability=0.2,
min_num_vertices=15,
max_num_vertices=15,
num_colors=3,
edge_probability=0.1,
)
dataset = GraphColorDataset(config)
@ -67,8 +64,7 @@ def test_graph_color():
size=1,
min_num_vertices=50,
max_num_vertices=50,
min_num_colors=3,
max_num_colors=3,
num_colors=3,
edge_probability=0.1,
)
dataset = GraphColorDataset(config)
@ -87,38 +83,15 @@ def test_graph_color_curriculum():
assert base_cfg.size == 150
assert base_cfg.seed == 1
assert base_cfg.min_num_vertices == base_cfg.max_num_vertices == 10
assert base_cfg.min_num_colors == base_cfg.max_num_colors == 3
assert base_cfg.num_colors == base_cfg.num_colors == 5
curriculum.increment_attr_level("num_vertices")
cfg = curriculum.generate_configuration(base_value)
assert cfg.min_num_vertices == 10
assert cfg.max_num_vertices == 20
assert cfg.min_num_vertices == 20
curriculum.increment_attr_level("num_colors")
cfg = curriculum.generate_configuration(base_value)
assert cfg.min_num_colors == 3
assert cfg.max_num_colors == 4
curriculum.increment_attr_level("num_vertices")
assert cfg.num_colors == 4
curriculum.increment_attr_level("num_colors")
cfg = curriculum.generate_configuration(base_value)
assert cfg.min_num_vertices == 10
assert cfg.max_num_vertices == 30
assert cfg.min_num_colors == 3
assert cfg.max_num_colors == 5
curriculum.increment_attr_level("num_vertices")
curriculum.increment_attr_level("num_colors")
cfg = curriculum.generate_configuration(base_value)
assert cfg.min_num_vertices == 10
assert cfg.max_num_vertices == 40
assert cfg.min_num_colors == 3
assert cfg.max_num_colors == 6
curriculum.decrement_attr_level("num_vertices")
curriculum.decrement_attr_level("num_colors")
cfg = curriculum.generate_configuration(base_value)
assert cfg.min_num_vertices == 10
assert cfg.max_num_vertices == 30
assert cfg.min_num_colors == 3
assert cfg.max_num_colors == 5
assert cfg.num_colors == 3