mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-25 17:10:51 +00:00
refactor: Improve equation generation with max attempts and error handling
This commit is contained in:
parent
572a9c6db2
commit
111f172c3f
1 changed files with 38 additions and 41 deletions
|
|
@ -81,56 +81,53 @@ class SimpleEquationsDataset(ProceduralDataset):
|
|||
Returns:
|
||||
Tuple of (equation string, solution integer)
|
||||
"""
|
||||
x = Symbol(variable)
|
||||
max_attempts = 100 # Prevent infinite loops
|
||||
|
||||
for _ in range(max_attempts):
|
||||
x = Symbol(variable)
|
||||
|
||||
# Generate terms for left side
|
||||
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
|
||||
terms = []
|
||||
# Generate terms for left side
|
||||
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
|
||||
terms = []
|
||||
|
||||
# Generate all constant terms first
|
||||
for _ in range(num_terms):
|
||||
value = rng.randint(self.config.min_value, self.config.max_value)
|
||||
terms.append(value)
|
||||
# Generate all constant terms first
|
||||
for _ in range(num_terms):
|
||||
value = rng.randint(self.config.min_value, self.config.max_value)
|
||||
terms.append(value)
|
||||
|
||||
# Replace one random term with the variable term
|
||||
var_pos = rng.randint(0, num_terms - 1)
|
||||
if "*" in self.config.operators:
|
||||
coef = rng.randint(self.config.min_value, self.config.max_value)
|
||||
terms[var_pos] = coef * x
|
||||
else:
|
||||
terms[var_pos] = x
|
||||
# Replace one random term with the variable term
|
||||
var_pos = rng.randint(0, num_terms - 1)
|
||||
if "*" in self.config.operators:
|
||||
coef = rng.randint(self.config.min_value, self.config.max_value)
|
||||
terms[var_pos] = coef * x
|
||||
else:
|
||||
terms[var_pos] = x
|
||||
|
||||
# Apply operators between terms
|
||||
expr = terms[0]
|
||||
for i in range(1, num_terms):
|
||||
op = rng.choice(self.config.operators)
|
||||
if op == "+":
|
||||
expr = expr + terms[i]
|
||||
elif op == "-":
|
||||
expr = expr - terms[i]
|
||||
else: # '*'
|
||||
expr = expr * terms[i]
|
||||
# Apply operators between terms
|
||||
expr = terms[0]
|
||||
for i in range(1, num_terms):
|
||||
op = rng.choice(self.config.operators)
|
||||
if op == "+":
|
||||
expr = expr + terms[i]
|
||||
elif op == "-":
|
||||
expr = expr - terms[i]
|
||||
else: # '*'
|
||||
expr = expr * terms[i]
|
||||
|
||||
left_side = expr
|
||||
left_side = expr
|
||||
|
||||
# Generate right side
|
||||
right_side = rng.randint(self.config.min_value, self.config.max_value)
|
||||
# Generate right side
|
||||
right_side = rng.randint(self.config.min_value, self.config.max_value)
|
||||
|
||||
# Create equation
|
||||
equation = Eq(left_side, right_side)
|
||||
solutions = solve(equation, x)
|
||||
# Create equation
|
||||
equation = Eq(left_side, right_side)
|
||||
solutions = solve(equation, x)
|
||||
|
||||
# Check if we found any solutions
|
||||
if not solutions:
|
||||
return self._generate_equation(rng, variable)
|
||||
# Check if we found any solutions and it's an integer
|
||||
if solutions and isinstance(solutions[0], sympy.Integer):
|
||||
return f"{left_side} = {right_side}", int(solutions[0])
|
||||
|
||||
solution = solutions[0]
|
||||
|
||||
# Only return if solution is a positive integer
|
||||
if not (isinstance(solution, sympy.Integer)):
|
||||
return self._generate_equation(rng, variable)
|
||||
|
||||
return f"{left_side} = {right_side}", int(solution)
|
||||
raise ValueError(f"Could not generate valid equation after {max_attempts} attempts")
|
||||
|
||||
|
||||
def simple_equations_dataset(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue