pool matrix curriculum (#298)

This commit is contained in:
Zafir Stojanovski 2025-03-08 20:57:22 +01:00 committed by GitHub
parent 5963cbd59e
commit 194f08cad2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 86 additions and 14 deletions

View file

@ -3,7 +3,7 @@
import numpy as np
import pytest
from reasoning_gym.algorithmic.pool_matrix import PoolMatrixConfig, PoolMatrixDataset
from reasoning_gym.algorithmic.pool_matrix import PoolMatrixConfig, PoolMatrixCurriculum, PoolMatrixDataset
def test_pool_matrix_config_validation():
@ -161,3 +161,32 @@ def test_pool_matrix_int_answer():
matrix = matrix.reshape(1, 1)
int_answer = "\n".join(" ".join(str(x) for x in row) for row in matrix)
assert dataset.score_answer(answer=int_answer, entry=entry) == 1.0
def test_pool_matrix_curriculum():
curriculum = PoolMatrixCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: PoolMatrixConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_rows == 10 and base_cfg.max_rows == 10
assert base_cfg.min_cols == 10 and base_cfg.max_cols == 10
assert base_cfg.min_pool_size == 3 and base_cfg.max_pool_size == 3
# test incrementing attribute levels
curriculum.increment_attr_level("rows")
curriculum.increment_attr_level("cols")
curriculum.increment_attr_level("pool_size")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_rows == 10 and increased_cfg.max_rows == 25
assert increased_cfg.min_cols == 10 and increased_cfg.max_cols == 25
assert increased_cfg.min_pool_size == 3 and increased_cfg.max_pool_size == 5
# test decrementing attribute level for pool_size again
curriculum.decrement_attr_level("pool_size")
partially_decreased_cfg = curriculum.generate_configuration(base_value)
assert partially_decreased_cfg.min_rows == 10 and partially_decreased_cfg.max_rows == 25
assert partially_decreased_cfg.min_cols == 10 and partially_decreased_cfg.max_cols == 25
assert partially_decreased_cfg.min_pool_size == 3 and partially_decreased_cfg.max_pool_size == 3