added Decimal curriculum (#280)

* added decimal curricula

* added chain sum decimal curriculum

* register DecimalArithmeticCurriculum & DecimalChainSumCurriculum

---------

Co-authored-by: Andreas Koepf <andreas.koepf@provisio.com>
This commit is contained in:
joesharratt1229 2025-03-07 23:02:57 +01:00 committed by GitHub
parent dc657b5ed4
commit e304b20e24
5 changed files with 178 additions and 16 deletions

View file

@ -7,8 +7,8 @@ from .bitwise_arithmetic import BitwiseArithmeticConfig, BitwiseArithmeticDatase
from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset
from .chain_sum import ChainSumConfig, ChainSumDataset from .chain_sum import ChainSumConfig, ChainSumDataset
from .count_bits import CountBitsConfig, CountBitsCurriculum, CountBitsDataset from .count_bits import CountBitsConfig, CountBitsCurriculum, CountBitsDataset
from .decimal_arithmetic import DecimalArithmeticConfig, DecimalArithmeticDataset from .decimal_arithmetic import DecimalArithmeticConfig, DecimalArithmeticCurriculum, DecimalArithmeticDataset
from .decimal_chain_sum import DecimalChainSumConfig, DecimalChainSumDataset from .decimal_chain_sum import DecimalChainSumConfig, DecimalChainSumCurriculum, DecimalChainSumDataset
from .dice import DiceConfig, DiceDataset from .dice import DiceConfig, DiceDataset
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset
from .gcd import GCDConfig, GCDDataset from .gcd import GCDConfig, GCDDataset
@ -57,6 +57,8 @@ __all__ = [
"NumberFormatDataset", "NumberFormatDataset",
"DecimalArithmeticConfig", "DecimalArithmeticConfig",
"DecimalArithmeticDataset", "DecimalArithmeticDataset",
"DecimalArithmeticCurriculum",
"DecimalChainSumCurriculum",
"DecimalChainSumConfig", "DecimalChainSumConfig",
"DecimalChainSumDataset", "DecimalChainSumDataset",
"BitwiseArithmeticConfig", "BitwiseArithmeticConfig",

View file

@ -4,6 +4,7 @@ from decimal import ROUND_HALF_UP, Decimal, getcontext
from random import Random from random import Random
from typing import Any, Optional from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
@ -13,8 +14,9 @@ class DecimalArithmeticConfig:
min_num_decimal_places: int = 3 min_num_decimal_places: int = 3
max_num_decimal_places: int = 3 max_num_decimal_places: int = 3
precision: int = 6 min_terms: int = 2
terms: int = 6 max_terms: int = 6
precision: int = 12
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 size: int = 500
@ -31,7 +33,7 @@ def build_grouped_expression(operands: list[str], operators: list[str], rng: Ran
inserting parentheses at random. inserting parentheses at random.
The expression is built by choosing a random split among the operands; The expression is built by choosing a random split among the operands;
the operator at that split becomes the root of the subexpression. the operator at that split becomes the "root" of the subexpression.
With 50% chance, the resulting combination is wrapped in parentheses. With 50% chance, the resulting combination is wrapped in parentheses.
""" """
if len(operands) == 1: if len(operands) == 1:
@ -74,10 +76,13 @@ def generate_arithmetic_problem(
operands: list[str] = [] operands: list[str] = []
operators: list[str] = [] operators: list[str] = []
max_ndp = 1
for i in range(terms): for i in range(terms):
# Choose a random number of decimal places for this term. # Choose a random number of decimal places for this term.
ndp: int = rng.randint(min_num_decimal_places, max_num_decimal_places) ndp: int = rng.randint(min_num_decimal_places, max_num_decimal_places)
if ndp > max_ndp:
max_ndp = ndp
max_integer_part: int = 10 # Maximum whole number before the decimal max_integer_part: int = 10 # Maximum whole number before the decimal
max_value: int = max_integer_part * (10**ndp) max_value: int = max_integer_part * (10**ndp)
raw_int: int = rng.randint(1, max_value) raw_int: int = rng.randint(1, max_value)
@ -94,7 +99,7 @@ def generate_arithmetic_problem(
expr: str = build_grouped_expression(operands, operators, rng) expr: str = build_grouped_expression(operands, operators, rng)
problem_str: str = expr + " = ?" problem_str: str = expr + " = ?"
return problem_str return problem_str, max_ndp
def evaluate_expression(expr: str) -> Decimal: def evaluate_expression(expr: str) -> Decimal:
@ -163,11 +168,13 @@ class DecimalArithmeticDataset(ProceduralDataset):
rng: Random = Random(self.seed + idx if self.seed is not None else None) rng: Random = Random(self.seed + idx if self.seed is not None else None)
getcontext().prec = self.config.precision getcontext().prec = self.config.precision
problem_str: str = generate_arithmetic_problem( terms = rng.randint(self.config.min_terms, self.config.max_terms)
problem_str, decimal_places = generate_arithmetic_problem(
rng, rng,
self.config.min_num_decimal_places, self.config.min_num_decimal_places,
self.config.max_num_decimal_places, self.config.max_num_decimal_places,
terms=self.config.terms, terms=terms,
) )
# Remove the trailing " = ?" to obtain the pure arithmetic expression. # Remove the trailing " = ?" to obtain the pure arithmetic expression.
expr: str = problem_str.replace(" = ?", "").strip() expr: str = problem_str.replace(" = ?", "").strip()
@ -178,7 +185,11 @@ class DecimalArithmeticDataset(ProceduralDataset):
+ problem_str + problem_str
) )
return {"question": problem_str, "answer": str(answer), "metadata": {}} return {
"question": problem_str,
"answer": str(answer),
"metadata": {"decimal_places": decimal_places, "num_terms": terms},
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
""" """
@ -207,5 +218,34 @@ class DecimalArithmeticDataset(ProceduralDataset):
return 0.0 return 0.0
class DecimalArithmeticCurriculum(BaseCurriculum):
"""Curriculum for Decimal Arithmetic"""
def __init__(self):
super().__init__(DecimalArithmeticCurriculum.__name__, DecimalArithmeticConfig)
self._define_attributes(
RangeAttributeDefinition(
name="decimal_places",
levels=[3, 5, 8, 10],
default_level=0,
description="Number of decimal places of the numbers in problem",
attr_type=AttributeType.APPEND,
min_value=3,
lower_field_name="min_num_decimal_places",
upper_field_name="max_num_decimal_places",
),
RangeAttributeDefinition(
name="num_terms",
levels=[2, 3, 4, 6],
default_level=0,
description="Number of terms in the arithmetic expression",
attr_type=AttributeType.APPEND,
min_value=2,
lower_field_name="min_terms",
upper_field_name="max_terms",
),
)
# Register the dataset with the factory. # Register the dataset with the factory.
register_dataset("decimal_arithmetic", DecimalArithmeticDataset, DecimalArithmeticConfig) register_dataset("decimal_arithmetic", DecimalArithmeticDataset, DecimalArithmeticConfig, DecimalArithmeticCurriculum)

View file

@ -3,6 +3,7 @@ from dataclasses import dataclass
from decimal import Decimal, InvalidOperation from decimal import Decimal, InvalidOperation
from typing import Any, Optional from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
@ -160,4 +161,43 @@ class DecimalChainSumDataset(ProceduralDataset):
return 0.0 return 0.0
register_dataset("decimal_chain_sum", DecimalChainSumDataset, DecimalChainSumConfig) class DecimalChainSumCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(DecimalChainSumCurriculum.__name__, DecimalChainSumConfig)
# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="num_terms",
levels=[2, 3, 4, 5],
default_level=0,
description="Maximum number of terms in the expression",
attr_type=AttributeType.APPEND,
min_value=2,
lower_field_name="min_terms",
upper_field_name="max_terms",
),
RangeAttributeDefinition(
name="num_digits",
levels=[1, 2, 4, 10],
default_level=0, # Start with 1-digit numbers
description="Number of digits in each operand",
attr_type=AttributeType.APPEND,
min_value=1,
lower_field_name="min_digits",
upper_field_name="max_digits",
),
RangeAttributeDefinition(
name="decimal_places",
levels=[1, 2, 3, 4],
default_level=0,
description="Number of decimal places in each operand",
attr_type=AttributeType.APPEND,
min_value=1,
lower_field_name="min_decimal_places",
upper_field_name="max_decimal_places",
),
)
register_dataset("decimal_chain_sum", DecimalChainSumDataset, DecimalChainSumConfig, DecimalChainSumCurriculum)

View file

@ -1,6 +1,10 @@
import pytest import pytest
from reasoning_gym.arithmetic.decimal_arithmetic import DecimalArithmeticConfig, DecimalArithmeticDataset from reasoning_gym.arithmetic.decimal_arithmetic import (
DecimalArithmeticConfig,
DecimalArithmeticCurriculum,
DecimalArithmeticDataset,
)
def test_decimal_arithmetic(): def test_decimal_arithmetic():
@ -8,7 +12,7 @@ def test_decimal_arithmetic():
# Easy # Easy
config = DecimalArithmeticConfig( config = DecimalArithmeticConfig(
seed=42, size=2000, min_num_decimal_places=3, max_num_decimal_places=3, precision=5, terms=3 seed=42, size=2000, min_num_decimal_places=3, max_num_decimal_places=3, precision=5, min_terms=2, max_terms=3
) )
dataset = DecimalArithmeticDataset(config) dataset = DecimalArithmeticDataset(config)
@ -23,7 +27,7 @@ def test_decimal_arithmetic():
# M # M
config = DecimalArithmeticConfig( config = DecimalArithmeticConfig(
seed=42, size=2000, min_num_decimal_places=3, max_num_decimal_places=6, precision=8, terms=6 seed=42, size=2000, min_num_decimal_places=3, max_num_decimal_places=6, precision=8, min_terms=3, max_terms=5
) )
dataset = DecimalArithmeticDataset(config) dataset = DecimalArithmeticDataset(config)
@ -37,7 +41,7 @@ def test_decimal_arithmetic():
# H # H
config = DecimalArithmeticConfig( config = DecimalArithmeticConfig(
seed=42, size=2000, min_num_decimal_places=3, max_num_decimal_places=13, precision=15, terms=10 seed=42, size=2000, min_num_decimal_places=3, max_num_decimal_places=13, precision=15, min_terms=3, max_terms=5
) )
dataset = DecimalArithmeticDataset(config) dataset = DecimalArithmeticDataset(config)
@ -48,3 +52,36 @@ def test_decimal_arithmetic():
assert "metadata" in item assert "metadata" in item
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
def test_decimal_arithmetic_curriculum():
"""Test the decimal arithmetic curriculum generation and attribute adjustment"""
curriculum = DecimalArithmeticCurriculum()
base_value = {"size": 200, "seed": 42, "precision": 6}
base_cfg: DecimalArithmeticConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 42
assert base_cfg.size == 200
assert base_cfg.precision == 6
assert base_cfg.min_num_decimal_places == 3 and base_cfg.max_num_decimal_places == 3
# Test incrementing attribute level
curriculum.increment_attr_level("decimal_places")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_num_decimal_places == 3 and increased_cfg.max_num_decimal_places == 5
# Test incrementing attribute level again
curriculum.increment_attr_level("decimal_places")
further_increased_cfg = curriculum.generate_configuration(base_value)
assert further_increased_cfg.min_num_decimal_places == 3 and further_increased_cfg.max_num_decimal_places == 8
# Test decrementing attribute level
curriculum.decrement_attr_level("decimal_places")
decreased_cfg = curriculum.generate_configuration(base_value)
assert decreased_cfg.min_num_decimal_places == 3 and decreased_cfg.max_num_decimal_places == 5
# Test decrementing attribute level to base level
curriculum.decrement_attr_level("decimal_places")
base_level_cfg = curriculum.generate_configuration(base_value)
assert base_level_cfg.min_num_decimal_places == 3 and base_level_cfg.max_num_decimal_places == 3

View file

@ -1,6 +1,6 @@
import pytest import pytest
from reasoning_gym.arithmetic import DecimalChainSumConfig, DecimalChainSumDataset from reasoning_gym.arithmetic import DecimalChainSumConfig, DecimalChainSumCurriculum, DecimalChainSumDataset
def test_decimal_chain_sum_config_validation(): def test_decimal_chain_sum_config_validation():
@ -250,3 +250,46 @@ def test_decimal_precision_scoring():
assert dataset.score_answer("", {"answer": "1.200"}) == 0.0 assert dataset.score_answer("", {"answer": "1.200"}) == 0.0
assert dataset.score_answer("invalid", {"answer": "1.200"}) == 0.0 assert dataset.score_answer("invalid", {"answer": "1.200"}) == 0.0
assert dataset.score_answer("1.2.3", {"answer": "1.200"}) == 0.0 assert dataset.score_answer("1.2.3", {"answer": "1.200"}) == 0.0
def test_decimal_chain_sum_curriculum():
"""Test that the decimal chain sum curriculum works as expected"""
curriculum = DecimalChainSumCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: DecimalChainSumConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_digits == 1 and base_cfg.max_digits == 1
assert base_cfg.min_terms == 2 and base_cfg.max_terms == 2
assert base_cfg.min_decimal_places == 1 and base_cfg.max_decimal_places == 1
# test incrementing attribute levels for num_terms, num_digits, & decimal_places attributes
curriculum.increment_attr_level("num_terms")
curriculum.increment_attr_level("num_digits")
curriculum.increment_attr_level("decimal_places")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_digits == 1 and increased_cfg.max_digits == 2
assert increased_cfg.min_terms == 2 and increased_cfg.max_terms == 3
assert increased_cfg.min_decimal_places == 1 and increased_cfg.max_decimal_places == 2
# test decrementing attribute level for num_digits and decimal_places
curriculum.decrement_attr_level("num_digits")
curriculum.decrement_attr_level("decimal_places")
partially_decreased_cfg = curriculum.generate_configuration(base_value)
assert partially_decreased_cfg.min_digits == 1 and partially_decreased_cfg.max_digits == 1
assert partially_decreased_cfg.min_terms == 2 and partially_decreased_cfg.max_terms == 3
assert partially_decreased_cfg.min_decimal_places == 1 and partially_decreased_cfg.max_decimal_places == 1
# test that trying to decrement below minimum doesn't change configuration
curriculum.decrement_attr_level("num_terms") # Already at minimum
curriculum.decrement_attr_level("num_digits") # Already at minimum
curriculum.decrement_attr_level("decimal_places") # Already at minimum
min_level_cfg = curriculum.generate_configuration(base_value)
assert min_level_cfg.min_digits == 1 and min_level_cfg.max_digits == 1
assert min_level_cfg.min_terms == 2 and min_level_cfg.max_terms == 2
assert min_level_cfg.min_decimal_places == 1 and min_level_cfg.max_decimal_places == 1