diff --git a/reasoning_gym/arithmetic/chain_sum.py b/reasoning_gym/arithmetic/chain_sum.py index a9a99699..a2fa6157 100644 --- a/reasoning_gym/arithmetic/chain_sum.py +++ b/reasoning_gym/arithmetic/chain_sum.py @@ -20,6 +20,10 @@ class ChainSumConfig: assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms" assert self.min_digits > 0, "min_digits must be positive" assert self.max_digits >= self.min_digits, "max_digits must be >= min_digits" + + # Validate digit ranges make sense + if self.min_digits > 1: + assert 10 ** (self.min_digits - 1) >= 1, "min_digits would result in invalid number range" class ChainSum: @@ -64,8 +68,21 @@ class ChainSum: } def _generate_task(self, rng: Random, num_terms: int, num_digits: int) -> tuple[str, int]: - """Generate a chain sum task""" - constants = [rng.randint(0, 10**num_digits) for _ in range(num_terms)] + """Generate a chain sum task + + Args: + rng: Random number generator + num_terms: Number of terms in the expression + num_digits: Number of digits for each number + + Returns: + Tuple of (expression string, result integer) + """ + # Generate numbers with at least min_digits + min_value = 10 ** (num_digits - 1) # e.g., 100 for 3 digits + max_value = (10 ** num_digits) - 1 # e.g., 999 for 3 digits + + constants = [rng.randint(min_value, max_value) for _ in range(num_terms)] operators = [rng.choice(["+", "-"]) for _ in range(num_terms - 1)] # Build expression and compute result diff --git a/tests/test_chain_sum.py b/tests/test_chain_sum.py index f50e2335..75094a20 100644 --- a/tests/test_chain_sum.py +++ b/tests/test_chain_sum.py @@ -49,3 +49,27 @@ def test_chain_sum_items(): # Verify the answer matches the expression answer = eval(expression) # Safe here as we control the expression assert str(answer) == item["answer"] + + +def test_chain_sum_number_ranges(): + """Test that generated numbers respect digit constraints""" + config = ChainSumConfig( + min_terms=2, + max_terms=2, # Fix to 2 terms for easier testing + min_digits=3, # Should generate numbers >= 100 + max_digits=3, # Should generate numbers <= 999 + size=50, + seed=42 + ) + dataset = ChainSum(config) + + for i in range(len(dataset)): + item = dataset[i] + expression = item["metadata"]["expression"] + + # Extract numbers from expression + numbers = [int(n) for n in expression.split() if n.isdigit()] + + # Verify each number is in the correct range + for num in numbers: + assert 100 <= num <= 999, f"Number {num} outside valid range for 3 digits"