refactor: Improve equation generation with max attempts and error handling

This commit is contained in:
Andreas Koepf (aider) 2025-01-24 19:09:10 +01:00
parent 572a9c6db2
commit 111f172c3f

View file

@ -81,56 +81,53 @@ class SimpleEquationsDataset(ProceduralDataset):
Returns:
Tuple of (equation string, solution integer)
"""
x = Symbol(variable)
max_attempts = 100 # Prevent infinite loops
for _ in range(max_attempts):
x = Symbol(variable)
# Generate terms for left side
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
terms = []
# Generate terms for left side
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
terms = []
# Generate all constant terms first
for _ in range(num_terms):
value = rng.randint(self.config.min_value, self.config.max_value)
terms.append(value)
# Generate all constant terms first
for _ in range(num_terms):
value = rng.randint(self.config.min_value, self.config.max_value)
terms.append(value)
# Replace one random term with the variable term
var_pos = rng.randint(0, num_terms - 1)
if "*" in self.config.operators:
coef = rng.randint(self.config.min_value, self.config.max_value)
terms[var_pos] = coef * x
else:
terms[var_pos] = x
# Replace one random term with the variable term
var_pos = rng.randint(0, num_terms - 1)
if "*" in self.config.operators:
coef = rng.randint(self.config.min_value, self.config.max_value)
terms[var_pos] = coef * x
else:
terms[var_pos] = x
# Apply operators between terms
expr = terms[0]
for i in range(1, num_terms):
op = rng.choice(self.config.operators)
if op == "+":
expr = expr + terms[i]
elif op == "-":
expr = expr - terms[i]
else: # '*'
expr = expr * terms[i]
# Apply operators between terms
expr = terms[0]
for i in range(1, num_terms):
op = rng.choice(self.config.operators)
if op == "+":
expr = expr + terms[i]
elif op == "-":
expr = expr - terms[i]
else: # '*'
expr = expr * terms[i]
left_side = expr
left_side = expr
# Generate right side
right_side = rng.randint(self.config.min_value, self.config.max_value)
# Generate right side
right_side = rng.randint(self.config.min_value, self.config.max_value)
# Create equation
equation = Eq(left_side, right_side)
solutions = solve(equation, x)
# Create equation
equation = Eq(left_side, right_side)
solutions = solve(equation, x)
# Check if we found any solutions
if not solutions:
return self._generate_equation(rng, variable)
# Check if we found any solutions and it's an integer
if solutions and isinstance(solutions[0], sympy.Integer):
return f"{left_side} = {right_side}", int(solutions[0])
solution = solutions[0]
# Only return if solution is a positive integer
if not (isinstance(solution, sympy.Integer)):
return self._generate_equation(rng, variable)
return f"{left_side} = {right_side}", int(solution)
raise ValueError(f"Could not generate valid equation after {max_attempts} attempts")
def simple_equations_dataset(