diff --git a/README.md b/README.md index 8b5ef4f6..1f2d9a79 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,9 @@ The goal is to generate virtually infinite data with adjustable complexity. ### Task Overview +#### Algebra Tasks +- `SimpleEquationsDataset`: Generate linear equations with one variable to solve (e.g. "3*x + 2 = 14") + #### Arithmetic Tasks - `BasicArithmeticDataset`: Generate arithmetic expressions with configurable complexity and operators (+, -, *) - `ChainSum`: Generate addition/subtraction chains with configurable length and digit counts diff --git a/reasoning_gym/__init__.py b/reasoning_gym/__init__.py index e894ba12..35065d18 100644 --- a/reasoning_gym/__init__.py +++ b/reasoning_gym/__init__.py @@ -2,7 +2,7 @@ Reasoning Gym - A library of procedural dataset generators for training reasoning models """ -from . import algorithmic, arithmetic, cognition, data, games, logic +from . import algorithmic, algebra, arithmetic, cognition, data, games, logic -__version__ = "0.1.0" -__all__ = ["arithmetic", "algorithmic", "cognition", "data", "games", "logic"] +__version__ = "0.1.1" +__all__ = ["arithmetic", "algorithmic", "algebra", "cognition", "data", "games", "logic"] diff --git a/reasoning_gym/algebra/__init__.py b/reasoning_gym/algebra/__init__.py index e69de29b..85574d60 100644 --- a/reasoning_gym/algebra/__init__.py +++ b/reasoning_gym/algebra/__init__.py @@ -0,0 +1,3 @@ +from .simple_equations import SimpleEquationsDataset, SimpleEquationsConfig, simple_equations_dataset + +__all__ = ["SimpleEquationsDataset", "SimpleEquationsConfig", "simple_equations_dataset"] diff --git a/reasoning_gym/algebra/simple_equations.py b/reasoning_gym/algebra/simple_equations.py index b452dae4..e6cc2133 100644 --- a/reasoning_gym/algebra/simple_equations.py +++ b/reasoning_gym/algebra/simple_equations.py @@ -35,6 +35,11 @@ class SimpleEquationsDataset(ProceduralDataset): def __init__(self, config: SimpleEquationsConfig): self.config = config self.config.validate() + self._prompt_templates = [ + "Find the value of {variable} in the equation: {equation}", + "Solve for {variable}: {equation}", + "Determine the value of {variable} that satisfies: {equation}", + ] super().__init__(seed=config.seed, size=config.size) def __getitem__(self, idx: int) -> dict: @@ -53,7 +58,7 @@ class SimpleEquationsDataset(ProceduralDataset): equation, solution = self._generate_equation(rng, variable) return { - "question": equation, + "question": rng.choice(self._prompt_templates).format(variable=variable, equation=equation), "answer": str(solution), "metadata": { "equation": equation, @@ -77,34 +82,45 @@ class SimpleEquationsDataset(ProceduralDataset): """ x = Symbol(variable) - # Generate left side + # Generate terms for left side num_terms = rng.randint(self.config.min_terms, self.config.max_terms) terms = [] - # First term includes the variable - coef = rng.randint(self.config.min_value, self.config.max_value) - terms.append(coef * x) - - # Add remaining terms - for _ in range(num_terms - 1): + # Generate all constant terms first + for _ in range(num_terms): value = rng.randint(self.config.min_value, self.config.max_value) - op = rng.choice(self.config.operators) - - if op == '+': - terms.append(value) - elif op == '-': - terms.append(-value) - else: # '*' - terms[-1] = terms[-1] * value + terms.append(value) - left_side = sum(terms) + # Replace one random term with the variable term + var_pos = rng.randint(0, num_terms - 1) + coef = rng.randint(self.config.min_value, self.config.max_value) + terms[var_pos] = coef * x + + # Apply operators between terms + expr = terms[0] + for i in range(1, num_terms): + op = rng.choice(self.config.operators) + if op == '+': + expr = expr + terms[i] + elif op == '-': + expr = expr - terms[i] + else: # '*' + expr = expr * terms[i] + + left_side = expr # Generate right side right_side = rng.randint(self.config.min_value, self.config.max_value) # Create equation equation = Eq(left_side, right_side) - solution = solve(equation, x)[0] + solutions = solve(equation, x) + + # Check if we found any solutions + if not solutions: + return self._generate_equation(rng, variable) + + solution = solutions[0] # Only return if solution is a positive integer if not (isinstance(solution, sympy.Integer) and solution > 0): @@ -115,7 +131,7 @@ class SimpleEquationsDataset(ProceduralDataset): def simple_equations_dataset( min_terms: int = 2, - max_terms: int = 4, + max_terms: int = 5, min_value: int = 1, max_value: int = 20, operators: tuple = ('+', '-', '*'), diff --git a/reasoning_gym/algebra/test.py b/reasoning_gym/algebra/test.py index 884bba6b..f038c37d 100644 --- a/reasoning_gym/algebra/test.py +++ b/reasoning_gym/algebra/test.py @@ -16,8 +16,22 @@ def test_simple_equations_generation(): # Validate equation format equation = item["metadata"]["equation"] + variable = item["metadata"]["variable"] assert "=" in equation - assert item["metadata"]["variable"] in equation + assert variable in equation + + # Validate question format + question = item["question"] + assert variable in question + assert equation in question + assert any( + prompt in question + for prompt in [ + "Find the value of", + "Solve for", + "Determine the value of" + ] + ) def test_simple_equations_config():