From e304b20e24d4233045fe9eef116cfdc60e09f037 Mon Sep 17 00:00:00 2001 From: joesharratt1229 <118444587+joesharratt1229@users.noreply.github.com> Date: Fri, 7 Mar 2025 23:02:57 +0100 Subject: [PATCH] added Decimal curriculum (#280) * added decimal curricula * added chain sum decimal curriculum * register DecimalArithmeticCurriculum & DecimalChainSumCurriculum --------- Co-authored-by: Andreas Koepf --- reasoning_gym/arithmetic/__init__.py | 6 +- .../arithmetic/decimal_arithmetic.py | 56 ++++++++++++++++--- reasoning_gym/arithmetic/decimal_chain_sum.py | 42 +++++++++++++- tests/test_decimal_arithmetic.py | 45 +++++++++++++-- tests/test_decimal_chain_sum.py | 45 ++++++++++++++- 5 files changed, 178 insertions(+), 16 deletions(-) diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index e895caeb..a7e6ab73 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -7,8 +7,8 @@ from .bitwise_arithmetic import BitwiseArithmeticConfig, BitwiseArithmeticDatase from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset from .chain_sum import ChainSumConfig, ChainSumDataset from .count_bits import CountBitsConfig, CountBitsCurriculum, CountBitsDataset -from .decimal_arithmetic import DecimalArithmeticConfig, DecimalArithmeticDataset -from .decimal_chain_sum import DecimalChainSumConfig, DecimalChainSumDataset +from .decimal_arithmetic import DecimalArithmeticConfig, DecimalArithmeticCurriculum, DecimalArithmeticDataset +from .decimal_chain_sum import DecimalChainSumConfig, DecimalChainSumCurriculum, DecimalChainSumDataset from .dice import DiceConfig, DiceDataset from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset from .gcd import GCDConfig, GCDDataset @@ -57,6 +57,8 @@ __all__ = [ "NumberFormatDataset", "DecimalArithmeticConfig", "DecimalArithmeticDataset", + "DecimalArithmeticCurriculum", + "DecimalChainSumCurriculum", "DecimalChainSumConfig", "DecimalChainSumDataset", "BitwiseArithmeticConfig", diff --git a/reasoning_gym/arithmetic/decimal_arithmetic.py b/reasoning_gym/arithmetic/decimal_arithmetic.py index 86c4f5ff..2ebaf26c 100644 --- a/reasoning_gym/arithmetic/decimal_arithmetic.py +++ b/reasoning_gym/arithmetic/decimal_arithmetic.py @@ -4,6 +4,7 @@ from decimal import ROUND_HALF_UP, Decimal, getcontext from random import Random from typing import Any, Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -13,8 +14,9 @@ class DecimalArithmeticConfig: min_num_decimal_places: int = 3 max_num_decimal_places: int = 3 - precision: int = 6 - terms: int = 6 + min_terms: int = 2 + max_terms: int = 6 + precision: int = 12 seed: Optional[int] = None size: int = 500 @@ -31,7 +33,7 @@ def build_grouped_expression(operands: list[str], operators: list[str], rng: Ran inserting parentheses at random. The expression is built by choosing a random split among the operands; - the operator at that split becomes the “root” of the subexpression. + the operator at that split becomes the "root" of the subexpression. With 50% chance, the resulting combination is wrapped in parentheses. """ if len(operands) == 1: @@ -74,10 +76,13 @@ def generate_arithmetic_problem( operands: list[str] = [] operators: list[str] = [] + max_ndp = 1 for i in range(terms): # Choose a random number of decimal places for this term. ndp: int = rng.randint(min_num_decimal_places, max_num_decimal_places) + if ndp > max_ndp: + max_ndp = ndp max_integer_part: int = 10 # Maximum whole number before the decimal max_value: int = max_integer_part * (10**ndp) raw_int: int = rng.randint(1, max_value) @@ -94,7 +99,7 @@ def generate_arithmetic_problem( expr: str = build_grouped_expression(operands, operators, rng) problem_str: str = expr + " = ?" - return problem_str + return problem_str, max_ndp def evaluate_expression(expr: str) -> Decimal: @@ -163,11 +168,13 @@ class DecimalArithmeticDataset(ProceduralDataset): rng: Random = Random(self.seed + idx if self.seed is not None else None) getcontext().prec = self.config.precision - problem_str: str = generate_arithmetic_problem( + terms = rng.randint(self.config.min_terms, self.config.max_terms) + + problem_str, decimal_places = generate_arithmetic_problem( rng, self.config.min_num_decimal_places, self.config.max_num_decimal_places, - terms=self.config.terms, + terms=terms, ) # Remove the trailing " = ?" to obtain the pure arithmetic expression. expr: str = problem_str.replace(" = ?", "").strip() @@ -178,7 +185,11 @@ class DecimalArithmeticDataset(ProceduralDataset): + problem_str ) - return {"question": problem_str, "answer": str(answer), "metadata": {}} + return { + "question": problem_str, + "answer": str(answer), + "metadata": {"decimal_places": decimal_places, "num_terms": terms}, + } def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: """ @@ -207,5 +218,34 @@ class DecimalArithmeticDataset(ProceduralDataset): return 0.0 +class DecimalArithmeticCurriculum(BaseCurriculum): + """Curriculum for Decimal Arithmetic""" + + def __init__(self): + super().__init__(DecimalArithmeticCurriculum.__name__, DecimalArithmeticConfig) + self._define_attributes( + RangeAttributeDefinition( + name="decimal_places", + levels=[3, 5, 8, 10], + default_level=0, + description="Number of decimal places of the numbers in problem", + attr_type=AttributeType.APPEND, + min_value=3, + lower_field_name="min_num_decimal_places", + upper_field_name="max_num_decimal_places", + ), + RangeAttributeDefinition( + name="num_terms", + levels=[2, 3, 4, 6], + default_level=0, + description="Number of terms in the arithmetic expression", + attr_type=AttributeType.APPEND, + min_value=2, + lower_field_name="min_terms", + upper_field_name="max_terms", + ), + ) + + # Register the dataset with the factory. -register_dataset("decimal_arithmetic", DecimalArithmeticDataset, DecimalArithmeticConfig) +register_dataset("decimal_arithmetic", DecimalArithmeticDataset, DecimalArithmeticConfig, DecimalArithmeticCurriculum) diff --git a/reasoning_gym/arithmetic/decimal_chain_sum.py b/reasoning_gym/arithmetic/decimal_chain_sum.py index 55f9b411..e06444dc 100644 --- a/reasoning_gym/arithmetic/decimal_chain_sum.py +++ b/reasoning_gym/arithmetic/decimal_chain_sum.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from decimal import Decimal, InvalidOperation from typing import Any, Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -160,4 +161,43 @@ class DecimalChainSumDataset(ProceduralDataset): return 0.0 -register_dataset("decimal_chain_sum", DecimalChainSumDataset, DecimalChainSumConfig) +class DecimalChainSumCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(DecimalChainSumCurriculum.__name__, DecimalChainSumConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="num_terms", + levels=[2, 3, 4, 5], + default_level=0, + description="Maximum number of terms in the expression", + attr_type=AttributeType.APPEND, + min_value=2, + lower_field_name="min_terms", + upper_field_name="max_terms", + ), + RangeAttributeDefinition( + name="num_digits", + levels=[1, 2, 4, 10], + default_level=0, # Start with 1-digit numbers + description="Number of digits in each operand", + attr_type=AttributeType.APPEND, + min_value=1, + lower_field_name="min_digits", + upper_field_name="max_digits", + ), + RangeAttributeDefinition( + name="decimal_places", + levels=[1, 2, 3, 4], + default_level=0, + description="Number of decimal places in each operand", + attr_type=AttributeType.APPEND, + min_value=1, + lower_field_name="min_decimal_places", + upper_field_name="max_decimal_places", + ), + ) + + +register_dataset("decimal_chain_sum", DecimalChainSumDataset, DecimalChainSumConfig, DecimalChainSumCurriculum) diff --git a/tests/test_decimal_arithmetic.py b/tests/test_decimal_arithmetic.py index 3595d97b..61d394fa 100644 --- a/tests/test_decimal_arithmetic.py +++ b/tests/test_decimal_arithmetic.py @@ -1,6 +1,10 @@ import pytest -from reasoning_gym.arithmetic.decimal_arithmetic import DecimalArithmeticConfig, DecimalArithmeticDataset +from reasoning_gym.arithmetic.decimal_arithmetic import ( + DecimalArithmeticConfig, + DecimalArithmeticCurriculum, + DecimalArithmeticDataset, +) def test_decimal_arithmetic(): @@ -8,7 +12,7 @@ def test_decimal_arithmetic(): # Easy config = DecimalArithmeticConfig( - seed=42, size=2000, min_num_decimal_places=3, max_num_decimal_places=3, precision=5, terms=3 + seed=42, size=2000, min_num_decimal_places=3, max_num_decimal_places=3, precision=5, min_terms=2, max_terms=3 ) dataset = DecimalArithmeticDataset(config) @@ -23,7 +27,7 @@ def test_decimal_arithmetic(): # M config = DecimalArithmeticConfig( - seed=42, size=2000, min_num_decimal_places=3, max_num_decimal_places=6, precision=8, terms=6 + seed=42, size=2000, min_num_decimal_places=3, max_num_decimal_places=6, precision=8, min_terms=3, max_terms=5 ) dataset = DecimalArithmeticDataset(config) @@ -37,7 +41,7 @@ def test_decimal_arithmetic(): # H config = DecimalArithmeticConfig( - seed=42, size=2000, min_num_decimal_places=3, max_num_decimal_places=13, precision=15, terms=10 + seed=42, size=2000, min_num_decimal_places=3, max_num_decimal_places=13, precision=15, min_terms=3, max_terms=5 ) dataset = DecimalArithmeticDataset(config) @@ -48,3 +52,36 @@ def test_decimal_arithmetic(): assert "metadata" in item assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 + + +def test_decimal_arithmetic_curriculum(): + """Test the decimal arithmetic curriculum generation and attribute adjustment""" + curriculum = DecimalArithmeticCurriculum() + + base_value = {"size": 200, "seed": 42, "precision": 6} + + base_cfg: DecimalArithmeticConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 42 + assert base_cfg.size == 200 + assert base_cfg.precision == 6 + assert base_cfg.min_num_decimal_places == 3 and base_cfg.max_num_decimal_places == 3 + + # Test incrementing attribute level + curriculum.increment_attr_level("decimal_places") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_num_decimal_places == 3 and increased_cfg.max_num_decimal_places == 5 + + # Test incrementing attribute level again + curriculum.increment_attr_level("decimal_places") + further_increased_cfg = curriculum.generate_configuration(base_value) + assert further_increased_cfg.min_num_decimal_places == 3 and further_increased_cfg.max_num_decimal_places == 8 + + # Test decrementing attribute level + curriculum.decrement_attr_level("decimal_places") + decreased_cfg = curriculum.generate_configuration(base_value) + assert decreased_cfg.min_num_decimal_places == 3 and decreased_cfg.max_num_decimal_places == 5 + + # Test decrementing attribute level to base level + curriculum.decrement_attr_level("decimal_places") + base_level_cfg = curriculum.generate_configuration(base_value) + assert base_level_cfg.min_num_decimal_places == 3 and base_level_cfg.max_num_decimal_places == 3 diff --git a/tests/test_decimal_chain_sum.py b/tests/test_decimal_chain_sum.py index e488b77c..821215f1 100644 --- a/tests/test_decimal_chain_sum.py +++ b/tests/test_decimal_chain_sum.py @@ -1,6 +1,6 @@ import pytest -from reasoning_gym.arithmetic import DecimalChainSumConfig, DecimalChainSumDataset +from reasoning_gym.arithmetic import DecimalChainSumConfig, DecimalChainSumCurriculum, DecimalChainSumDataset def test_decimal_chain_sum_config_validation(): @@ -250,3 +250,46 @@ def test_decimal_precision_scoring(): assert dataset.score_answer("", {"answer": "1.200"}) == 0.0 assert dataset.score_answer("invalid", {"answer": "1.200"}) == 0.0 assert dataset.score_answer("1.2.3", {"answer": "1.200"}) == 0.0 + + +def test_decimal_chain_sum_curriculum(): + """Test that the decimal chain sum curriculum works as expected""" + curriculum = DecimalChainSumCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: DecimalChainSumConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_digits == 1 and base_cfg.max_digits == 1 + assert base_cfg.min_terms == 2 and base_cfg.max_terms == 2 + assert base_cfg.min_decimal_places == 1 and base_cfg.max_decimal_places == 1 + + # test incrementing attribute levels for num_terms, num_digits, & decimal_places attributes + curriculum.increment_attr_level("num_terms") + curriculum.increment_attr_level("num_digits") + curriculum.increment_attr_level("decimal_places") + + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_digits == 1 and increased_cfg.max_digits == 2 + assert increased_cfg.min_terms == 2 and increased_cfg.max_terms == 3 + assert increased_cfg.min_decimal_places == 1 and increased_cfg.max_decimal_places == 2 + + # test decrementing attribute level for num_digits and decimal_places + curriculum.decrement_attr_level("num_digits") + curriculum.decrement_attr_level("decimal_places") + + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_digits == 1 and partially_decreased_cfg.max_digits == 1 + assert partially_decreased_cfg.min_terms == 2 and partially_decreased_cfg.max_terms == 3 + assert partially_decreased_cfg.min_decimal_places == 1 and partially_decreased_cfg.max_decimal_places == 1 + + # test that trying to decrement below minimum doesn't change configuration + curriculum.decrement_attr_level("num_terms") # Already at minimum + curriculum.decrement_attr_level("num_digits") # Already at minimum + curriculum.decrement_attr_level("decimal_places") # Already at minimum + + min_level_cfg = curriculum.generate_configuration(base_value) + assert min_level_cfg.min_digits == 1 and min_level_cfg.max_digits == 1 + assert min_level_cfg.min_terms == 2 and min_level_cfg.max_terms == 2 + assert min_level_cfg.min_decimal_places == 1 and min_level_cfg.max_decimal_places == 1