Add PolynomialMultiplicationDataset (#64)

* Add PolynomialMultiplicationDataset
This commit is contained in:
tohskai 2025-02-07 14:06:41 +01:00 committed by GitHub
parent 426fa22fcc
commit 847442ef0a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 332 additions and 0 deletions

1
.gitignore vendored
View file

@ -21,6 +21,7 @@ wheels/
*.egg-info/
.installed.cfg
*.egg
.python-version
# Virtual Environment
venv/

View file

@ -72,6 +72,7 @@ See the [Dataset Gallery](GALLERY.md) for a complete list of available datasets
- `SimpleEquationsDataset`: Generate linear equations with one variable to solve (e.g. "3\*x + 2 = 14")
- `PolynomialEquationsDataset`: Generate polynomial equations with one variable to solve (e.g. "-6*h\*\*4 + 4*h\**2 - 5*h = 0")
- `PolynomialMultiplicationDataset`: Generate polynomial multiplicatons (e.g. "(8x^3 + x + 2)*(y - 3)")
### <small>Arithmetic Tasks</small>

View file

@ -1,5 +1,6 @@
from .intermediate_integration import IntermediateIntegrationConfig, IntermediateIntegrationDataset
from .polynomial_equations import PolynomialEquationsConfig, PolynomialEquationsDataset
from .polynomial_multiplication import PolynomialMultiplicationConfig, PolynomialMultiplicationDataset
from .simple_equations import SimpleEquationsConfig, SimpleEquationsDataset
from .simple_integration import SimpleIntegrationConfig, SimpleIntegrationDataset
@ -12,4 +13,6 @@ __all__ = [
"SimpleEquationsConfig",
"SimpleIntegrationConfig",
"SimpleIntegrationDataset",
"PolynomialMultiplicationConfig",
"PolynomialMultiplicationDataset",
]

View file

@ -0,0 +1,161 @@
import random
import string
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
import sympy as sp
from sympy import Eq, Symbol, expand, solve
from ..factory import ProceduralDataset, register_dataset
@dataclass
class PolynomialMultiplicationConfig:
"""
Configuration for polynomial multiplication task generation.
"""
min_terms: int = 2 # Minimum number of polynomial terms
max_terms: int = 4 # Maximum number of polynomial terms
min_value: int = 1 # Minimum value for coefficients
max_value: int = 100 # Maximum value for coefficients
min_degree: int = 1 # Minimum polynomial degree
max_degree: int = 3 # Maximum polynomial degree
min_polynomials: int = 2 # Minimum number of polynomials being multiplied
max_polynomials: int = 3 # Maximum number of polynomials being multiplied
single_variable: bool = (True,)
operators: Tuple[str, ...] = (
"+",
"-",
) # Allowed operators between terms, Avoid adding '*' or '/' because they will affect the degree
seed: Optional[int] = None
size: int = 500
def validate(self) -> None:
"""Validate configuration parameters."""
assert self.min_terms > 0, "min_terms must be positive."
assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms."
assert self.min_value > 0, "min_value must be positive."
assert self.max_value >= self.min_value, "max_value must be >= min_value."
assert self.min_degree >= 1, "min_degree must be >= 1."
assert self.max_degree >= self.min_degree, "max_degree must be >= min_degree."
assert self.min_polynomials >= 2, "min_polynomials must be >= 2."
assert self.max_polynomials >= self.min_polynomials, "max_polynomials must be >= min_polynomials."
allowed_ops = {"+", "-"}
assert len(self.operators) > 0, "operators tuple cannot be empty."
assert all(op in allowed_ops for op in self.operators), "Invalid operator found. Must be a subset of {+, -}."
class PolynomialMultiplicationDataset(ProceduralDataset):
"""
Generates [min_polynomials, max_polynomials] random polynomials of degree in [min_degree, max_degree].
- The polynomial is formed by summing random terms of the form: coeff * x^exponent.
- Then we find "F = P_0 * ... * P_1" using Sympy.
"""
def __init__(self, config: PolynomialMultiplicationConfig):
self._prompt_templates = [
"Simplify this expression: {polynomial_expr}",
"Calculate the following: {polynomial_expr}",
]
super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict:
"""
Generate a single polynomial multiplication item.
Returns:
A dict with:
- question: str (e.g. "Multiply polynomials: (8x^3 + x + 2)*(x - 3)")
- answer: str (Product, e.g. "8x^4 - 24x^3 + x^2 - x - 6")
- metadata: dict with details (polynomial_expr, single_variable)
"""
rng = random.Random(self.seed + idx)
number_polynomials = rng.randint(self.config.min_polynomials, self.config.max_polynomials)
polynomials = [self._generate_polynomial_expr(rng) for i in range(number_polynomials)]
polynomial_expr = sp.prod(polynomials)
product = sp.expand(polynomial_expr)
return {
"question": rng.choice(self._prompt_templates).format(
polynomial_expr=polynomial_expr,
),
"answer": product,
"metadata": {
"polynomial_expr": str(polynomial_expr),
"single_variable": self.config.single_variable,
"result": str(product),
},
}
def _get_variable(self, rng: random.Random) -> str:
"""Get a random lowercase variable name"""
if self.config.single_variable:
return "x"
return rng.choice(string.ascii_lowercase)
def _generate_polynomial_expr(self, rng: random.Random):
"""
Randomly generate a polynomial expression of 'degree'.
We'll use the config parameters:
- min_terms, max_terms: how many total terms to combine
- min_value, max_value: range for coefficients
- operators: to decide sign flips or direct addition
Args:
rng: Random number generator
Returns:
Polynomial string
"""
variable = self._get_variable(rng)
degree = rng.randint(self.config.min_degree, self.config.max_degree)
x = Symbol(variable)
# Choose the number of terms and their respective degrees
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
# Keep track of exponents, exponents can repeat or skip but we force the highest exponent
chosen_exponents = [degree]
# Fill the rest randomly in [0, degree]
for _ in range(num_terms - 1):
exp = rng.randint(0, degree)
chosen_exponents.append(exp)
# Now build the polynomial expression: sum_{term}( coeff * x^exponent ), with optional sign
polynomial_expr = 0
for exp in chosen_exponents:
coeff = rng.randint(self.config.min_value, self.config.max_value)
# If '-' in operators, we can randomly flip the sign
if "-" in self.config.operators and rng.random() < 0.5:
coeff = -coeff
term_expr = coeff * (x**exp)
polynomial_expr += term_expr
return polynomial_expr
def score_answer(self, answer: Optional[str], metadata: Dict[str, Any]) -> float:
reward = 0.0
if answer is not None:
try:
predicted_poly = sp.parse_expr(answer)
target_poly = sp.parse_expr(metadata["result"])
# Check if the difference simplifies to zero (i.e. they are equivalent).
if sp.simplify(predicted_poly - target_poly) == 0:
reward = 1.0
elif answer.strip():
reward = 0.05
else:
reward = 0.01
except Exception:
reward = 0.01
return reward
register_dataset("polynomial_multiplication", PolynomialMultiplicationDataset, PolynomialMultiplicationConfig)

View file

@ -0,0 +1,166 @@
import pytest
import sympy as sp
from reasoning_gym import create_dataset
from reasoning_gym.algebra.polynomial_multiplication import (
PolynomialMultiplicationConfig,
PolynomialMultiplicationDataset,
)
def test_polynomial_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
PolynomialMultiplicationConfig(min_terms=0).validate()
with pytest.raises(AssertionError):
PolynomialMultiplicationConfig(min_value=0).validate()
with pytest.raises(AssertionError):
PolynomialMultiplicationConfig(min_degree=0, max_degree=3).validate()
with pytest.raises(AssertionError):
PolynomialMultiplicationConfig(min_degree=4, max_degree=3).validate()
with pytest.raises(AssertionError):
PolynomialMultiplicationConfig(operators=("^",)).validate()
with pytest.raises(AssertionError):
PolynomialMultiplicationConfig(min_polynomials=5, max_polynomials=2).validate()
def test_polynomial_multiplication_dataset_basic():
"""Test dataset creation and length"""
dataset_size = 50
config = PolynomialMultiplicationConfig(
min_terms=2,
max_terms=3,
min_value=1,
max_value=5,
min_degree=1,
max_degree=2,
min_polynomials=2,
max_polynomials=3,
single_variable=True,
seed=42,
size=dataset_size,
)
dataset = PolynomialMultiplicationDataset(config)
assert len(dataset) == dataset_size
def test_polynomial_equations_dataset_items():
"""Test that generated items have correct structure"""
ds = create_dataset(
"polynomial_multiplication",
min_terms=2,
max_terms=3,
min_value=1,
max_value=5,
min_degree=1,
max_degree=2,
min_polynomials=2,
max_polynomials=5,
single_variable=False,
size=3,
seed=100,
)
for item in ds:
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Check metadata
assert isinstance(item["metadata"]["polynomial_expr"], str)
assert isinstance(item["metadata"]["single_variable"], bool)
# Check polynomial_expr existence
poly_str = item["metadata"]["polynomial_expr"]
# Ensure it can parse with sympy
sp.sympify(poly_str)
def test_polynomial_equations_dataset_deterministic():
"""Test dataset reproducibility with fixed seed."""
cfg = PolynomialMultiplicationConfig(seed=999, size=3)
ds1 = PolynomialMultiplicationDataset(cfg)
ds2 = PolynomialMultiplicationDataset(cfg)
for i in range(len(ds1)):
assert ds1[i] == ds2[i], "Polynomial datasets with same seed should match exactly."
def test_polynomial_solutions_evaluation():
"""Test that solution satisfy the polynomial multiplication."""
ds = create_dataset(
"polynomial_multiplication",
min_terms=2,
max_terms=4,
min_value=1,
max_value=10,
min_degree=1,
max_degree=3,
min_polynomials=2,
max_polynomials=5,
single_variable=False,
size=5,
seed=42,
)
for item in ds:
# Extract the polynomial expression
poly_str = item["metadata"]["polynomial_expr"]
# Get the polynomial product
poly_expr = sp.expand(poly_str)
# Verify that each solution satisfies the polynomial
assert poly_expr == item["answer"]
def test_score_function():
"""Test that solution satisfy the polynomial multiplication."""
ds = create_dataset(
"polynomial_multiplication",
min_terms=2,
max_terms=4,
min_value=1,
max_value=10,
min_degree=1,
max_degree=3,
min_polynomials=2,
max_polynomials=5,
single_variable=True,
size=1,
seed=42,
)
assert ds.score_answer(None, ds[0]["metadata"]) == 0.00
assert ds.score_answer("6*x**4 + 9*x**3 - 6*x**2 - 39*x - 45", ds[0]["metadata"]) == 1
assert ds.score_answer("Not a polynomial", ds[0]["metadata"]) == 0.01
assert ds.score_answer("x**4", ds[0]["metadata"]) == 0.05
def test_multivariate_score_function():
"""Test that solution satisfy the polynomial multiplication."""
ds = create_dataset(
"polynomial_multiplication",
min_terms=2,
max_terms=4,
min_value=1,
max_value=10,
min_degree=1,
max_degree=3,
min_polynomials=2,
max_polynomials=5,
single_variable=False,
size=1,
seed=42,
)
assert ds.score_answer(None, ds[0]["metadata"]) == 0.00
assert ds.score_answer("-27*a**3*c - 27*a**3 + 144*a*c + 144*a", ds[0]["metadata"]) == 1
assert ds.score_answer("Not a polynomial", ds[0]["metadata"]) == 0.01
assert ds.score_answer("x**4", ds[0]["metadata"]) == 0.05