mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
Add PolynomialMultiplicationDataset (#64)
* Add PolynomialMultiplicationDataset
This commit is contained in:
parent
426fa22fcc
commit
847442ef0a
5 changed files with 332 additions and 0 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -21,6 +21,7 @@ wheels/
|
|||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
.python-version
|
||||
|
||||
# Virtual Environment
|
||||
venv/
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ See the [Dataset Gallery](GALLERY.md) for a complete list of available datasets
|
|||
|
||||
- `SimpleEquationsDataset`: Generate linear equations with one variable to solve (e.g. "3\*x + 2 = 14")
|
||||
- `PolynomialEquationsDataset`: Generate polynomial equations with one variable to solve (e.g. "-6*h\*\*4 + 4*h\**2 - 5*h = 0")
|
||||
- `PolynomialMultiplicationDataset`: Generate polynomial multiplicatons (e.g. "(8x^3 + x + 2)*(y - 3)")
|
||||
|
||||
### <small>Arithmetic Tasks</small>
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from .intermediate_integration import IntermediateIntegrationConfig, IntermediateIntegrationDataset
|
||||
from .polynomial_equations import PolynomialEquationsConfig, PolynomialEquationsDataset
|
||||
from .polynomial_multiplication import PolynomialMultiplicationConfig, PolynomialMultiplicationDataset
|
||||
from .simple_equations import SimpleEquationsConfig, SimpleEquationsDataset
|
||||
from .simple_integration import SimpleIntegrationConfig, SimpleIntegrationDataset
|
||||
|
||||
|
|
@ -12,4 +13,6 @@ __all__ = [
|
|||
"SimpleEquationsConfig",
|
||||
"SimpleIntegrationConfig",
|
||||
"SimpleIntegrationDataset",
|
||||
"PolynomialMultiplicationConfig",
|
||||
"PolynomialMultiplicationDataset",
|
||||
]
|
||||
|
|
|
|||
161
reasoning_gym/algebra/polynomial_multiplication.py
Normal file
161
reasoning_gym/algebra/polynomial_multiplication.py
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
import random
|
||||
import string
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import sympy as sp
|
||||
from sympy import Eq, Symbol, expand, solve
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolynomialMultiplicationConfig:
|
||||
"""
|
||||
Configuration for polynomial multiplication task generation.
|
||||
"""
|
||||
|
||||
min_terms: int = 2 # Minimum number of polynomial terms
|
||||
max_terms: int = 4 # Maximum number of polynomial terms
|
||||
min_value: int = 1 # Minimum value for coefficients
|
||||
max_value: int = 100 # Maximum value for coefficients
|
||||
min_degree: int = 1 # Minimum polynomial degree
|
||||
max_degree: int = 3 # Maximum polynomial degree
|
||||
min_polynomials: int = 2 # Minimum number of polynomials being multiplied
|
||||
max_polynomials: int = 3 # Maximum number of polynomials being multiplied
|
||||
single_variable: bool = (True,)
|
||||
operators: Tuple[str, ...] = (
|
||||
"+",
|
||||
"-",
|
||||
) # Allowed operators between terms, Avoid adding '*' or '/' because they will affect the degree
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters."""
|
||||
assert self.min_terms > 0, "min_terms must be positive."
|
||||
assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms."
|
||||
|
||||
assert self.min_value > 0, "min_value must be positive."
|
||||
assert self.max_value >= self.min_value, "max_value must be >= min_value."
|
||||
|
||||
assert self.min_degree >= 1, "min_degree must be >= 1."
|
||||
assert self.max_degree >= self.min_degree, "max_degree must be >= min_degree."
|
||||
|
||||
assert self.min_polynomials >= 2, "min_polynomials must be >= 2."
|
||||
assert self.max_polynomials >= self.min_polynomials, "max_polynomials must be >= min_polynomials."
|
||||
|
||||
allowed_ops = {"+", "-"}
|
||||
assert len(self.operators) > 0, "operators tuple cannot be empty."
|
||||
assert all(op in allowed_ops for op in self.operators), "Invalid operator found. Must be a subset of {+, -}."
|
||||
|
||||
|
||||
class PolynomialMultiplicationDataset(ProceduralDataset):
|
||||
"""
|
||||
Generates [min_polynomials, max_polynomials] random polynomials of degree in [min_degree, max_degree].
|
||||
- The polynomial is formed by summing random terms of the form: coeff * x^exponent.
|
||||
- Then we find "F = P_0 * ... * P_1" using Sympy.
|
||||
"""
|
||||
|
||||
def __init__(self, config: PolynomialMultiplicationConfig):
|
||||
self._prompt_templates = [
|
||||
"Simplify this expression: {polynomial_expr}",
|
||||
"Calculate the following: {polynomial_expr}",
|
||||
]
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""
|
||||
Generate a single polynomial multiplication item.
|
||||
|
||||
Returns:
|
||||
A dict with:
|
||||
- question: str (e.g. "Multiply polynomials: (8x^3 + x + 2)*(x - 3)")
|
||||
- answer: str (Product, e.g. "8x^4 - 24x^3 + x^2 - x - 6")
|
||||
- metadata: dict with details (polynomial_expr, single_variable)
|
||||
"""
|
||||
rng = random.Random(self.seed + idx)
|
||||
number_polynomials = rng.randint(self.config.min_polynomials, self.config.max_polynomials)
|
||||
polynomials = [self._generate_polynomial_expr(rng) for i in range(number_polynomials)]
|
||||
|
||||
polynomial_expr = sp.prod(polynomials)
|
||||
product = sp.expand(polynomial_expr)
|
||||
|
||||
return {
|
||||
"question": rng.choice(self._prompt_templates).format(
|
||||
polynomial_expr=polynomial_expr,
|
||||
),
|
||||
"answer": product,
|
||||
"metadata": {
|
||||
"polynomial_expr": str(polynomial_expr),
|
||||
"single_variable": self.config.single_variable,
|
||||
"result": str(product),
|
||||
},
|
||||
}
|
||||
|
||||
def _get_variable(self, rng: random.Random) -> str:
|
||||
"""Get a random lowercase variable name"""
|
||||
if self.config.single_variable:
|
||||
return "x"
|
||||
return rng.choice(string.ascii_lowercase)
|
||||
|
||||
def _generate_polynomial_expr(self, rng: random.Random):
|
||||
"""
|
||||
Randomly generate a polynomial expression of 'degree'.
|
||||
We'll use the config parameters:
|
||||
- min_terms, max_terms: how many total terms to combine
|
||||
- min_value, max_value: range for coefficients
|
||||
- operators: to decide sign flips or direct addition
|
||||
|
||||
Args:
|
||||
rng: Random number generator
|
||||
|
||||
Returns:
|
||||
Polynomial string
|
||||
"""
|
||||
variable = self._get_variable(rng)
|
||||
degree = rng.randint(self.config.min_degree, self.config.max_degree)
|
||||
|
||||
x = Symbol(variable)
|
||||
|
||||
# Choose the number of terms and their respective degrees
|
||||
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
|
||||
# Keep track of exponents, exponents can repeat or skip but we force the highest exponent
|
||||
chosen_exponents = [degree]
|
||||
# Fill the rest randomly in [0, degree]
|
||||
for _ in range(num_terms - 1):
|
||||
exp = rng.randint(0, degree)
|
||||
chosen_exponents.append(exp)
|
||||
|
||||
# Now build the polynomial expression: sum_{term}( coeff * x^exponent ), with optional sign
|
||||
polynomial_expr = 0
|
||||
for exp in chosen_exponents:
|
||||
coeff = rng.randint(self.config.min_value, self.config.max_value)
|
||||
# If '-' in operators, we can randomly flip the sign
|
||||
if "-" in self.config.operators and rng.random() < 0.5:
|
||||
coeff = -coeff
|
||||
term_expr = coeff * (x**exp)
|
||||
polynomial_expr += term_expr
|
||||
|
||||
return polynomial_expr
|
||||
|
||||
def score_answer(self, answer: Optional[str], metadata: Dict[str, Any]) -> float:
|
||||
reward = 0.0
|
||||
if answer is not None:
|
||||
try:
|
||||
predicted_poly = sp.parse_expr(answer)
|
||||
target_poly = sp.parse_expr(metadata["result"])
|
||||
|
||||
# Check if the difference simplifies to zero (i.e. they are equivalent).
|
||||
if sp.simplify(predicted_poly - target_poly) == 0:
|
||||
reward = 1.0
|
||||
elif answer.strip():
|
||||
reward = 0.05
|
||||
else:
|
||||
reward = 0.01
|
||||
except Exception:
|
||||
reward = 0.01
|
||||
return reward
|
||||
|
||||
|
||||
register_dataset("polynomial_multiplication", PolynomialMultiplicationDataset, PolynomialMultiplicationConfig)
|
||||
166
tests/test_polynomial_multiplication.py
Normal file
166
tests/test_polynomial_multiplication.py
Normal file
|
|
@ -0,0 +1,166 @@
|
|||
import pytest
|
||||
import sympy as sp
|
||||
|
||||
from reasoning_gym import create_dataset
|
||||
from reasoning_gym.algebra.polynomial_multiplication import (
|
||||
PolynomialMultiplicationConfig,
|
||||
PolynomialMultiplicationDataset,
|
||||
)
|
||||
|
||||
|
||||
def test_polynomial_config_validation():
|
||||
"""Test that invalid configs raise appropriate errors"""
|
||||
with pytest.raises(AssertionError):
|
||||
PolynomialMultiplicationConfig(min_terms=0).validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
PolynomialMultiplicationConfig(min_value=0).validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
PolynomialMultiplicationConfig(min_degree=0, max_degree=3).validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
PolynomialMultiplicationConfig(min_degree=4, max_degree=3).validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
PolynomialMultiplicationConfig(operators=("^",)).validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
PolynomialMultiplicationConfig(min_polynomials=5, max_polynomials=2).validate()
|
||||
|
||||
|
||||
def test_polynomial_multiplication_dataset_basic():
|
||||
"""Test dataset creation and length"""
|
||||
dataset_size = 50
|
||||
config = PolynomialMultiplicationConfig(
|
||||
min_terms=2,
|
||||
max_terms=3,
|
||||
min_value=1,
|
||||
max_value=5,
|
||||
min_degree=1,
|
||||
max_degree=2,
|
||||
min_polynomials=2,
|
||||
max_polynomials=3,
|
||||
single_variable=True,
|
||||
seed=42,
|
||||
size=dataset_size,
|
||||
)
|
||||
|
||||
dataset = PolynomialMultiplicationDataset(config)
|
||||
|
||||
assert len(dataset) == dataset_size
|
||||
|
||||
|
||||
def test_polynomial_equations_dataset_items():
|
||||
"""Test that generated items have correct structure"""
|
||||
ds = create_dataset(
|
||||
"polynomial_multiplication",
|
||||
min_terms=2,
|
||||
max_terms=3,
|
||||
min_value=1,
|
||||
max_value=5,
|
||||
min_degree=1,
|
||||
max_degree=2,
|
||||
min_polynomials=2,
|
||||
max_polynomials=5,
|
||||
single_variable=False,
|
||||
size=3,
|
||||
seed=100,
|
||||
)
|
||||
|
||||
for item in ds:
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Check metadata
|
||||
assert isinstance(item["metadata"]["polynomial_expr"], str)
|
||||
assert isinstance(item["metadata"]["single_variable"], bool)
|
||||
|
||||
# Check polynomial_expr existence
|
||||
poly_str = item["metadata"]["polynomial_expr"]
|
||||
# Ensure it can parse with sympy
|
||||
sp.sympify(poly_str)
|
||||
|
||||
|
||||
def test_polynomial_equations_dataset_deterministic():
|
||||
"""Test dataset reproducibility with fixed seed."""
|
||||
cfg = PolynomialMultiplicationConfig(seed=999, size=3)
|
||||
ds1 = PolynomialMultiplicationDataset(cfg)
|
||||
ds2 = PolynomialMultiplicationDataset(cfg)
|
||||
|
||||
for i in range(len(ds1)):
|
||||
assert ds1[i] == ds2[i], "Polynomial datasets with same seed should match exactly."
|
||||
|
||||
|
||||
def test_polynomial_solutions_evaluation():
|
||||
"""Test that solution satisfy the polynomial multiplication."""
|
||||
ds = create_dataset(
|
||||
"polynomial_multiplication",
|
||||
min_terms=2,
|
||||
max_terms=4,
|
||||
min_value=1,
|
||||
max_value=10,
|
||||
min_degree=1,
|
||||
max_degree=3,
|
||||
min_polynomials=2,
|
||||
max_polynomials=5,
|
||||
single_variable=False,
|
||||
size=5,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
for item in ds:
|
||||
# Extract the polynomial expression
|
||||
poly_str = item["metadata"]["polynomial_expr"]
|
||||
# Get the polynomial product
|
||||
poly_expr = sp.expand(poly_str)
|
||||
|
||||
# Verify that each solution satisfies the polynomial
|
||||
assert poly_expr == item["answer"]
|
||||
|
||||
|
||||
def test_score_function():
|
||||
"""Test that solution satisfy the polynomial multiplication."""
|
||||
ds = create_dataset(
|
||||
"polynomial_multiplication",
|
||||
min_terms=2,
|
||||
max_terms=4,
|
||||
min_value=1,
|
||||
max_value=10,
|
||||
min_degree=1,
|
||||
max_degree=3,
|
||||
min_polynomials=2,
|
||||
max_polynomials=5,
|
||||
single_variable=True,
|
||||
size=1,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
assert ds.score_answer(None, ds[0]["metadata"]) == 0.00
|
||||
assert ds.score_answer("6*x**4 + 9*x**3 - 6*x**2 - 39*x - 45", ds[0]["metadata"]) == 1
|
||||
assert ds.score_answer("Not a polynomial", ds[0]["metadata"]) == 0.01
|
||||
assert ds.score_answer("x**4", ds[0]["metadata"]) == 0.05
|
||||
|
||||
|
||||
def test_multivariate_score_function():
|
||||
"""Test that solution satisfy the polynomial multiplication."""
|
||||
ds = create_dataset(
|
||||
"polynomial_multiplication",
|
||||
min_terms=2,
|
||||
max_terms=4,
|
||||
min_value=1,
|
||||
max_value=10,
|
||||
min_degree=1,
|
||||
max_degree=3,
|
||||
min_polynomials=2,
|
||||
max_polynomials=5,
|
||||
single_variable=False,
|
||||
size=1,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
assert ds.score_answer(None, ds[0]["metadata"]) == 0.00
|
||||
assert ds.score_answer("-27*a**3*c - 27*a**3 + 144*a*c + 144*a", ds[0]["metadata"]) == 1
|
||||
assert ds.score_answer("Not a polynomial", ds[0]["metadata"]) == 0.01
|
||||
assert ds.score_answer("x**4", ds[0]["metadata"]) == 0.05
|
||||
Loading…
Add table
Add a link
Reference in a new issue