diff --git a/reasoning_gym/arithmetic/products.py b/reasoning_gym/arithmetic/products.py index 3a25077f..fe777e2b 100644 --- a/reasoning_gym/arithmetic/products.py +++ b/reasoning_gym/arithmetic/products.py @@ -14,6 +14,7 @@ class ProductsConfig: max_terms: int = 2 min_digits: int = 1 max_digits: int = 5 + allow_negation: bool = False seed: Optional[int] = None size: int = 500 @@ -51,7 +52,10 @@ class ProductsDataset(ProceduralDataset): num_digits = rng.randint(self.config.min_digits, self.config.max_digits) # Calculate value ranges based on number of digits - min_value = 0 if num_digits == 1 else 10 ** (num_digits - 1) # Special case for 1 digit + if self.config.allow_negation: + min_value = -1 * 10 ** (num_digits) + 1 + else: + min_value = 0 if num_digits == 1 else 10 ** (num_digits - 1) # Special case for 1 digit max_value = (10**num_digits) - 1 # e.g., 999 for 3 digits expression, result = self._generate_task(rng, num_terms, min_value, max_value) diff --git a/tests/test_products.py b/tests/test_products.py index d794e28a..469ae5fd 100644 --- a/tests/test_products.py +++ b/tests/test_products.py @@ -77,6 +77,37 @@ def test_products_number_ranges(): assert 0 <= num <= 9, f"Number {num} outside valid range for 1 digit" +def test_products_number_ranges_with_negation(): + """Test that generated numbers respect digit constraints""" + # Test 3-digit numbers with negation + config = ProductsConfig( + min_terms=2, + max_terms=2, # Fix to 2 terms for easier testing + min_digits=3, # Should generate numbers >= -999 + max_digits=3, # Should generate numbers <= 999 + allow_negation=True, + size=50, + seed=42, + ) + dataset = ProductsDataset(config) + for i in range(len(dataset)): + item = dataset[i] + expression = item["metadata"]["expression"] + numbers = [int(n) for n in expression.split() if n.isdigit()] + for num in numbers: + assert -999 <= num <= 999, f"Number {num} outside valid range for 3 digits" + + # Test 1-digit numbers with negation + config = ProductsConfig(min_terms=2, max_terms=2, min_digits=1, max_digits=1, size=50, seed=42) + dataset = ProductsDataset(config) + for i in range(len(dataset)): + item = dataset[i] + expression = item["metadata"]["expression"] + numbers = [int(n) for n in expression.split() if n.isdigit()] + for num in numbers: + assert -9 <= num <= 9, f"Number {num} outside valid range for 1 digit" + + def test_products_iteration(): """Test that iteration respects dataset size""" config = ProductsConfig(min_terms=2, max_terms=2, size=5, seed=42) # Small size for testing