spiral matrix curriculum (#296)

This commit is contained in:
Zafir Stojanovski 2025-03-08 20:56:08 +01:00 committed by GitHub
parent d82c73b6f8
commit e4e516a949
3 changed files with 55 additions and 8 deletions

View file

@ -2,7 +2,7 @@
import pytest
from reasoning_gym.algorithmic.spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset
from reasoning_gym.algorithmic.spiral_matrix import SpiralMatrixConfig, SpiralMatrixCurriculum, SpiralMatrixDataset
def test_spiral_matrix_config_validation():
@ -96,3 +96,24 @@ def test_spiral_matrix_answer():
entry = {"answer": "1 2 3 6 9 8 7 4 5"}
answer = None
assert dataset.score_answer(answer, entry) == 0.0
def test_spiral_matrix_curriculum():
curriculum = SpiralMatrixCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: SpiralMatrixConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_n == 10 and base_cfg.max_n == 10
# test incrementing attribute levels
curriculum.increment_attr_level("n")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_n == 10 and increased_cfg.max_n == 25
# test decrementing attribute levels
curriculum.decrement_attr_level("n")
decreased_cfg = curriculum.generate_configuration(base_value)
assert decreased_cfg.min_n == 10 and decreased_cfg.max_n == 10