diff --git a/reasoning_gym/arithmetic/chain_sum.py b/reasoning_gym/arithmetic/chain_sum.py index 2a3f6b7a..0dccad1f 100644 --- a/reasoning_gym/arithmetic/chain_sum.py +++ b/reasoning_gym/arithmetic/chain_sum.py @@ -56,7 +56,11 @@ class ChainSum: num_terms = item_rng.randint(self.config.min_terms, self.config.max_terms) num_digits = item_rng.randint(self.config.min_digits, self.config.max_digits) - expression, result = self._generate_task(item_rng, num_terms, num_digits) + # Calculate value ranges based on number of digits + min_value = 10 ** (num_digits - 1) # e.g., 100 for 3 digits + max_value = (10 ** num_digits) - 1 # e.g., 999 for 3 digits + + expression, result = self._generate_task(item_rng, num_terms, min_value, max_value) return { "question": f"{expression} =", @@ -68,21 +72,18 @@ class ChainSum: } } - def _generate_task(self, rng: random.Random, num_terms: int, num_digits: int) -> tuple[str, int]: + def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max_value: int) -> tuple[str, int]: """Generate a chain sum task Args: rng: Random number generator num_terms: Number of terms in the expression - num_digits: Number of digits for each number + min_value: Minimum value for generated numbers + max_value: Maximum value for generated numbers Returns: Tuple of (expression string, result integer) """ - # Generate numbers with at least min_digits - min_value = 10 ** (num_digits - 1) # e.g., 100 for 3 digits - max_value = (10 ** num_digits) - 1 # e.g., 999 for 3 digits - if self.config.allow_negation: # Allow both positive and negative numbers in the range constants = [rng.randint(-max_value, max_value) for _ in range(num_terms)]