diff --git a/README.md b/README.md index df242838..0ff42c93 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ The goal is to generate virtually infinite data with adjustable complexity. #### Algebra Tasks - `SimpleEquationsDataset`: Generate linear equations with one variable to solve (e.g. "3*x + 2 = 14") +- `PolynomialEquationsDataset`: Generate polynomial equations with one variable to solve (e.g. "-6*h**4 + 4*h**2 - 5*h = 0") #### Arithmetic Tasks - `BasicArithmeticDataset`: Generate arithmetic expressions with configurable complexity and operators (+, -, *, /) @@ -24,6 +25,7 @@ The goal is to generate virtually infinite data with adjustable complexity. - `NumberFilteringDataset`: Filter numbers based on comparison with threshold - `NumberSortingDataset`: Sort lists of numbers in ascending or descending order - `WordReversalDataset`: Reverse word order in text spans +- `Sorting` #### Cognition Tasks - `NumberSequenceDataset`: Generate number sequences with discoverable patterns @@ -41,6 +43,35 @@ The goal is to generate virtually infinite data with adjustable complexity. ### Available Generators +### PolynomialEquations + +Generate polynomial equation with configurable complexity: +```python +from reasoning_gym.algebra import PolynomialEquationsConfig, PolynomialEquationsConfig + +config = PolynomialEquationsConfig( + min_terms=3, + max_terms=4, + min_degree=4, + max_degree=4, + min_value=1, + max_value=5, + size=3, + seed=123, +) + +dataset = PolynomialEquationsDataset(config) +for item in dataset: + print(item) +``` + +Example output: +``` +{'question': 'Find the real value(s) of b in the equation: b**4 - b**3 - 5*b**2 = 0', 'answer': '[-1.79128784747792, 0.0, 2.79128784747792]', 'metadata': {'polynomial_expr': 'b**4 - b**3 - 5*b**2', 'variable': 'b', 'degree': 4, 'real_solutions': [-1.79128784747792, 0.0, 2.79128784747792]}} +{'question': 'Solve the polynomial equation for real i:\n3*i**4 + 4*i**3 - 1 = 0', 'answer': '[]', 'metadata': {'polynomial_expr': '3*i**4 + 4*i**3 - 1', 'variable': 'i', 'degree': 4, 'real_solutions': []}} +{'question': 'Solve the polynomial equation for real h:\n7*h**4 - 2*h**2 + h = 0', 'answer': '[-0.6998793469266564, 0.0]', 'metadata': {'polynomial_expr': '7*h**4 - 2*h**2 + h', 'variable': 'h', 'degree': 4, 'real_solutions': [-0.6998793469266564, 0.0]}} +``` + #### Basic Arithmetic Generates arithmetic problems with configurable complexity: ```python diff --git a/reasoning_gym/algebra/__init__.py b/reasoning_gym/algebra/__init__.py index 251f6583..28cfe7a2 100644 --- a/reasoning_gym/algebra/__init__.py +++ b/reasoning_gym/algebra/__init__.py @@ -1,3 +1,11 @@ from .simple_equations import SimpleEquationsConfig, SimpleEquationsDataset, simple_equations_dataset +from .polynomial_equations import PolynomialEquationsConfig, PolynomialEquationsDataset, polynomial_equations_dataset -__all__ = ["SimpleEquationsDataset", "SimpleEquationsConfig", "simple_equations_dataset"] +__all__ = [ + "SimpleEquationsDataset", + "SimpleEquationsConfig", + "simple_equations_dataset", + "PolynomialEquationsConfig", + "PolynomialEquationsDataset", + "polynomial_equations_dataset", +] diff --git a/reasoning_gym/algebra/polynomial_equations.py b/reasoning_gym/algebra/polynomial_equations.py new file mode 100644 index 00000000..ac02c19c --- /dev/null +++ b/reasoning_gym/algebra/polynomial_equations.py @@ -0,0 +1,180 @@ +import random +import string +from dataclasses import dataclass +from typing import Optional, Tuple, List + +import sympy +from sympy import Symbol, Eq, solve, expand + +from ..dataset import ProceduralDataset + + +@dataclass +class PolynomialEquationsConfig: + """ + Configuration for polynomial equation task generation. + """ + + min_terms: int = 2 # Minimum number of polynomial terms + max_terms: int = 4 # Maximum number of polynomial terms + min_value: int = 1 # Minimum value for coefficients + max_value: int = 100 # Maximum value for coefficients + min_degree: int = 1 # Minimum polynomial degree + max_degree: int = 3 # Maximum polynomial degree + operators: Tuple[str, ...] = ( + "+", + "-", + ) # Allowed operators between terms, Avoid adding '*' or '/' because they will affect the degree + seed: Optional[int] = None + size: int = 500 + + def validate(self): + """Validate configuration parameters.""" + assert self.min_terms > 0, "min_terms must be positive." + assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms." + + assert self.min_value > 0, "min_value must be positive." + assert self.max_value >= self.min_value, "max_value must be >= min_value." + + assert self.min_degree >= 1, "min_degree must be >= 1." + assert self.max_degree >= self.min_degree, "max_degree must be >= min_degree." + + allowed_ops = {"+", "-"} + assert len(self.operators) > 0, "operators tuple cannot be empty." + assert all(op in allowed_ops for op in self.operators), "Invalid operator found. Must be a subset of {+, -}." + + +class PolynomialEquationsDataset(ProceduralDataset): + """ + Generates random polynomial equations of degree in [min_degree, max_degree]. + - The polynomial is formed by summing random terms of the form: coeff * x^exponent. + - Then we solve "polynomial_expr = 0" using Sympy. + - The solution may be real or complex; we filter real solutions by default for simplicity. + """ + + def __init__(self, config: PolynomialEquationsConfig): + config.validate() + self.config = config + self._prompt_templates = [ + "Find the real value(s) of {variable} in the equation: {polynomial_expanded} = 0", + "Solve for real {variable}: {polynomial_expanded} = 0", + "Determine the real value(s) of {variable} tha satisfies: {polynomial_expanded} = 0", + "Solve the polynomial equation for real {variable}:\n{polynomial_expanded} = 0", + ] + super().__init__(seed=config.seed, size=config.size) + + def __getitem__(self, idx: int) -> dict: + """ + Generate a single polynomial equation item. + + Returns: + A dict with: + - question: str (e.g. "Solve the polynomial equation: 2*x^2 - 3*x + 1 = 0") + - answer: str (the sorted list of real solutions, e.g. "[0.5, 1.0]") + - metadata: dict with details (polynomial_expr, degree, etc.) + """ + rng = random.Random(self.seed + idx) + + # Get variable and generate polynomial equation in standard form + variable = self._get_variable(rng) + degree = rng.randint(self.config.min_degree, self.config.max_degree) + polynomial_expr = self._generate_polynomial_expr(rng, variable, degree) + polynomial_expanded = expand(polynomial_expr) + + # Solve the polynomial = 0 + # We filter real solutions only + solutions = solve(Eq(polynomial_expanded, 0), variable, dict=False) + real_solutions = [] + for sol in solutions: + if sol.is_real: + # Evaluate symbolic solution to a floating approximation + real_solutions.append(float(sol.evalf())) + real_solutions.sort() + answer_str = str(real_solutions) + + return { + "question": rng.choice(self._prompt_templates).format( + variable=variable, + polynomial_expanded=polynomial_expanded, + ), + "answer": answer_str, + "metadata": { + "polynomial_expr": str(polynomial_expanded), + "variable": variable, + "degree": degree, + "real_solutions": real_solutions, + }, + } + + def _get_variable(self, rng: random.Random) -> str: + """Get a random lowercase variable name""" + return rng.choice(string.ascii_lowercase) + + def _generate_polynomial_expr(self, rng: random.Random, variable: Symbol, degree: int): + """ + Randomly generate a polynomial expression of 'degree'. + We'll use the config parameters: + - min_terms, max_terms: how many total terms to combine + - min_value, max_value: range for coefficients + - operators: to decide sign flips or direct addition + + Args: + rng: Random number generator + variable: Variable symbol to use in equation + degree: Highest degree. We ensure that there is at least one term with exponent=degree + + Returns: + Polynomial string + """ + x = Symbol(variable) + + # Choose the number of terms and their respective degrees + num_terms = rng.randint(self.config.min_terms, self.config.max_terms) + # Keep track of exponents, exponents can repeat or skip but we force the highest exponent + chosen_exponents = [degree] + # Fill the rest randomly in [0, degree] + for _ in range(num_terms - 1): + exp = rng.randint(0, degree) + chosen_exponents.append(exp) + + # Now build the polynomial expression: sum_{term}( coeff * x^exponent ), with optional sign + polynomial_expr = 0 + for exp in chosen_exponents: + coeff = rng.randint(self.config.min_value, self.config.max_value) + # If '-' in operators, we can randomly flip the sign + if "-" in self.config.operators and rng.random() < 0.5: + coeff = -coeff + term_expr = coeff * (x**exp) + polynomial_expr += term_expr + + return polynomial_expr + + +def polynomial_equations_dataset( + min_terms: int = 2, + max_terms: int = 4, + min_value: int = 1, + max_value: int = 100, + min_degree: int = 1, + max_degree: int = 3, + operators: Tuple[str, ...] = ("+", "-"), + seed: Optional[int] = None, + size: int = 500, +) -> PolynomialEquationsDataset: + """ + Factory function for creating a PolynomialEquationsDataset. + Example usage: + dataset = polynomial_equations_dataset(min_degree=2, max_degree=3, ...) + """ + config = PolynomialEquationsConfig( + min_terms=min_terms, + max_terms=max_terms, + min_value=min_value, + max_value=max_value, + min_degree=min_degree, + max_degree=max_degree, + operators=operators, + seed=seed, + size=size, + ) + return PolynomialEquationsDataset(config) diff --git a/tests/test_polynomial_equations.py b/tests/test_polynomial_equations.py new file mode 100644 index 00000000..daa33e2e --- /dev/null +++ b/tests/test_polynomial_equations.py @@ -0,0 +1,118 @@ +import pytest +from sympy import sympify, Symbol + +from reasoning_gym.algebra.polynomial_equations import ( + PolynomialEquationsConfig, + PolynomialEquationsDataset, + polynomial_equations_dataset, +) + + +def test_polynomial_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + PolynomialEquationsConfig(min_terms=0).validate() + + with pytest.raises(AssertionError): + PolynomialEquationsConfig(min_value=0).validate() + + with pytest.raises(AssertionError): + PolynomialEquationsConfig(min_degree=0, max_degree=3).validate() + + with pytest.raises(AssertionError): + PolynomialEquationsConfig(min_degree=4, max_degree=3).validate() + + with pytest.raises(AssertionError): + PolynomialEquationsConfig(operators=("^",)).validate() + + +def test_polynomial_equations_dataset_basic(): + """Test dataset creation and length""" + dataset_size = 50 + config = PolynomialEquationsConfig( + min_terms=2, + max_terms=3, + min_value=1, + max_value=5, + min_degree=1, + max_degree=2, + seed=42, + size=dataset_size, + ) + + dataset = PolynomialEquationsDataset(config) + + assert len(dataset) == dataset_size + + +def test_polynomial_equations_dataset_items(): + """Test that generated items have correct structure""" + ds = polynomial_equations_dataset( + min_terms=2, + max_terms=3, + min_value=1, + max_value=5, + min_degree=1, + max_degree=2, + size=3, + seed=100, + ) + + for item in ds: + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata + assert isinstance(item["metadata"]["polynomial_expr"], str) + assert isinstance(item["metadata"]["variable"], str) + assert isinstance(item["metadata"]["degree"], int) + assert isinstance(item["metadata"]["real_solutions"], list) + + # Check polynomial_expr existence + poly_str = item["metadata"]["polynomial_expr"] + # Ensure it can parse with sympy + sympify(poly_str) + + +def test_polynomial_equations_dataset_deterministic(): + """Test dataset reproducibility with fixed seed.""" + cfg = PolynomialEquationsConfig(seed=999, size=3) + ds1 = PolynomialEquationsDataset(cfg) + ds2 = PolynomialEquationsDataset(cfg) + + for i in range(len(ds1)): + assert ds1[i] == ds2[i], "Polynomial datasets with same seed should match exactly." + + +def test_polynomial_solutions_evaluation(): + """Test that real_solutions satisfy the polynomial equation.""" + ds = polynomial_equations_dataset( + min_terms=2, + max_terms=4, + min_value=1, + max_value=10, + min_degree=1, + max_degree=3, + size=5, + seed=42, + ) + + for item in ds: + # Extract the polynomial expression and solutions + poly_str = item["metadata"]["polynomial_expr"] + real_solutions = item["metadata"]["real_solutions"] + x = Symbol(item["metadata"]["variable"]) + # Parse the polynomial expression + poly_expr = sympify(poly_str) + + # Verify that each solution satisfies the polynomial + for solution in real_solutions: + # Evaluate the expression with the solution substituted + evaluated_value = poly_expr.subs(x, solution) + + # Ensure the evaluated value is close to zero (numerical stability threshold) + assert abs(evaluated_value) < 1e-6, ( + f"Solution {solution} does not satisfy the polynomial {poly_str}. " + f"Evaluated value: {evaluated_value}" + )