feat: Add support for integer division in BasicArithmeticDataset

This commit is contained in:
Andreas Koepf (aider) 2025-01-24 11:30:38 +01:00
parent 336fdad55c
commit 7a64273f2e

View file

@ -13,7 +13,7 @@ class BasicArithmeticDatasetConfig:
max_terms: int = 6
min_digits: int = 1
max_digits: int = 4
operators: list[str] = ("+", "-", "*")
operators: list[str] = ("+", "-", "*", "/")
allow_parentheses: bool = True
allow_negation: bool = True
seed: Optional[int] = None
@ -29,7 +29,7 @@ class BasicArithmeticDatasetConfig:
assert self.max_digits >= self.min_digits, "max_digits must be >= min_digits"
assert len(self.operators) > 0, "must provide at least one operator"
for op in self.operators:
assert op in ["+", "-", "*"], f"unsupported operator: {op}"
assert op in ["+", "-", "*", "/"], f"unsupported operator: {op}"
class BasicArithmeticDataset(ProceduralDataset):
@ -92,10 +92,32 @@ class BasicArithmeticDataset(ProceduralDataset):
parts.append(")")
else:
for i in range(num_left):
c = rng.randint(-(10**num_digits) + 1, 10**num_digits - 1)
parts.append(str(c))
if i + 1 < num_left:
parts.append(rng.choice(self.config.operators))
if i + 1 < num_left or "/" not in self.config.operators:
# For non-division terms or when division isn't used
c = rng.randint(-(10**num_digits) + 1, 10**num_digits - 1)
parts.append(str(c))
if i + 1 < num_left:
op = rng.choice(self.config.operators)
parts.append(op)
else:
# Handle division case - ensure integer result
expr = "".join(parts)
try:
dividend = eval(expr) # Evaluate left part
# Find potential divisors
divisors = [d for d in range(2, min(abs(dividend), 10**num_digits))
if dividend % d == 0]
if divisors:
divisor = rng.choice(divisors)
parts.append(str(divisor))
else:
# Fallback if no divisors found
c = rng.randint(1, 10**num_digits - 1)
parts.append(str(c))
except:
# Fallback if evaluation fails
c = rng.randint(1, 10**num_digits - 1)
parts.append(str(c))
if num_right > 0:
parts.append(rng.choice(self.config.operators))
@ -140,6 +162,18 @@ class BasicArithmeticDataset(ProceduralDataset):
result -= c
elif op == "*":
result *= c
elif op == "/":
# Find a number that divides result evenly
divisors = [d for d in range(2, min(abs(result), 10**num_digits))
if result % d == 0]
if divisors:
c = rng.choice(divisors)
result //= c
else:
# Fallback to multiplication if no clean division possible
op = "*"
c = rng.randint(1, 10**num_digits - 1)
result *= c
else:
raise RuntimeError(f"Unsupported operator: {op}")