diff --git a/reasoning_gym/arithmetic/basic_arithmetic.py b/reasoning_gym/arithmetic/basic_arithmetic.py index 0452392a..f0b54398 100644 --- a/reasoning_gym/arithmetic/basic_arithmetic.py +++ b/reasoning_gym/arithmetic/basic_arithmetic.py @@ -32,6 +32,33 @@ class BasicArithmeticDatasetConfig: assert op in ["+", "-", "*", "/"], f"unsupported operator: {op}" +def find_common_divisors(a: int, b: int) -> list[int]: + # Helper function to find GCD using Euclidean algorithm + def gcd(x, y): + while y: + x, y = y, x % y + return x + + # Get the GCD of the two numbers + gcd_value = gcd(abs(a), abs(b)) + # Find all divisors of the GCD + divisors = [] + i = 1 + # We only need to check up to sqrt(gcd_value) + while i * i <= gcd_value: + if gcd_value % i == 0: + divisors.append(i) + # Don't add the same number twice for perfect squares + if i * i != gcd_value: + divisors.append(gcd_value // i) + i += 1 + return divisors + + +def eval_floordiv(exp: str) -> int: + return eval(exp.replace("/", "//")) + + class BasicArithmeticDataset(ProceduralDataset): """Dataset that generates basic arithmetic tasks with configurable complexity""" @@ -77,53 +104,71 @@ class BasicArithmeticDataset(ProceduralDataset): def _generate_complex_task(self, rng: Random, num_terms: int, num_digits: int) -> tuple[str, int]: """Generate a complex arithmetic task with possible parentheses""" - parts = [] - def add_terms(remaining: int): + def add_terms(remaining: int) -> list[str]: + # split terms randomly into left and right num_left = rng.randint(1, remaining) num_right = remaining - num_left + left_parts = [] if num_left > 1 and rng.random() > 0.5 and self.config.allow_parentheses: if rng.random() > 0.5 and self.config.allow_negation: - parts.append("-(") + left_parts.append("-(") else: - parts.append("(") - add_terms(num_left) - parts.append(")") + left_parts.append("(") + left_parts.extend(add_terms(num_left)) + left_parts.append(")") else: for i in range(num_left): - if i + 1 < num_left or "/" not in self.config.operators: - # For non-division terms or when division isn't used - c = rng.randint(-(10**num_digits) + 1, 10**num_digits - 1) - parts.append(str(c)) - if i + 1 < num_left: - op = rng.choice(self.config.operators) - parts.append(op) + c = rng.randint(-(10**num_digits) + 1, 10**num_digits - 1) + left_parts.append(str(c)) + if i + 1 < num_left: + left_parts.append(rng.choice([o for o in self.config.operators if o != "/"])) + + if num_right == 0: + return left_parts + + op = rng.choice(self.config.operators) + if op != "/": + left_parts.append(op) + left_parts.extend(add_terms(num_right)) + else: + # left part has parantheses or no division + dividend = eval_floordiv("".join(left_parts) if left_parts[-1] == ")" else left_parts[-1]) + left_parts.append(op) + + if num_right > 1: + right_parts = add_terms(num_right - 1) + if right_parts[-1] == ")": + right_value = eval_floordiv("".join(right_parts)) + + if right_value == 0: + correction = 1 + else: + target = rng.choice(find_common_divisors(dividend, right_value)) + correction = target - right_value + + right_parts.pop() + right_parts.append("+") + right_parts.append(str(correction)) + right_parts.append(")") + else: - # Handle division case - ensure integer result - expr = "".join(parts) - try: - dividend = eval(expr) # Evaluate left part - # Find potential divisors - divisors = [d for d in range(2, min(abs(dividend), 10**num_digits)) - if dividend % d == 0] - if divisors: - divisor = rng.choice(divisors) - parts.append(str(divisor)) - else: - # Fallback if no divisors found - c = rng.randint(1, 10**num_digits - 1) - parts.append(str(c)) - except: - # Fallback if evaluation fails - c = rng.randint(1, 10**num_digits - 1) - parts.append(str(c)) + divisor = rng.choice(find_common_divisors(dividend, 0)) + left_parts.append(str(divisor)) + left_parts.append("+") - if num_right > 0: - parts.append(rng.choice(self.config.operators)) - add_terms(num_right) + left_parts.extend(right_parts) + else: + if dividend != 0: + divisor = rng.choice(find_common_divisors(dividend, 0)) + else: + divisor = rng.randint(1, 10**num_digits - 1) + left_parts.append(str(divisor)) - add_terms(num_terms) + return left_parts + + parts = add_terms(num_terms) # Add whitespace according to config if self.config.whitespace == "no_space": @@ -137,7 +182,7 @@ class BasicArithmeticDataset(ProceduralDataset): space_parts.append(" ") space_parts.append(p) expression = "".join(space_parts).strip() - result = eval(expression) # Note: eval is safe here as we control the input + result = eval_floordiv(expression) # Note: eval is safe here as we control the input return expression, result @@ -164,8 +209,7 @@ class BasicArithmeticDataset(ProceduralDataset): result *= c elif op == "/": # Find a number that divides result evenly - divisors = [d for d in range(2, min(abs(result), 10**num_digits)) - if result % d == 0] + divisors = [d for d in range(2, min(abs(result), 10**num_digits)) if result % d == 0] if divisors: c = rng.choice(divisors) result //= c @@ -194,7 +238,7 @@ def basic_arithmetic_dataset( max_terms: int = 6, min_digits: int = 1, max_digits: int = 4, - operators: list[str] = ("+", "-", "*"), + operators: list[str] = ("+", "-", "*", "/"), allow_parentheses: bool = True, allow_negation: bool = True, seed: Optional[int] = None, diff --git a/tests/test_arithmetic.py b/tests/test_arithmetic.py index 8472b5fa..3d3d08b5 100644 --- a/tests/test_arithmetic.py +++ b/tests/test_arithmetic.py @@ -2,7 +2,11 @@ from random import Random import pytest -from reasoning_gym.arithmetic.basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig +from reasoning_gym.arithmetic.basic_arithmetic import ( + BasicArithmeticDataset, + BasicArithmeticDatasetConfig, + eval_floordiv, +) def test_arithmetic_dataset_config_validation(): @@ -44,7 +48,7 @@ def test_arithmetic_dataset_items(): # Verify the answer matches the expression expression = item["metadata"]["expression"] - answer = eval(expression) # Safe here as we control the expression + answer = eval_floordiv(expression) # Safe here as we control the expression assert str(answer) == item["answer"]