diff --git a/reasoning_gym/arithmetic/chain_sum.py b/reasoning_gym/arithmetic/chain_sum.py index 0dccad1f..355bb57b 100644 --- a/reasoning_gym/arithmetic/chain_sum.py +++ b/reasoning_gym/arithmetic/chain_sum.py @@ -57,7 +57,7 @@ class ChainSum: num_digits = item_rng.randint(self.config.min_digits, self.config.max_digits) # Calculate value ranges based on number of digits - min_value = 10 ** (num_digits - 1) # e.g., 100 for 3 digits + min_value = 0 if num_digits == 1 else 10 ** (num_digits - 1) # Special case for 1 digit max_value = (10 ** num_digits) - 1 # e.g., 999 for 3 digits expression, result = self._generate_task(item_rng, num_terms, min_value, max_value) diff --git a/tests/test_chain_sum.py b/tests/test_chain_sum.py index ff6b0484..b5b0ea98 100644 --- a/tests/test_chain_sum.py +++ b/tests/test_chain_sum.py @@ -53,6 +53,7 @@ def test_chain_sum_items(): def test_chain_sum_number_ranges(): """Test that generated numbers respect digit constraints""" + # Test 3-digit numbers config = ChainSumConfig( min_terms=2, max_terms=2, # Fix to 2 terms for easier testing @@ -63,6 +64,27 @@ def test_chain_sum_number_ranges(): ) dataset = ChainSum(config) + for i in range(len(dataset)): + item = dataset[i] + expression = item["metadata"]["expression"] + numbers = [int(n) for n in expression.split() if n.isdigit()] + for num in numbers: + if config.allow_negation: + assert -999 <= num <= 999, f"Number {num} outside valid range for 3 digits" + else: + assert 100 <= num <= 999, f"Number {num} outside valid range for 3 digits" + + # Test 1-digit numbers + config = ChainSumConfig( + min_terms=2, + max_terms=2, + min_digits=1, + max_digits=1, + size=50, + seed=42 + ) + dataset = ChainSum(config) + for i in range(len(dataset)): item = dataset[i] expression = item["metadata"]["expression"] @@ -76,6 +98,18 @@ def test_chain_sum_number_ranges(): assert -999 <= num <= 999, f"Number {num} outside valid range for 3 digits" else: assert 100 <= num <= 999, f"Number {num} outside valid range for 3 digits" + + # Test 1-digit numbers + dataset = ChainSum(config) + for i in range(len(dataset)): + item = dataset[i] + expression = item["metadata"]["expression"] + numbers = [int(n) for n in expression.split() if n.isdigit()] + for num in numbers: + if config.allow_negation: + assert -9 <= num <= 9, f"Number {num} outside valid range for 1 digit" + else: + assert 0 <= num <= 9, f"Number {num} outside valid range for 1 digit" def test_chain_sum_negation(): """Test that allow_negation controls number ranges"""