diff --git a/reasoning_gym/arithmetic/basic_arithmetic.py b/reasoning_gym/arithmetic/basic_arithmetic.py index dc5ff890..1611af59 100644 --- a/reasoning_gym/arithmetic/basic_arithmetic.py +++ b/reasoning_gym/arithmetic/basic_arithmetic.py @@ -142,6 +142,41 @@ class ArithmeticDataset: expression = " ".join(expression_parts) return expression, result + +def chain_sum( + min_terms: int = 2, + max_terms: int = 6, + min_digits: int = 1, + max_digits: int = 4, + allow_negation: bool = False, + seed: Optional[int] = None, + size: int = 500, +) -> ChainSum: + """Create a ChainSum dataset with the given configuration. + + Args: + min_terms: Minimum number of terms in expressions + max_terms: Maximum number of terms in expressions + min_digits: Minimum number of digits in numbers + max_digits: Maximum number of digits in numbers + allow_negation: Whether to allow negative numbers + seed: Random seed for reproducibility + size: Virtual size of the dataset + + Returns: + ChainSum: Configured dataset instance + """ + config = ChainSumConfig( + min_terms=min_terms, + max_terms=max_terms, + min_digits=min_digits, + max_digits=max_digits, + allow_negation=allow_negation, + seed=seed, + size=size, + ) + return ChainSum(config) + def __iter__(self): """Make the dataset iterable""" self._current_idx = 0