diff --git a/reasoning_gym/algebra/simple_equations.py b/reasoning_gym/algebra/simple_equations.py index 0bc64e37..2bd661df 100644 --- a/reasoning_gym/algebra/simple_equations.py +++ b/reasoning_gym/algebra/simple_equations.py @@ -48,15 +48,16 @@ class SimpleEquationsDataset(ProceduralDataset): """ rng = random.Random(self.seed + idx) - # Generate equation and solution - equation, solution = self._generate_equation(rng) + # Get variable and generate equation + variable = self._get_variable(rng) + equation, solution = self._generate_equation(rng, variable) return { "question": equation, "answer": str(solution), "metadata": { "equation": equation, - "variable": self._get_variable(rng), + "variable": variable, } } @@ -64,14 +65,17 @@ class SimpleEquationsDataset(ProceduralDataset): """Get a random lowercase variable name""" return rng.choice(string.ascii_lowercase) - def _generate_equation(self, rng: random.Random) -> Tuple[str, int]: + def _generate_equation(self, rng: random.Random, variable: str) -> Tuple[str, int]: """Generate an equation and its solution + Args: + rng: Random number generator + variable: Variable symbol to use in equation + Returns: Tuple of (equation string, solution integer) """ - var = self._get_variable(rng) - x = Symbol(var) + x = Symbol(variable) # Generate left side num_terms = rng.randint(self.config.min_terms, self.config.max_terms)