mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-22 16:49:06 +00:00
Refactor PolynomialMultiplicationDataset and fix issues with score_answer
This commit is contained in:
parent
7bad77b426
commit
28fcf4d481
2 changed files with 40 additions and 76 deletions
|
|
@ -4,6 +4,7 @@ from dataclasses import dataclass
|
|||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import sympy as sp
|
||||
from sympy.polys.monomials import itermonomials
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
|
@ -18,7 +19,7 @@ class PolynomialMultiplicationConfig:
|
|||
max_terms: int = 4 # Maximum number of polynomial terms
|
||||
min_value: int = 1 # Minimum value for coefficients
|
||||
max_value: int = 100 # Maximum value for coefficients
|
||||
min_degree: int = 1 # Minimum polynomial degree
|
||||
min_degree: int = 0 # Minimum polynomial degree
|
||||
max_degree: int = 3 # Maximum polynomial degree
|
||||
min_polynomials: int = 2 # Minimum number of polynomials being multiplied
|
||||
max_polynomials: int = 3 # Maximum number of polynomials being multiplied
|
||||
|
|
@ -40,7 +41,7 @@ class PolynomialMultiplicationConfig:
|
|||
assert self.min_value > 0, "min_value must be positive."
|
||||
assert self.max_value >= self.min_value, "max_value must be >= min_value."
|
||||
|
||||
assert self.min_degree >= 1, "min_degree must be >= 1."
|
||||
assert self.min_degree >= 0, "min_degree must be >= 0."
|
||||
assert self.max_degree >= self.min_degree, "max_degree must be >= min_degree."
|
||||
|
||||
assert self.min_polynomials >= 2, "min_polynomials must be >= 2."
|
||||
|
|
@ -80,9 +81,22 @@ class PolynomialMultiplicationDataset(ProceduralDataset):
|
|||
- answer: str (Product, e.g. "8x^4 - 24x^3 + x^2 - x - 6")
|
||||
- metadata: dict with details (polynomial_expr, result, variables)
|
||||
"""
|
||||
|
||||
rng = random.Random(self.seed + idx)
|
||||
|
||||
polynomial_expr = sp.prod(self._generate_polynomial_product(rng))
|
||||
"""
|
||||
Three Monomial States:
|
||||
- allow_multivariate_polynomials == 1: list of multivariate monomials (e.g "xy" --> [x, y, xy, x**2, y**2])
|
||||
- allow_cross_variable_product == 1: None. Will generate a unique list of single variable monomials for each term
|
||||
- allow_cross_variable_product == 0: A shared list of monomials for each term (e.g "x" --> [x, x**2, 1])
|
||||
"""
|
||||
monomials = self._get_monomials(rng) if self.config.allow_cross_variable_product else None
|
||||
monomials = None if self.config.allow_cross_variable_product else self._get_monomials(rng)
|
||||
|
||||
number_polynomials = rng.randint(self.config.min_polynomials, self.config.max_polynomials)
|
||||
|
||||
polynomial_terms = [self._generate_polynomial(rng, monomials) for _ in range(number_polynomials)]
|
||||
polynomial_expr = sp.prod(polynomial_terms)
|
||||
product = sp.expand(polynomial_expr)
|
||||
|
||||
return {
|
||||
|
|
@ -97,26 +111,19 @@ class PolynomialMultiplicationDataset(ProceduralDataset):
|
|||
},
|
||||
}
|
||||
|
||||
def _get_variable(self, rng: random.Random) -> str:
|
||||
"""Get a random lowercase variable name"""
|
||||
return rng.choice(self.config.variables)
|
||||
|
||||
def _generate_polynomial_product(self, rng):
|
||||
"""Helper for selecting regular or multivariate polynomial. Returns expressions and unique variables."""
|
||||
|
||||
variable = None if self.config.allow_cross_variable_product else self._get_variable(rng)
|
||||
number_polynomials = rng.randint(self.config.min_polynomials, self.config.max_polynomials)
|
||||
|
||||
def _get_monomials(self, rng: random.Random) -> str:
|
||||
"""Get a list of monomials"""
|
||||
if self.config.allow_multivariate_polynomials:
|
||||
generated = [self._generate_multivariate_polynomial(rng) for _ in range(number_polynomials)]
|
||||
sym = sp.symbols(self.config.variables)
|
||||
else:
|
||||
generated = [self._generate_regular_polynomial(rng, variable) for _ in range(number_polynomials)]
|
||||
sym = [sp.symbols(rng.choice(self.config.variables))]
|
||||
monomials = list(itermonomials(sym, self.config.max_degree, self.config.min_degree))
|
||||
return monomials
|
||||
|
||||
return generated
|
||||
|
||||
def _generate_multivariate_polynomial(self, rng: random.Random):
|
||||
"""Generates a multivariate polynomial, returns variable set and expression."""
|
||||
def _generate_polynomial(self, rng: random.Random, monomials: Optional[list]):
|
||||
"""Generates a random polynomial, returns expression."""
|
||||
# Choose the number of terms and their respective degrees
|
||||
monomials = monomials if monomials else self._get_monomials(rng)
|
||||
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
|
||||
|
||||
polynomial_expr = 0
|
||||
|
|
@ -124,59 +131,14 @@ class PolynomialMultiplicationDataset(ProceduralDataset):
|
|||
# Pick a nonzero random coefficient between min_value and max_value.
|
||||
coeff = rng.randint(self.config.min_value, self.config.max_value)
|
||||
|
||||
# Build the monomial by choosing each exponent independently.
|
||||
monomial = 1
|
||||
var = self._get_variable(rng)
|
||||
for v in var:
|
||||
v = sp.Symbol(v)
|
||||
exp = random.randint(self.config.min_degree, self.config.max_degree)
|
||||
monomial *= v**exp
|
||||
# Pick a random monomial
|
||||
var = rng.choice(monomials)
|
||||
|
||||
# If '-' in operators, we can randomly flip the sign
|
||||
if "-" in self.config.operators and rng.random() < 0.5:
|
||||
coeff = -coeff
|
||||
|
||||
polynomial_expr += coeff * monomial
|
||||
|
||||
return polynomial_expr
|
||||
|
||||
def _generate_regular_polynomial(self, rng: random.Random, variable: Optional[str]):
|
||||
"""
|
||||
Randomly generate a polynomial expression of 'degree'.
|
||||
We'll use the config parameters:
|
||||
- min_terms, max_terms: how many total terms to combine
|
||||
- min_value, max_value: range for coefficients
|
||||
- operators: to decide sign flips or direct addition
|
||||
|
||||
Args:
|
||||
rng: Random number generator
|
||||
|
||||
Returns:
|
||||
Polynomial string
|
||||
"""
|
||||
variable = variable if variable else self._get_variable(rng)
|
||||
degree = rng.randint(self.config.min_degree, self.config.max_degree)
|
||||
|
||||
x = sp.Symbol(variable)
|
||||
|
||||
# Choose the number of terms and their respective degrees
|
||||
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
|
||||
# Keep track of exponents, exponents can repeat or skip but we force the highest exponent
|
||||
chosen_exponents = [degree]
|
||||
# Fill the rest randomly in [0, degree]
|
||||
for _ in range(num_terms - 1):
|
||||
exp = rng.randint(0, degree)
|
||||
chosen_exponents.append(exp)
|
||||
|
||||
# Now build the polynomial expression: sum_{term}( coeff * x^exponent ), with optional sign
|
||||
polynomial_expr = 0
|
||||
for exp in chosen_exponents:
|
||||
coeff = rng.randint(self.config.min_value, self.config.max_value)
|
||||
# If '-' in operators, we can randomly flip the sign
|
||||
if "-" in self.config.operators and rng.random() < 0.5:
|
||||
coeff = -coeff
|
||||
term_expr = coeff * (x**exp)
|
||||
polynomial_expr += term_expr
|
||||
polynomial_expr += coeff * var
|
||||
|
||||
return polynomial_expr
|
||||
|
||||
|
|
@ -185,8 +147,8 @@ class PolynomialMultiplicationDataset(ProceduralDataset):
|
|||
metadata = entry["metadata"]
|
||||
if answer is not None:
|
||||
try:
|
||||
predicted_poly = sp.Poly(answer)
|
||||
target_poly = sp.Poly(metadata["result"])
|
||||
predicted_poly = sp.parse_expr(answer)
|
||||
target_poly = sp.parse_expr(metadata["result"])
|
||||
|
||||
# Check if the difference simplifies to zero (i.e. they are equivalent).
|
||||
if predicted_poly == target_poly:
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ def test_polynomial_config_validation():
|
|||
PolynomialMultiplicationConfig(min_value=0).validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
PolynomialMultiplicationConfig(min_degree=0, max_degree=3).validate()
|
||||
PolynomialMultiplicationConfig(min_degree=-1, max_degree=3).validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
PolynomialMultiplicationConfig(min_degree=4, max_degree=3).validate()
|
||||
|
|
@ -31,7 +31,7 @@ def test_polynomial_config_validation():
|
|||
PolynomialMultiplicationConfig(min_polynomials=5, max_polynomials=2).validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
PolynomialMultiplicationConfig(variables=tuple("")).validate()
|
||||
PolynomialMultiplicationConfig(variables="").validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
PolynomialMultiplicationConfig(
|
||||
|
|
@ -183,7 +183,7 @@ def test_multivariate_polynomial_equations_dataset_items():
|
|||
max_degree=2,
|
||||
min_polynomials=2,
|
||||
max_polynomials=5,
|
||||
variables=tuple(["x", "y", "xy"]),
|
||||
variables=tuple(["x", "y"]),
|
||||
allow_cross_variable_product=True,
|
||||
allow_multivariate_polynomials=True,
|
||||
size=3,
|
||||
|
|
@ -228,7 +228,7 @@ def test_polynomial_solutions_evaluation():
|
|||
max_degree=3,
|
||||
min_polynomials=2,
|
||||
max_polynomials=5,
|
||||
variables=tuple(["x", "y", "xy"]),
|
||||
variables=tuple(["x", "y"]),
|
||||
allow_cross_variable_product=True,
|
||||
allow_multivariate_polynomials=True,
|
||||
size=5,
|
||||
|
|
@ -257,18 +257,20 @@ def test_score_function():
|
|||
max_degree=3,
|
||||
min_polynomials=3,
|
||||
max_polynomials=3,
|
||||
variables=tuple(["x", "y", "xy"]),
|
||||
variables=tuple(["x", "y"]),
|
||||
allow_cross_variable_product=True,
|
||||
allow_multivariate_polynomials=True,
|
||||
size=1,
|
||||
size=3,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
for item in ds:
|
||||
poly_str = item["metadata"]["polynomial_expr"]
|
||||
poly_expr = sp.expand(poly_str)
|
||||
assert ds.score_answer(poly_str, item) == 0.05
|
||||
|
||||
poly_expr = str(sp.expand(poly_str))
|
||||
assert ds.score_answer(poly_expr, item) == 1.0
|
||||
|
||||
assert ds.score_answer(None, item) == 0.00
|
||||
assert ds.score_answer("Not a polynomial", item) == 0.01
|
||||
assert ds.score_answer("x**4", item) == 0.05
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue