added Decimal curriculum (#280)

* added decimal curricula

* added chain sum decimal curriculum

* register DecimalArithmeticCurriculum & DecimalChainSumCurriculum

---------

Co-authored-by: Andreas Koepf <andreas.koepf@provisio.com>
This commit is contained in:
joesharratt1229 2025-03-07 23:02:57 +01:00 committed by GitHub
parent dc657b5ed4
commit e304b20e24
5 changed files with 178 additions and 16 deletions

View file

@ -1,6 +1,10 @@
import pytest
from reasoning_gym.arithmetic.decimal_arithmetic import DecimalArithmeticConfig, DecimalArithmeticDataset
from reasoning_gym.arithmetic.decimal_arithmetic import (
DecimalArithmeticConfig,
DecimalArithmeticCurriculum,
DecimalArithmeticDataset,
)
def test_decimal_arithmetic():
@ -8,7 +12,7 @@ def test_decimal_arithmetic():
# Easy
config = DecimalArithmeticConfig(
seed=42, size=2000, min_num_decimal_places=3, max_num_decimal_places=3, precision=5, terms=3
seed=42, size=2000, min_num_decimal_places=3, max_num_decimal_places=3, precision=5, min_terms=2, max_terms=3
)
dataset = DecimalArithmeticDataset(config)
@ -23,7 +27,7 @@ def test_decimal_arithmetic():
# M
config = DecimalArithmeticConfig(
seed=42, size=2000, min_num_decimal_places=3, max_num_decimal_places=6, precision=8, terms=6
seed=42, size=2000, min_num_decimal_places=3, max_num_decimal_places=6, precision=8, min_terms=3, max_terms=5
)
dataset = DecimalArithmeticDataset(config)
@ -37,7 +41,7 @@ def test_decimal_arithmetic():
# H
config = DecimalArithmeticConfig(
seed=42, size=2000, min_num_decimal_places=3, max_num_decimal_places=13, precision=15, terms=10
seed=42, size=2000, min_num_decimal_places=3, max_num_decimal_places=13, precision=15, min_terms=3, max_terms=5
)
dataset = DecimalArithmeticDataset(config)
@ -48,3 +52,36 @@ def test_decimal_arithmetic():
assert "metadata" in item
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
def test_decimal_arithmetic_curriculum():
"""Test the decimal arithmetic curriculum generation and attribute adjustment"""
curriculum = DecimalArithmeticCurriculum()
base_value = {"size": 200, "seed": 42, "precision": 6}
base_cfg: DecimalArithmeticConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 42
assert base_cfg.size == 200
assert base_cfg.precision == 6
assert base_cfg.min_num_decimal_places == 3 and base_cfg.max_num_decimal_places == 3
# Test incrementing attribute level
curriculum.increment_attr_level("decimal_places")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_num_decimal_places == 3 and increased_cfg.max_num_decimal_places == 5
# Test incrementing attribute level again
curriculum.increment_attr_level("decimal_places")
further_increased_cfg = curriculum.generate_configuration(base_value)
assert further_increased_cfg.min_num_decimal_places == 3 and further_increased_cfg.max_num_decimal_places == 8
# Test decrementing attribute level
curriculum.decrement_attr_level("decimal_places")
decreased_cfg = curriculum.generate_configuration(base_value)
assert decreased_cfg.min_num_decimal_places == 3 and decreased_cfg.max_num_decimal_places == 5
# Test decrementing attribute level to base level
curriculum.decrement_attr_level("decimal_places")
base_level_cfg = curriculum.generate_configuration(base_value)
assert base_level_cfg.min_num_decimal_places == 3 and base_level_cfg.max_num_decimal_places == 3