gcd curriculum (#331)

This commit is contained in:
vncntt 2025-03-11 00:25:24 -07:00 committed by GitHub
parent 126eecc798
commit c3c6cc8051
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 78 additions and 3 deletions

View file

@ -3,7 +3,7 @@ from math import gcd
import pytest
from reasoning_gym.arithmetic import GCDConfig, GCDDataset
from reasoning_gym.arithmetic import GCDConfig, GCDCurriculum, GCDDataset
def test_gcd_config_validation():
@ -115,3 +115,39 @@ def test_gcd_special_cases():
# With enough samples, we should see both coprime and non-coprime numbers
assert seen_gcd_1, "Expected to see some coprime numbers (GCD=1)"
assert seen_large_gcd, "Expected to see some non-coprime numbers (GCD>1)"
def test_gcd_curriculum():
"""Test that curriculum generates correct items"""
curriculum = GCDCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: GCDConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_numbers == 2 and base_cfg.max_numbers == 2
assert base_cfg.min_value == 100 and base_cfg.max_value == 100
curriculum.increment_attr_level("num_terms")
curriculum.increment_attr_level("max_value")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_numbers == 2 and increased_cfg.max_numbers == 3
assert increased_cfg.min_value == 100 and increased_cfg.max_value == 1000
curriculum.increment_attr_level("num_terms")
curriculum.increment_attr_level("max_value")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_numbers == 2 and increased_cfg.max_numbers == 4
assert increased_cfg.min_value == 100 and increased_cfg.max_value == 10000
curriculum.increment_attr_level("num_terms")
curriculum.increment_attr_level("max_value")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_numbers == 2 and increased_cfg.max_numbers == 5
assert increased_cfg.min_value == 100 and increased_cfg.max_value == 100000
curriculum.decrement_attr_level("num_terms")
curriculum.decrement_attr_level("max_value")
decreased_cfg = curriculum.generate_configuration(base_value)
assert decreased_cfg.min_numbers == 2 and decreased_cfg.max_numbers == 4
assert decreased_cfg.min_value == 100 and decreased_cfg.max_value == 10000