mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-27 17:23:19 +00:00
Merge pull request #1 from panispani/polynomial
Add polynomial equations (extension of simple equations)
This commit is contained in:
commit
c6a4931eae
4 changed files with 338 additions and 1 deletions
31
README.md
31
README.md
|
|
@ -8,6 +8,7 @@ The goal is to generate virtually infinite data with adjustable complexity.
|
|||
|
||||
#### Algebra Tasks
|
||||
- `SimpleEquationsDataset`: Generate linear equations with one variable to solve (e.g. "3*x + 2 = 14")
|
||||
- `PolynomialEquationsDataset`: Generate polynomial equations with one variable to solve (e.g. "-6*h**4 + 4*h**2 - 5*h = 0")
|
||||
|
||||
#### Arithmetic Tasks
|
||||
- `BasicArithmeticDataset`: Generate arithmetic expressions with configurable complexity and operators (+, -, *, /)
|
||||
|
|
@ -24,6 +25,7 @@ The goal is to generate virtually infinite data with adjustable complexity.
|
|||
- `NumberFilteringDataset`: Filter numbers based on comparison with threshold
|
||||
- `NumberSortingDataset`: Sort lists of numbers in ascending or descending order
|
||||
- `WordReversalDataset`: Reverse word order in text spans
|
||||
- `Sorting`
|
||||
|
||||
#### Cognition Tasks
|
||||
- `NumberSequenceDataset`: Generate number sequences with discoverable patterns
|
||||
|
|
@ -41,6 +43,35 @@ The goal is to generate virtually infinite data with adjustable complexity.
|
|||
|
||||
### Available Generators
|
||||
|
||||
### PolynomialEquations
|
||||
|
||||
Generate polynomial equation with configurable complexity:
|
||||
```python
|
||||
from reasoning_gym.algebra import PolynomialEquationsConfig, PolynomialEquationsConfig
|
||||
|
||||
config = PolynomialEquationsConfig(
|
||||
min_terms=3,
|
||||
max_terms=4,
|
||||
min_degree=4,
|
||||
max_degree=4,
|
||||
min_value=1,
|
||||
max_value=5,
|
||||
size=3,
|
||||
seed=123,
|
||||
)
|
||||
|
||||
dataset = PolynomialEquationsDataset(config)
|
||||
for item in dataset:
|
||||
print(item)
|
||||
```
|
||||
|
||||
Example output:
|
||||
```
|
||||
{'question': 'Find the real value(s) of b in the equation: b**4 - b**3 - 5*b**2 = 0', 'answer': '[-1.79128784747792, 0.0, 2.79128784747792]', 'metadata': {'polynomial_expr': 'b**4 - b**3 - 5*b**2', 'variable': 'b', 'degree': 4, 'real_solutions': [-1.79128784747792, 0.0, 2.79128784747792]}}
|
||||
{'question': 'Solve the polynomial equation for real i:\n3*i**4 + 4*i**3 - 1 = 0', 'answer': '[]', 'metadata': {'polynomial_expr': '3*i**4 + 4*i**3 - 1', 'variable': 'i', 'degree': 4, 'real_solutions': []}}
|
||||
{'question': 'Solve the polynomial equation for real h:\n7*h**4 - 2*h**2 + h = 0', 'answer': '[-0.6998793469266564, 0.0]', 'metadata': {'polynomial_expr': '7*h**4 - 2*h**2 + h', 'variable': 'h', 'degree': 4, 'real_solutions': [-0.6998793469266564, 0.0]}}
|
||||
```
|
||||
|
||||
#### Basic Arithmetic
|
||||
Generates arithmetic problems with configurable complexity:
|
||||
```python
|
||||
|
|
|
|||
|
|
@ -1,3 +1,11 @@
|
|||
from .simple_equations import SimpleEquationsConfig, SimpleEquationsDataset, simple_equations_dataset
|
||||
from .polynomial_equations import PolynomialEquationsConfig, PolynomialEquationsDataset, polynomial_equations_dataset
|
||||
|
||||
__all__ = ["SimpleEquationsDataset", "SimpleEquationsConfig", "simple_equations_dataset"]
|
||||
__all__ = [
|
||||
"SimpleEquationsDataset",
|
||||
"SimpleEquationsConfig",
|
||||
"simple_equations_dataset",
|
||||
"PolynomialEquationsConfig",
|
||||
"PolynomialEquationsDataset",
|
||||
"polynomial_equations_dataset",
|
||||
]
|
||||
|
|
|
|||
180
reasoning_gym/algebra/polynomial_equations.py
Normal file
180
reasoning_gym/algebra/polynomial_equations.py
Normal file
|
|
@ -0,0 +1,180 @@
|
|||
import random
|
||||
import string
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
import sympy
|
||||
from sympy import Symbol, Eq, solve, expand
|
||||
|
||||
from ..dataset import ProceduralDataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolynomialEquationsConfig:
|
||||
"""
|
||||
Configuration for polynomial equation task generation.
|
||||
"""
|
||||
|
||||
min_terms: int = 2 # Minimum number of polynomial terms
|
||||
max_terms: int = 4 # Maximum number of polynomial terms
|
||||
min_value: int = 1 # Minimum value for coefficients
|
||||
max_value: int = 100 # Maximum value for coefficients
|
||||
min_degree: int = 1 # Minimum polynomial degree
|
||||
max_degree: int = 3 # Maximum polynomial degree
|
||||
operators: Tuple[str, ...] = (
|
||||
"+",
|
||||
"-",
|
||||
) # Allowed operators between terms, Avoid adding '*' or '/' because they will affect the degree
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self):
|
||||
"""Validate configuration parameters."""
|
||||
assert self.min_terms > 0, "min_terms must be positive."
|
||||
assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms."
|
||||
|
||||
assert self.min_value > 0, "min_value must be positive."
|
||||
assert self.max_value >= self.min_value, "max_value must be >= min_value."
|
||||
|
||||
assert self.min_degree >= 1, "min_degree must be >= 1."
|
||||
assert self.max_degree >= self.min_degree, "max_degree must be >= min_degree."
|
||||
|
||||
allowed_ops = {"+", "-"}
|
||||
assert len(self.operators) > 0, "operators tuple cannot be empty."
|
||||
assert all(op in allowed_ops for op in self.operators), "Invalid operator found. Must be a subset of {+, -}."
|
||||
|
||||
|
||||
class PolynomialEquationsDataset(ProceduralDataset):
|
||||
"""
|
||||
Generates random polynomial equations of degree in [min_degree, max_degree].
|
||||
- The polynomial is formed by summing random terms of the form: coeff * x^exponent.
|
||||
- Then we solve "polynomial_expr = 0" using Sympy.
|
||||
- The solution may be real or complex; we filter real solutions by default for simplicity.
|
||||
"""
|
||||
|
||||
def __init__(self, config: PolynomialEquationsConfig):
|
||||
config.validate()
|
||||
self.config = config
|
||||
self._prompt_templates = [
|
||||
"Find the real value(s) of {variable} in the equation: {polynomial_expanded} = 0",
|
||||
"Solve for real {variable}: {polynomial_expanded} = 0",
|
||||
"Determine the real value(s) of {variable} tha satisfies: {polynomial_expanded} = 0",
|
||||
"Solve the polynomial equation for real {variable}:\n{polynomial_expanded} = 0",
|
||||
]
|
||||
super().__init__(seed=config.seed, size=config.size)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""
|
||||
Generate a single polynomial equation item.
|
||||
|
||||
Returns:
|
||||
A dict with:
|
||||
- question: str (e.g. "Solve the polynomial equation: 2*x^2 - 3*x + 1 = 0")
|
||||
- answer: str (the sorted list of real solutions, e.g. "[0.5, 1.0]")
|
||||
- metadata: dict with details (polynomial_expr, degree, etc.)
|
||||
"""
|
||||
rng = random.Random(self.seed + idx)
|
||||
|
||||
# Get variable and generate polynomial equation in standard form
|
||||
variable = self._get_variable(rng)
|
||||
degree = rng.randint(self.config.min_degree, self.config.max_degree)
|
||||
polynomial_expr = self._generate_polynomial_expr(rng, variable, degree)
|
||||
polynomial_expanded = expand(polynomial_expr)
|
||||
|
||||
# Solve the polynomial = 0
|
||||
# We filter real solutions only
|
||||
solutions = solve(Eq(polynomial_expanded, 0), variable, dict=False)
|
||||
real_solutions = []
|
||||
for sol in solutions:
|
||||
if sol.is_real:
|
||||
# Evaluate symbolic solution to a floating approximation
|
||||
real_solutions.append(float(sol.evalf()))
|
||||
real_solutions.sort()
|
||||
answer_str = str(real_solutions)
|
||||
|
||||
return {
|
||||
"question": rng.choice(self._prompt_templates).format(
|
||||
variable=variable,
|
||||
polynomial_expanded=polynomial_expanded,
|
||||
),
|
||||
"answer": answer_str,
|
||||
"metadata": {
|
||||
"polynomial_expr": str(polynomial_expanded),
|
||||
"variable": variable,
|
||||
"degree": degree,
|
||||
"real_solutions": real_solutions,
|
||||
},
|
||||
}
|
||||
|
||||
def _get_variable(self, rng: random.Random) -> str:
|
||||
"""Get a random lowercase variable name"""
|
||||
return rng.choice(string.ascii_lowercase)
|
||||
|
||||
def _generate_polynomial_expr(self, rng: random.Random, variable: Symbol, degree: int):
|
||||
"""
|
||||
Randomly generate a polynomial expression of 'degree'.
|
||||
We'll use the config parameters:
|
||||
- min_terms, max_terms: how many total terms to combine
|
||||
- min_value, max_value: range for coefficients
|
||||
- operators: to decide sign flips or direct addition
|
||||
|
||||
Args:
|
||||
rng: Random number generator
|
||||
variable: Variable symbol to use in equation
|
||||
degree: Highest degree. We ensure that there is at least one term with exponent=degree
|
||||
|
||||
Returns:
|
||||
Polynomial string
|
||||
"""
|
||||
x = Symbol(variable)
|
||||
|
||||
# Choose the number of terms and their respective degrees
|
||||
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
|
||||
# Keep track of exponents, exponents can repeat or skip but we force the highest exponent
|
||||
chosen_exponents = [degree]
|
||||
# Fill the rest randomly in [0, degree]
|
||||
for _ in range(num_terms - 1):
|
||||
exp = rng.randint(0, degree)
|
||||
chosen_exponents.append(exp)
|
||||
|
||||
# Now build the polynomial expression: sum_{term}( coeff * x^exponent ), with optional sign
|
||||
polynomial_expr = 0
|
||||
for exp in chosen_exponents:
|
||||
coeff = rng.randint(self.config.min_value, self.config.max_value)
|
||||
# If '-' in operators, we can randomly flip the sign
|
||||
if "-" in self.config.operators and rng.random() < 0.5:
|
||||
coeff = -coeff
|
||||
term_expr = coeff * (x**exp)
|
||||
polynomial_expr += term_expr
|
||||
|
||||
return polynomial_expr
|
||||
|
||||
|
||||
def polynomial_equations_dataset(
|
||||
min_terms: int = 2,
|
||||
max_terms: int = 4,
|
||||
min_value: int = 1,
|
||||
max_value: int = 100,
|
||||
min_degree: int = 1,
|
||||
max_degree: int = 3,
|
||||
operators: Tuple[str, ...] = ("+", "-"),
|
||||
seed: Optional[int] = None,
|
||||
size: int = 500,
|
||||
) -> PolynomialEquationsDataset:
|
||||
"""
|
||||
Factory function for creating a PolynomialEquationsDataset.
|
||||
Example usage:
|
||||
dataset = polynomial_equations_dataset(min_degree=2, max_degree=3, ...)
|
||||
"""
|
||||
config = PolynomialEquationsConfig(
|
||||
min_terms=min_terms,
|
||||
max_terms=max_terms,
|
||||
min_value=min_value,
|
||||
max_value=max_value,
|
||||
min_degree=min_degree,
|
||||
max_degree=max_degree,
|
||||
operators=operators,
|
||||
seed=seed,
|
||||
size=size,
|
||||
)
|
||||
return PolynomialEquationsDataset(config)
|
||||
118
tests/test_polynomial_equations.py
Normal file
118
tests/test_polynomial_equations.py
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
import pytest
|
||||
from sympy import sympify, Symbol
|
||||
|
||||
from reasoning_gym.algebra.polynomial_equations import (
|
||||
PolynomialEquationsConfig,
|
||||
PolynomialEquationsDataset,
|
||||
polynomial_equations_dataset,
|
||||
)
|
||||
|
||||
|
||||
def test_polynomial_config_validation():
|
||||
"""Test that invalid configs raise appropriate errors"""
|
||||
with pytest.raises(AssertionError):
|
||||
PolynomialEquationsConfig(min_terms=0).validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
PolynomialEquationsConfig(min_value=0).validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
PolynomialEquationsConfig(min_degree=0, max_degree=3).validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
PolynomialEquationsConfig(min_degree=4, max_degree=3).validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
PolynomialEquationsConfig(operators=("^",)).validate()
|
||||
|
||||
|
||||
def test_polynomial_equations_dataset_basic():
|
||||
"""Test dataset creation and length"""
|
||||
dataset_size = 50
|
||||
config = PolynomialEquationsConfig(
|
||||
min_terms=2,
|
||||
max_terms=3,
|
||||
min_value=1,
|
||||
max_value=5,
|
||||
min_degree=1,
|
||||
max_degree=2,
|
||||
seed=42,
|
||||
size=dataset_size,
|
||||
)
|
||||
|
||||
dataset = PolynomialEquationsDataset(config)
|
||||
|
||||
assert len(dataset) == dataset_size
|
||||
|
||||
|
||||
def test_polynomial_equations_dataset_items():
|
||||
"""Test that generated items have correct structure"""
|
||||
ds = polynomial_equations_dataset(
|
||||
min_terms=2,
|
||||
max_terms=3,
|
||||
min_value=1,
|
||||
max_value=5,
|
||||
min_degree=1,
|
||||
max_degree=2,
|
||||
size=3,
|
||||
seed=100,
|
||||
)
|
||||
|
||||
for item in ds:
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Check metadata
|
||||
assert isinstance(item["metadata"]["polynomial_expr"], str)
|
||||
assert isinstance(item["metadata"]["variable"], str)
|
||||
assert isinstance(item["metadata"]["degree"], int)
|
||||
assert isinstance(item["metadata"]["real_solutions"], list)
|
||||
|
||||
# Check polynomial_expr existence
|
||||
poly_str = item["metadata"]["polynomial_expr"]
|
||||
# Ensure it can parse with sympy
|
||||
sympify(poly_str)
|
||||
|
||||
|
||||
def test_polynomial_equations_dataset_deterministic():
|
||||
"""Test dataset reproducibility with fixed seed."""
|
||||
cfg = PolynomialEquationsConfig(seed=999, size=3)
|
||||
ds1 = PolynomialEquationsDataset(cfg)
|
||||
ds2 = PolynomialEquationsDataset(cfg)
|
||||
|
||||
for i in range(len(ds1)):
|
||||
assert ds1[i] == ds2[i], "Polynomial datasets with same seed should match exactly."
|
||||
|
||||
|
||||
def test_polynomial_solutions_evaluation():
|
||||
"""Test that real_solutions satisfy the polynomial equation."""
|
||||
ds = polynomial_equations_dataset(
|
||||
min_terms=2,
|
||||
max_terms=4,
|
||||
min_value=1,
|
||||
max_value=10,
|
||||
min_degree=1,
|
||||
max_degree=3,
|
||||
size=5,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
for item in ds:
|
||||
# Extract the polynomial expression and solutions
|
||||
poly_str = item["metadata"]["polynomial_expr"]
|
||||
real_solutions = item["metadata"]["real_solutions"]
|
||||
x = Symbol(item["metadata"]["variable"])
|
||||
# Parse the polynomial expression
|
||||
poly_expr = sympify(poly_str)
|
||||
|
||||
# Verify that each solution satisfies the polynomial
|
||||
for solution in real_solutions:
|
||||
# Evaluate the expression with the solution substituted
|
||||
evaluated_value = poly_expr.subs(x, solution)
|
||||
|
||||
# Ensure the evaluated value is close to zero (numerical stability threshold)
|
||||
assert abs(evaluated_value) < 1e-6, (
|
||||
f"Solution {solution} does not satisfy the polynomial {poly_str}. "
|
||||
f"Evaluated value: {evaluated_value}"
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue