added Decimal curriculum (#280)

* added decimal curricula

* added chain sum decimal curriculum

* register DecimalArithmeticCurriculum & DecimalChainSumCurriculum

---------

Co-authored-by: Andreas Koepf <andreas.koepf@provisio.com>
This commit is contained in:
joesharratt1229 2025-03-07 23:02:57 +01:00 committed by GitHub
parent dc657b5ed4
commit e304b20e24
5 changed files with 178 additions and 16 deletions

View file

@ -1,6 +1,6 @@
import pytest
from reasoning_gym.arithmetic import DecimalChainSumConfig, DecimalChainSumDataset
from reasoning_gym.arithmetic import DecimalChainSumConfig, DecimalChainSumCurriculum, DecimalChainSumDataset
def test_decimal_chain_sum_config_validation():
@ -250,3 +250,46 @@ def test_decimal_precision_scoring():
assert dataset.score_answer("", {"answer": "1.200"}) == 0.0
assert dataset.score_answer("invalid", {"answer": "1.200"}) == 0.0
assert dataset.score_answer("1.2.3", {"answer": "1.200"}) == 0.0
def test_decimal_chain_sum_curriculum():
"""Test that the decimal chain sum curriculum works as expected"""
curriculum = DecimalChainSumCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: DecimalChainSumConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_digits == 1 and base_cfg.max_digits == 1
assert base_cfg.min_terms == 2 and base_cfg.max_terms == 2
assert base_cfg.min_decimal_places == 1 and base_cfg.max_decimal_places == 1
# test incrementing attribute levels for num_terms, num_digits, & decimal_places attributes
curriculum.increment_attr_level("num_terms")
curriculum.increment_attr_level("num_digits")
curriculum.increment_attr_level("decimal_places")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_digits == 1 and increased_cfg.max_digits == 2
assert increased_cfg.min_terms == 2 and increased_cfg.max_terms == 3
assert increased_cfg.min_decimal_places == 1 and increased_cfg.max_decimal_places == 2
# test decrementing attribute level for num_digits and decimal_places
curriculum.decrement_attr_level("num_digits")
curriculum.decrement_attr_level("decimal_places")
partially_decreased_cfg = curriculum.generate_configuration(base_value)
assert partially_decreased_cfg.min_digits == 1 and partially_decreased_cfg.max_digits == 1
assert partially_decreased_cfg.min_terms == 2 and partially_decreased_cfg.max_terms == 3
assert partially_decreased_cfg.min_decimal_places == 1 and partially_decreased_cfg.max_decimal_places == 1
# test that trying to decrement below minimum doesn't change configuration
curriculum.decrement_attr_level("num_terms") # Already at minimum
curriculum.decrement_attr_level("num_digits") # Already at minimum
curriculum.decrement_attr_level("decimal_places") # Already at minimum
min_level_cfg = curriculum.generate_configuration(base_value)
assert min_level_cfg.min_digits == 1 and min_level_cfg.max_digits == 1
assert min_level_cfg.min_terms == 2 and min_level_cfg.max_terms == 2
assert min_level_cfg.min_decimal_places == 1 and min_level_cfg.max_decimal_places == 1