mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-27 17:23:19 +00:00
added Decimal curriculum (#280)
* added decimal curricula * added chain sum decimal curriculum * register DecimalArithmeticCurriculum & DecimalChainSumCurriculum --------- Co-authored-by: Andreas Koepf <andreas.koepf@provisio.com>
This commit is contained in:
parent
dc657b5ed4
commit
e304b20e24
5 changed files with 178 additions and 16 deletions
|
|
@ -1,6 +1,10 @@
|
|||
import pytest
|
||||
|
||||
from reasoning_gym.arithmetic.decimal_arithmetic import DecimalArithmeticConfig, DecimalArithmeticDataset
|
||||
from reasoning_gym.arithmetic.decimal_arithmetic import (
|
||||
DecimalArithmeticConfig,
|
||||
DecimalArithmeticCurriculum,
|
||||
DecimalArithmeticDataset,
|
||||
)
|
||||
|
||||
|
||||
def test_decimal_arithmetic():
|
||||
|
|
@ -8,7 +12,7 @@ def test_decimal_arithmetic():
|
|||
|
||||
# Easy
|
||||
config = DecimalArithmeticConfig(
|
||||
seed=42, size=2000, min_num_decimal_places=3, max_num_decimal_places=3, precision=5, terms=3
|
||||
seed=42, size=2000, min_num_decimal_places=3, max_num_decimal_places=3, precision=5, min_terms=2, max_terms=3
|
||||
)
|
||||
dataset = DecimalArithmeticDataset(config)
|
||||
|
||||
|
|
@ -23,7 +27,7 @@ def test_decimal_arithmetic():
|
|||
|
||||
# M
|
||||
config = DecimalArithmeticConfig(
|
||||
seed=42, size=2000, min_num_decimal_places=3, max_num_decimal_places=6, precision=8, terms=6
|
||||
seed=42, size=2000, min_num_decimal_places=3, max_num_decimal_places=6, precision=8, min_terms=3, max_terms=5
|
||||
)
|
||||
dataset = DecimalArithmeticDataset(config)
|
||||
|
||||
|
|
@ -37,7 +41,7 @@ def test_decimal_arithmetic():
|
|||
|
||||
# H
|
||||
config = DecimalArithmeticConfig(
|
||||
seed=42, size=2000, min_num_decimal_places=3, max_num_decimal_places=13, precision=15, terms=10
|
||||
seed=42, size=2000, min_num_decimal_places=3, max_num_decimal_places=13, precision=15, min_terms=3, max_terms=5
|
||||
)
|
||||
dataset = DecimalArithmeticDataset(config)
|
||||
|
||||
|
|
@ -48,3 +52,36 @@ def test_decimal_arithmetic():
|
|||
assert "metadata" in item
|
||||
|
||||
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
|
||||
|
||||
|
||||
def test_decimal_arithmetic_curriculum():
|
||||
"""Test the decimal arithmetic curriculum generation and attribute adjustment"""
|
||||
curriculum = DecimalArithmeticCurriculum()
|
||||
|
||||
base_value = {"size": 200, "seed": 42, "precision": 6}
|
||||
|
||||
base_cfg: DecimalArithmeticConfig = curriculum.generate_configuration(base_value)
|
||||
assert base_cfg.seed == 42
|
||||
assert base_cfg.size == 200
|
||||
assert base_cfg.precision == 6
|
||||
assert base_cfg.min_num_decimal_places == 3 and base_cfg.max_num_decimal_places == 3
|
||||
|
||||
# Test incrementing attribute level
|
||||
curriculum.increment_attr_level("decimal_places")
|
||||
increased_cfg = curriculum.generate_configuration(base_value)
|
||||
assert increased_cfg.min_num_decimal_places == 3 and increased_cfg.max_num_decimal_places == 5
|
||||
|
||||
# Test incrementing attribute level again
|
||||
curriculum.increment_attr_level("decimal_places")
|
||||
further_increased_cfg = curriculum.generate_configuration(base_value)
|
||||
assert further_increased_cfg.min_num_decimal_places == 3 and further_increased_cfg.max_num_decimal_places == 8
|
||||
|
||||
# Test decrementing attribute level
|
||||
curriculum.decrement_attr_level("decimal_places")
|
||||
decreased_cfg = curriculum.generate_configuration(base_value)
|
||||
assert decreased_cfg.min_num_decimal_places == 3 and decreased_cfg.max_num_decimal_places == 5
|
||||
|
||||
# Test decrementing attribute level to base level
|
||||
curriculum.decrement_attr_level("decimal_places")
|
||||
base_level_cfg = curriculum.generate_configuration(base_value)
|
||||
assert base_level_cfg.min_num_decimal_places == 3 and base_level_cfg.max_num_decimal_places == 3
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import pytest
|
||||
|
||||
from reasoning_gym.arithmetic import DecimalChainSumConfig, DecimalChainSumDataset
|
||||
from reasoning_gym.arithmetic import DecimalChainSumConfig, DecimalChainSumCurriculum, DecimalChainSumDataset
|
||||
|
||||
|
||||
def test_decimal_chain_sum_config_validation():
|
||||
|
|
@ -250,3 +250,46 @@ def test_decimal_precision_scoring():
|
|||
assert dataset.score_answer("", {"answer": "1.200"}) == 0.0
|
||||
assert dataset.score_answer("invalid", {"answer": "1.200"}) == 0.0
|
||||
assert dataset.score_answer("1.2.3", {"answer": "1.200"}) == 0.0
|
||||
|
||||
|
||||
def test_decimal_chain_sum_curriculum():
|
||||
"""Test that the decimal chain sum curriculum works as expected"""
|
||||
curriculum = DecimalChainSumCurriculum()
|
||||
|
||||
base_value = {"size": 150, "seed": 1}
|
||||
|
||||
base_cfg: DecimalChainSumConfig = curriculum.generate_configuration(base_value)
|
||||
assert base_cfg.seed == 1
|
||||
assert base_cfg.size == 150
|
||||
assert base_cfg.min_digits == 1 and base_cfg.max_digits == 1
|
||||
assert base_cfg.min_terms == 2 and base_cfg.max_terms == 2
|
||||
assert base_cfg.min_decimal_places == 1 and base_cfg.max_decimal_places == 1
|
||||
|
||||
# test incrementing attribute levels for num_terms, num_digits, & decimal_places attributes
|
||||
curriculum.increment_attr_level("num_terms")
|
||||
curriculum.increment_attr_level("num_digits")
|
||||
curriculum.increment_attr_level("decimal_places")
|
||||
|
||||
increased_cfg = curriculum.generate_configuration(base_value)
|
||||
assert increased_cfg.min_digits == 1 and increased_cfg.max_digits == 2
|
||||
assert increased_cfg.min_terms == 2 and increased_cfg.max_terms == 3
|
||||
assert increased_cfg.min_decimal_places == 1 and increased_cfg.max_decimal_places == 2
|
||||
|
||||
# test decrementing attribute level for num_digits and decimal_places
|
||||
curriculum.decrement_attr_level("num_digits")
|
||||
curriculum.decrement_attr_level("decimal_places")
|
||||
|
||||
partially_decreased_cfg = curriculum.generate_configuration(base_value)
|
||||
assert partially_decreased_cfg.min_digits == 1 and partially_decreased_cfg.max_digits == 1
|
||||
assert partially_decreased_cfg.min_terms == 2 and partially_decreased_cfg.max_terms == 3
|
||||
assert partially_decreased_cfg.min_decimal_places == 1 and partially_decreased_cfg.max_decimal_places == 1
|
||||
|
||||
# test that trying to decrement below minimum doesn't change configuration
|
||||
curriculum.decrement_attr_level("num_terms") # Already at minimum
|
||||
curriculum.decrement_attr_level("num_digits") # Already at minimum
|
||||
curriculum.decrement_attr_level("decimal_places") # Already at minimum
|
||||
|
||||
min_level_cfg = curriculum.generate_configuration(base_value)
|
||||
assert min_level_cfg.min_digits == 1 and min_level_cfg.max_digits == 1
|
||||
assert min_level_cfg.min_terms == 2 and min_level_cfg.max_terms == 2
|
||||
assert min_level_cfg.min_decimal_places == 1 and min_level_cfg.max_decimal_places == 1
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue