diff --git a/reasoning_gym/arithmetic/basic_arithmetic.py b/reasoning_gym/arithmetic/basic_arithmetic.py index a1c21e70..0452392a 100644 --- a/reasoning_gym/arithmetic/basic_arithmetic.py +++ b/reasoning_gym/arithmetic/basic_arithmetic.py @@ -13,7 +13,7 @@ class BasicArithmeticDatasetConfig: max_terms: int = 6 min_digits: int = 1 max_digits: int = 4 - operators: list[str] = ("+", "-", "*") + operators: list[str] = ("+", "-", "*", "/") allow_parentheses: bool = True allow_negation: bool = True seed: Optional[int] = None @@ -29,7 +29,7 @@ class BasicArithmeticDatasetConfig: assert self.max_digits >= self.min_digits, "max_digits must be >= min_digits" assert len(self.operators) > 0, "must provide at least one operator" for op in self.operators: - assert op in ["+", "-", "*"], f"unsupported operator: {op}" + assert op in ["+", "-", "*", "/"], f"unsupported operator: {op}" class BasicArithmeticDataset(ProceduralDataset): @@ -92,10 +92,32 @@ class BasicArithmeticDataset(ProceduralDataset): parts.append(")") else: for i in range(num_left): - c = rng.randint(-(10**num_digits) + 1, 10**num_digits - 1) - parts.append(str(c)) - if i + 1 < num_left: - parts.append(rng.choice(self.config.operators)) + if i + 1 < num_left or "/" not in self.config.operators: + # For non-division terms or when division isn't used + c = rng.randint(-(10**num_digits) + 1, 10**num_digits - 1) + parts.append(str(c)) + if i + 1 < num_left: + op = rng.choice(self.config.operators) + parts.append(op) + else: + # Handle division case - ensure integer result + expr = "".join(parts) + try: + dividend = eval(expr) # Evaluate left part + # Find potential divisors + divisors = [d for d in range(2, min(abs(dividend), 10**num_digits)) + if dividend % d == 0] + if divisors: + divisor = rng.choice(divisors) + parts.append(str(divisor)) + else: + # Fallback if no divisors found + c = rng.randint(1, 10**num_digits - 1) + parts.append(str(c)) + except: + # Fallback if evaluation fails + c = rng.randint(1, 10**num_digits - 1) + parts.append(str(c)) if num_right > 0: parts.append(rng.choice(self.config.operators)) @@ -140,6 +162,18 @@ class BasicArithmeticDataset(ProceduralDataset): result -= c elif op == "*": result *= c + elif op == "/": + # Find a number that divides result evenly + divisors = [d for d in range(2, min(abs(result), 10**num_digits)) + if result % d == 0] + if divisors: + c = rng.choice(divisors) + result //= c + else: + # Fallback to multiplication if no clean division possible + op = "*" + c = rng.randint(1, 10**num_digits - 1) + result *= c else: raise RuntimeError(f"Unsupported operator: {op}")