mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
feat: Add support for integer division in BasicArithmeticDataset
This commit is contained in:
parent
336fdad55c
commit
7a64273f2e
1 changed files with 40 additions and 6 deletions
|
|
@ -13,7 +13,7 @@ class BasicArithmeticDatasetConfig:
|
|||
max_terms: int = 6
|
||||
min_digits: int = 1
|
||||
max_digits: int = 4
|
||||
operators: list[str] = ("+", "-", "*")
|
||||
operators: list[str] = ("+", "-", "*", "/")
|
||||
allow_parentheses: bool = True
|
||||
allow_negation: bool = True
|
||||
seed: Optional[int] = None
|
||||
|
|
@ -29,7 +29,7 @@ class BasicArithmeticDatasetConfig:
|
|||
assert self.max_digits >= self.min_digits, "max_digits must be >= min_digits"
|
||||
assert len(self.operators) > 0, "must provide at least one operator"
|
||||
for op in self.operators:
|
||||
assert op in ["+", "-", "*"], f"unsupported operator: {op}"
|
||||
assert op in ["+", "-", "*", "/"], f"unsupported operator: {op}"
|
||||
|
||||
|
||||
class BasicArithmeticDataset(ProceduralDataset):
|
||||
|
|
@ -92,10 +92,32 @@ class BasicArithmeticDataset(ProceduralDataset):
|
|||
parts.append(")")
|
||||
else:
|
||||
for i in range(num_left):
|
||||
c = rng.randint(-(10**num_digits) + 1, 10**num_digits - 1)
|
||||
parts.append(str(c))
|
||||
if i + 1 < num_left:
|
||||
parts.append(rng.choice(self.config.operators))
|
||||
if i + 1 < num_left or "/" not in self.config.operators:
|
||||
# For non-division terms or when division isn't used
|
||||
c = rng.randint(-(10**num_digits) + 1, 10**num_digits - 1)
|
||||
parts.append(str(c))
|
||||
if i + 1 < num_left:
|
||||
op = rng.choice(self.config.operators)
|
||||
parts.append(op)
|
||||
else:
|
||||
# Handle division case - ensure integer result
|
||||
expr = "".join(parts)
|
||||
try:
|
||||
dividend = eval(expr) # Evaluate left part
|
||||
# Find potential divisors
|
||||
divisors = [d for d in range(2, min(abs(dividend), 10**num_digits))
|
||||
if dividend % d == 0]
|
||||
if divisors:
|
||||
divisor = rng.choice(divisors)
|
||||
parts.append(str(divisor))
|
||||
else:
|
||||
# Fallback if no divisors found
|
||||
c = rng.randint(1, 10**num_digits - 1)
|
||||
parts.append(str(c))
|
||||
except:
|
||||
# Fallback if evaluation fails
|
||||
c = rng.randint(1, 10**num_digits - 1)
|
||||
parts.append(str(c))
|
||||
|
||||
if num_right > 0:
|
||||
parts.append(rng.choice(self.config.operators))
|
||||
|
|
@ -140,6 +162,18 @@ class BasicArithmeticDataset(ProceduralDataset):
|
|||
result -= c
|
||||
elif op == "*":
|
||||
result *= c
|
||||
elif op == "/":
|
||||
# Find a number that divides result evenly
|
||||
divisors = [d for d in range(2, min(abs(result), 10**num_digits))
|
||||
if result % d == 0]
|
||||
if divisors:
|
||||
c = rng.choice(divisors)
|
||||
result //= c
|
||||
else:
|
||||
# Fallback to multiplication if no clean division possible
|
||||
op = "*"
|
||||
c = rng.randint(1, 10**num_digits - 1)
|
||||
result *= c
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported operator: {op}")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue