reasoning-gym/reasoning_gym/algebra/polynomial_multiplication.py
tohskai 847442ef0a
Add PolynomialMultiplicationDataset (#64)
* Add PolynomialMultiplicationDataset
2025-02-07 14:06:41 +01:00

161 lines
6.3 KiB
Python

import random
import string
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
import sympy as sp
from sympy import Eq, Symbol, expand, solve
from ..factory import ProceduralDataset, register_dataset
@dataclass
class PolynomialMultiplicationConfig:
"""
Configuration for polynomial multiplication task generation.
"""
min_terms: int = 2 # Minimum number of polynomial terms
max_terms: int = 4 # Maximum number of polynomial terms
min_value: int = 1 # Minimum value for coefficients
max_value: int = 100 # Maximum value for coefficients
min_degree: int = 1 # Minimum polynomial degree
max_degree: int = 3 # Maximum polynomial degree
min_polynomials: int = 2 # Minimum number of polynomials being multiplied
max_polynomials: int = 3 # Maximum number of polynomials being multiplied
single_variable: bool = (True,)
operators: Tuple[str, ...] = (
"+",
"-",
) # Allowed operators between terms, Avoid adding '*' or '/' because they will affect the degree
seed: Optional[int] = None
size: int = 500
def validate(self) -> None:
"""Validate configuration parameters."""
assert self.min_terms > 0, "min_terms must be positive."
assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms."
assert self.min_value > 0, "min_value must be positive."
assert self.max_value >= self.min_value, "max_value must be >= min_value."
assert self.min_degree >= 1, "min_degree must be >= 1."
assert self.max_degree >= self.min_degree, "max_degree must be >= min_degree."
assert self.min_polynomials >= 2, "min_polynomials must be >= 2."
assert self.max_polynomials >= self.min_polynomials, "max_polynomials must be >= min_polynomials."
allowed_ops = {"+", "-"}
assert len(self.operators) > 0, "operators tuple cannot be empty."
assert all(op in allowed_ops for op in self.operators), "Invalid operator found. Must be a subset of {+, -}."
class PolynomialMultiplicationDataset(ProceduralDataset):
"""
Generates [min_polynomials, max_polynomials] random polynomials of degree in [min_degree, max_degree].
- The polynomial is formed by summing random terms of the form: coeff * x^exponent.
- Then we find "F = P_0 * ... * P_1" using Sympy.
"""
def __init__(self, config: PolynomialMultiplicationConfig):
self._prompt_templates = [
"Simplify this expression: {polynomial_expr}",
"Calculate the following: {polynomial_expr}",
]
super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict:
"""
Generate a single polynomial multiplication item.
Returns:
A dict with:
- question: str (e.g. "Multiply polynomials: (8x^3 + x + 2)*(x - 3)")
- answer: str (Product, e.g. "8x^4 - 24x^3 + x^2 - x - 6")
- metadata: dict with details (polynomial_expr, single_variable)
"""
rng = random.Random(self.seed + idx)
number_polynomials = rng.randint(self.config.min_polynomials, self.config.max_polynomials)
polynomials = [self._generate_polynomial_expr(rng) for i in range(number_polynomials)]
polynomial_expr = sp.prod(polynomials)
product = sp.expand(polynomial_expr)
return {
"question": rng.choice(self._prompt_templates).format(
polynomial_expr=polynomial_expr,
),
"answer": product,
"metadata": {
"polynomial_expr": str(polynomial_expr),
"single_variable": self.config.single_variable,
"result": str(product),
},
}
def _get_variable(self, rng: random.Random) -> str:
"""Get a random lowercase variable name"""
if self.config.single_variable:
return "x"
return rng.choice(string.ascii_lowercase)
def _generate_polynomial_expr(self, rng: random.Random):
"""
Randomly generate a polynomial expression of 'degree'.
We'll use the config parameters:
- min_terms, max_terms: how many total terms to combine
- min_value, max_value: range for coefficients
- operators: to decide sign flips or direct addition
Args:
rng: Random number generator
Returns:
Polynomial string
"""
variable = self._get_variable(rng)
degree = rng.randint(self.config.min_degree, self.config.max_degree)
x = Symbol(variable)
# Choose the number of terms and their respective degrees
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
# Keep track of exponents, exponents can repeat or skip but we force the highest exponent
chosen_exponents = [degree]
# Fill the rest randomly in [0, degree]
for _ in range(num_terms - 1):
exp = rng.randint(0, degree)
chosen_exponents.append(exp)
# Now build the polynomial expression: sum_{term}( coeff * x^exponent ), with optional sign
polynomial_expr = 0
for exp in chosen_exponents:
coeff = rng.randint(self.config.min_value, self.config.max_value)
# If '-' in operators, we can randomly flip the sign
if "-" in self.config.operators and rng.random() < 0.5:
coeff = -coeff
term_expr = coeff * (x**exp)
polynomial_expr += term_expr
return polynomial_expr
def score_answer(self, answer: Optional[str], metadata: Dict[str, Any]) -> float:
reward = 0.0
if answer is not None:
try:
predicted_poly = sp.parse_expr(answer)
target_poly = sp.parse_expr(metadata["result"])
# Check if the difference simplifies to zero (i.e. they are equivalent).
if sp.simplify(predicted_poly - target_poly) == 0:
reward = 1.0
elif answer.strip():
reward = 0.05
else:
reward = 0.01
except Exception:
reward = 0.01
return reward
register_dataset("polynomial_multiplication", PolynomialMultiplicationDataset, PolynomialMultiplicationConfig)