diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index 2e8cf322..89ac5cf5 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -6,6 +6,7 @@ from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConf from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset from .chain_sum import ChainSumConfig, ChainSumDataset from .count_bits import CountBitsConfig, CountBitsDataset +from .decimal_arithmetic import DecimalArithmeticConfig, DecimalArithmeticDataset from .decimal_chain_sum import DecimalChainSumConfig, DecimalChainSumDataset from .dice import DiceConfig, DiceDataset from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset @@ -50,4 +51,8 @@ __all__ = [ "DiceDataset", "NumberFormatConfig", "NumberFormatDataset", + "DecimalArithmeticConfig", + "DecimalArithmeticDataset", + "DecimalChainSumConfig", + "DecimalChainSumDataset", ] diff --git a/reasoning_gym/arithmetic/decimal_arithmetic.py b/reasoning_gym/arithmetic/decimal_arithmetic.py index b5877dba..d9465f23 100644 --- a/reasoning_gym/arithmetic/decimal_arithmetic.py +++ b/reasoning_gym/arithmetic/decimal_arithmetic.py @@ -8,7 +8,7 @@ from ..factory import ProceduralDataset, register_dataset @dataclass -class DecimalArithmeticDatasetConfig: +class DecimalArithmeticConfig: """Configuration for decimal arithmetic dataset generation""" min_num_decimal_places: int = 6 @@ -140,7 +140,7 @@ def _eval_ast(node) -> Decimal: class DecimalArithmeticDataset(ProceduralDataset): """Dataset that generates basic arithmetic tasks using Decimal arithmetic and proper operator precedence.""" - def __init__(self, config: DecimalArithmeticDatasetConfig): + def __init__(self, config: DecimalArithmeticConfig): super().__init__(config=config, seed=config.seed, size=config.size) def __getitem__(self, idx: int) -> dict[str, Any]: @@ -202,4 +202,4 @@ class DecimalArithmeticDataset(ProceduralDataset): # Register the dataset with the factory. -register_dataset("decimal_arithmetic", DecimalArithmeticDataset, DecimalArithmeticDatasetConfig) +register_dataset("decimal_arithmetic", DecimalArithmeticDataset, DecimalArithmeticConfig) diff --git a/tests/test_decimal_arithmetic.py b/tests/test_decimal_arithmetic.py index 60a9d91f..3595d97b 100644 --- a/tests/test_decimal_arithmetic.py +++ b/tests/test_decimal_arithmetic.py @@ -1,13 +1,13 @@ import pytest -from reasoning_gym.arithmetic.decimal_arithmetic import DecimalArithmeticDataset, DecimalArithmeticDatasetConfig +from reasoning_gym.arithmetic.decimal_arithmetic import DecimalArithmeticConfig, DecimalArithmeticDataset def test_decimal_arithmetic(): """Test basic properties and solution of generated items""" # Easy - config = DecimalArithmeticDatasetConfig( + config = DecimalArithmeticConfig( seed=42, size=2000, min_num_decimal_places=3, max_num_decimal_places=3, precision=5, terms=3 ) dataset = DecimalArithmeticDataset(config) @@ -22,7 +22,7 @@ def test_decimal_arithmetic(): assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 # M - config = DecimalArithmeticDatasetConfig( + config = DecimalArithmeticConfig( seed=42, size=2000, min_num_decimal_places=3, max_num_decimal_places=6, precision=8, terms=6 ) dataset = DecimalArithmeticDataset(config) @@ -36,7 +36,7 @@ def test_decimal_arithmetic(): assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 # H - config = DecimalArithmeticDatasetConfig( + config = DecimalArithmeticConfig( seed=42, size=2000, min_num_decimal_places=3, max_num_decimal_places=13, precision=15, terms=10 ) dataset = DecimalArithmeticDataset(config)