shortest path curriculum (#271)

This commit is contained in:
Zafir Stojanovski 2025-03-05 22:46:10 +01:00 committed by GitHub
parent 5bac641650
commit f426db90ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 58 additions and 2 deletions

View file

@ -2,7 +2,7 @@
import pytest
from reasoning_gym.graphs.shortest_path import ShortestPathConfig, ShortestPathDataset
from reasoning_gym.graphs.shortest_path import ShortestPathConfig, ShortestPathCurriculum, ShortestPathDataset
def test_shortest_path_config_validation():
@ -179,3 +179,28 @@ def test_shortest_path_answer():
},
}
assert dataset.score_answer(None, entry) == 0.0
def test_chain_sum_curriculum():
curriculum = ShortestPathCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: ShortestPathConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_rows == 10 and base_cfg.max_rows == 10
assert base_cfg.min_cols == 10 and base_cfg.max_cols == 10
# test incrementing attribute levels
curriculum.increment_attr_level("rows")
curriculum.increment_attr_level("cols")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_rows == 10 and increased_cfg.max_rows == 25
assert increased_cfg.min_cols == 10 and increased_cfg.max_cols == 25
# test decrementing attribute level for rows again
curriculum.decrement_attr_level("rows")
partially_decreased_cfg = curriculum.generate_configuration(base_value)
assert partially_decreased_cfg.min_rows == 10 and partially_decreased_cfg.max_rows == 10
assert partially_decreased_cfg.min_cols == 10 and partially_decreased_cfg.max_cols == 25