manipulate matrix curriculum (#293)

This commit is contained in:
Zafir Stojanovski 2025-03-08 01:57:37 +01:00 committed by GitHub
parent e69ed78c26
commit 8d4e9030c0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 93 additions and 10 deletions

View file

@ -2,7 +2,11 @@
import pytest
from reasoning_gym.algorithmic.manipulate_matrix import ManipulateMatrixConfig, ManipulateMatrixDataset
from reasoning_gym.algorithmic.manipulate_matrix import (
ManipulateMatrixConfig,
ManipulateMatrixCurriculum,
ManipulateMatrixDataset,
)
def test_manipulate_matrix_config_validation():
@ -219,3 +223,32 @@ def test_manipulate_matrix_score_answer():
# answer is none
answer = None
assert dataset.score_answer(answer, entry) == 0.0
def test_manipulate_matrix_curriculum():
curriculum = ManipulateMatrixCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: ManipulateMatrixConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_rows == 10 and base_cfg.max_rows == 10
assert base_cfg.min_cols == 10 and base_cfg.max_cols == 10
assert base_cfg.min_transforms == 5 and base_cfg.max_transforms == 5
# test incrementing attribute levels
curriculum.increment_attr_level("rows")
curriculum.increment_attr_level("cols")
curriculum.increment_attr_level("num_transforms")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_rows == 10 and increased_cfg.max_rows == 25
assert increased_cfg.min_cols == 10 and increased_cfg.max_cols == 25
assert increased_cfg.min_transforms == 5 and increased_cfg.max_transforms == 10
# test decrementing attribute level for rows again
curriculum.decrement_attr_level("rows")
partially_decreased_cfg = curriculum.generate_configuration(base_value)
assert partially_decreased_cfg.min_rows == 10 and partially_decreased_cfg.max_rows == 10
assert partially_decreased_cfg.min_cols == 10 and partially_decreased_cfg.max_cols == 25
assert increased_cfg.min_transforms == 5 and increased_cfg.max_transforms == 10