diff --git a/reasoning_gym/algebra/simple_equations.py b/reasoning_gym/algebra/simple_equations.py index 143ef126..5fe9d088 100644 --- a/reasoning_gym/algebra/simple_equations.py +++ b/reasoning_gym/algebra/simple_equations.py @@ -81,56 +81,53 @@ class SimpleEquationsDataset(ProceduralDataset): Returns: Tuple of (equation string, solution integer) """ - x = Symbol(variable) + max_attempts = 100 # Prevent infinite loops + + for _ in range(max_attempts): + x = Symbol(variable) - # Generate terms for left side - num_terms = rng.randint(self.config.min_terms, self.config.max_terms) - terms = [] + # Generate terms for left side + num_terms = rng.randint(self.config.min_terms, self.config.max_terms) + terms = [] - # Generate all constant terms first - for _ in range(num_terms): - value = rng.randint(self.config.min_value, self.config.max_value) - terms.append(value) + # Generate all constant terms first + for _ in range(num_terms): + value = rng.randint(self.config.min_value, self.config.max_value) + terms.append(value) - # Replace one random term with the variable term - var_pos = rng.randint(0, num_terms - 1) - if "*" in self.config.operators: - coef = rng.randint(self.config.min_value, self.config.max_value) - terms[var_pos] = coef * x - else: - terms[var_pos] = x + # Replace one random term with the variable term + var_pos = rng.randint(0, num_terms - 1) + if "*" in self.config.operators: + coef = rng.randint(self.config.min_value, self.config.max_value) + terms[var_pos] = coef * x + else: + terms[var_pos] = x - # Apply operators between terms - expr = terms[0] - for i in range(1, num_terms): - op = rng.choice(self.config.operators) - if op == "+": - expr = expr + terms[i] - elif op == "-": - expr = expr - terms[i] - else: # '*' - expr = expr * terms[i] + # Apply operators between terms + expr = terms[0] + for i in range(1, num_terms): + op = rng.choice(self.config.operators) + if op == "+": + expr = expr + terms[i] + elif op == "-": + expr = expr - terms[i] + else: # '*' + expr = expr * terms[i] - left_side = expr + left_side = expr - # Generate right side - right_side = rng.randint(self.config.min_value, self.config.max_value) + # Generate right side + right_side = rng.randint(self.config.min_value, self.config.max_value) - # Create equation - equation = Eq(left_side, right_side) - solutions = solve(equation, x) + # Create equation + equation = Eq(left_side, right_side) + solutions = solve(equation, x) - # Check if we found any solutions - if not solutions: - return self._generate_equation(rng, variable) + # Check if we found any solutions and it's an integer + if solutions and isinstance(solutions[0], sympy.Integer): + return f"{left_side} = {right_side}", int(solutions[0]) - solution = solutions[0] - - # Only return if solution is a positive integer - if not (isinstance(solution, sympy.Integer)): - return self._generate_equation(rng, variable) - - return f"{left_side} = {right_side}", int(solution) + raise ValueError(f"Could not generate valid equation after {max_attempts} attempts") def simple_equations_dataset(