From c8aa98f4e8553a04dd47314ecca6e3b2a17cdfe0 Mon Sep 17 00:00:00 2001 From: "Andreas Koepf (aider)" Date: Thu, 23 Jan 2025 11:40:00 +0100 Subject: [PATCH] feat: Add ChainSum class for generating simple arithmetic tasks --- reasoning_gym/arithmetic/__init__.py | 3 + reasoning_gym/arithmetic/chain_sum.py | 87 +++++++++++++++++++++++++++ tests/test_chain_sum.py | 51 ++++++++++++++++ 3 files changed, 141 insertions(+) create mode 100644 reasoning_gym/arithmetic/chain_sum.py create mode 100644 tests/test_chain_sum.py diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index e69de29b..a2593e79 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -0,0 +1,3 @@ +from .chain_sum import ChainSum, ChainSumConfig + +__all__ = ["ChainSum", "ChainSumConfig"] diff --git a/reasoning_gym/arithmetic/chain_sum.py b/reasoning_gym/arithmetic/chain_sum.py new file mode 100644 index 00000000..a9a99699 --- /dev/null +++ b/reasoning_gym/arithmetic/chain_sum.py @@ -0,0 +1,87 @@ +from dataclasses import dataclass +from random import Random +from typing import Optional + + +@dataclass +class ChainSumConfig: + """Configuration for chain sum task generation""" + min_terms: int = 2 + max_terms: int = 6 + min_digits: int = 1 + max_digits: int = 4 + allow_negation: bool = False + seed: Optional[int] = None + size: int = 500 + + def validate(self): + """Validate configuration parameters""" + assert self.min_terms > 0, "min_terms must be positive" + assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms" + assert self.min_digits > 0, "min_digits must be positive" + assert self.max_digits >= self.min_digits, "max_digits must be >= min_digits" + + +class ChainSum: + """Generates simple arithmetic tasks using only + and - operators""" + + def __init__(self, config: ChainSumConfig): + self.config = config + self.config.validate() + self.rng = Random(config.seed) + + def __len__(self) -> int: + return self.config.size + + def __getitem__(self, idx: int) -> dict: + """Generate a single chain sum task + + Args: + idx: Index of the item to generate + + Returns: + dict with keys: + - question: str, the formatted arithmetic expression + - answer: str, the ground truth result + - metadata: dict with generation parameters + """ + # Use seed derived from idx for deterministic generation + item_rng = Random(self.rng.randint(0, 2**32) + idx) + + num_terms = item_rng.randint(self.config.min_terms, self.config.max_terms) + num_digits = item_rng.randint(self.config.min_digits, self.config.max_digits) + + expression, result = self._generate_task(item_rng, num_terms, num_digits) + + return { + "question": f"{expression} =", + "answer": str(result), + "metadata": { + "num_terms": num_terms, + "num_digits": num_digits, + "expression": expression + } + } + + def _generate_task(self, rng: Random, num_terms: int, num_digits: int) -> tuple[str, int]: + """Generate a chain sum task""" + constants = [rng.randint(0, 10**num_digits) for _ in range(num_terms)] + operators = [rng.choice(["+", "-"]) for _ in range(num_terms - 1)] + + # Build expression and compute result + expression_parts = [] + result = constants[0] + + expression_parts.append(str(constants[0])) + for i, op in enumerate(operators): + c = constants[i + 1] + expression_parts.append(op) + expression_parts.append(str(c)) + + if op == "+": + result += c + else: # op == "-" + result -= c + + expression = " ".join(expression_parts) + return expression, result diff --git a/tests/test_chain_sum.py b/tests/test_chain_sum.py new file mode 100644 index 00000000..f50e2335 --- /dev/null +++ b/tests/test_chain_sum.py @@ -0,0 +1,51 @@ +import pytest +from reasoning_gym.arithmetic import ChainSum, ChainSumConfig + + +def test_chain_sum_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = ChainSumConfig(min_terms=0) + config.validate() + + with pytest.raises(AssertionError): + config = ChainSumConfig(min_terms=3, max_terms=2) + config.validate() + + +def test_chain_sum_deterministic(): + """Test that dataset generates same items with same seed""" + config = ChainSumConfig(seed=42, size=10) + dataset1 = ChainSum(config) + dataset2 = ChainSum(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_chain_sum_items(): + """Test basic properties of generated items""" + config = ChainSumConfig( + min_terms=2, + max_terms=4, + min_digits=1, + max_digits=2, + size=100, + seed=42 + ) + dataset = ChainSum(config) + + for i in range(len(dataset)): + item = dataset[i] + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Verify only + and - are used + expression = item["metadata"]["expression"] + assert all(op in ["+", "-", " "] or op.isdigit() for op in expression) + + # Verify the answer matches the expression + answer = eval(expression) # Safe here as we control the expression + assert str(answer) == item["answer"]