Refactor PolynomialMultiplicationDataset and fix issues with score_answer

This commit is contained in:
tohskai 2025-02-17 17:04:48 +01:00
parent 7bad77b426
commit 28fcf4d481
2 changed files with 40 additions and 76 deletions

View file

@ -4,6 +4,7 @@ from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
import sympy as sp
from sympy.polys.monomials import itermonomials
from ..factory import ProceduralDataset, register_dataset
@ -18,7 +19,7 @@ class PolynomialMultiplicationConfig:
max_terms: int = 4 # Maximum number of polynomial terms
min_value: int = 1 # Minimum value for coefficients
max_value: int = 100 # Maximum value for coefficients
min_degree: int = 1 # Minimum polynomial degree
min_degree: int = 0 # Minimum polynomial degree
max_degree: int = 3 # Maximum polynomial degree
min_polynomials: int = 2 # Minimum number of polynomials being multiplied
max_polynomials: int = 3 # Maximum number of polynomials being multiplied
@ -40,7 +41,7 @@ class PolynomialMultiplicationConfig:
assert self.min_value > 0, "min_value must be positive."
assert self.max_value >= self.min_value, "max_value must be >= min_value."
assert self.min_degree >= 1, "min_degree must be >= 1."
assert self.min_degree >= 0, "min_degree must be >= 0."
assert self.max_degree >= self.min_degree, "max_degree must be >= min_degree."
assert self.min_polynomials >= 2, "min_polynomials must be >= 2."
@ -80,9 +81,22 @@ class PolynomialMultiplicationDataset(ProceduralDataset):
- answer: str (Product, e.g. "8x^4 - 24x^3 + x^2 - x - 6")
- metadata: dict with details (polynomial_expr, result, variables)
"""
rng = random.Random(self.seed + idx)
polynomial_expr = sp.prod(self._generate_polynomial_product(rng))
"""
Three Monomial States:
- allow_multivariate_polynomials == 1: list of multivariate monomials (e.g "xy" --> [x, y, xy, x**2, y**2])
- allow_cross_variable_product == 1: None. Will generate a unique list of single variable monomials for each term
- allow_cross_variable_product == 0: A shared list of monomials for each term (e.g "x" --> [x, x**2, 1])
"""
monomials = self._get_monomials(rng) if self.config.allow_cross_variable_product else None
monomials = None if self.config.allow_cross_variable_product else self._get_monomials(rng)
number_polynomials = rng.randint(self.config.min_polynomials, self.config.max_polynomials)
polynomial_terms = [self._generate_polynomial(rng, monomials) for _ in range(number_polynomials)]
polynomial_expr = sp.prod(polynomial_terms)
product = sp.expand(polynomial_expr)
return {
@ -97,26 +111,19 @@ class PolynomialMultiplicationDataset(ProceduralDataset):
},
}
def _get_variable(self, rng: random.Random) -> str:
"""Get a random lowercase variable name"""
return rng.choice(self.config.variables)
def _generate_polynomial_product(self, rng):
"""Helper for selecting regular or multivariate polynomial. Returns expressions and unique variables."""
variable = None if self.config.allow_cross_variable_product else self._get_variable(rng)
number_polynomials = rng.randint(self.config.min_polynomials, self.config.max_polynomials)
def _get_monomials(self, rng: random.Random) -> str:
"""Get a list of monomials"""
if self.config.allow_multivariate_polynomials:
generated = [self._generate_multivariate_polynomial(rng) for _ in range(number_polynomials)]
sym = sp.symbols(self.config.variables)
else:
generated = [self._generate_regular_polynomial(rng, variable) for _ in range(number_polynomials)]
sym = [sp.symbols(rng.choice(self.config.variables))]
monomials = list(itermonomials(sym, self.config.max_degree, self.config.min_degree))
return monomials
return generated
def _generate_multivariate_polynomial(self, rng: random.Random):
"""Generates a multivariate polynomial, returns variable set and expression."""
def _generate_polynomial(self, rng: random.Random, monomials: Optional[list]):
"""Generates a random polynomial, returns expression."""
# Choose the number of terms and their respective degrees
monomials = monomials if monomials else self._get_monomials(rng)
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
polynomial_expr = 0
@ -124,59 +131,14 @@ class PolynomialMultiplicationDataset(ProceduralDataset):
# Pick a nonzero random coefficient between min_value and max_value.
coeff = rng.randint(self.config.min_value, self.config.max_value)
# Build the monomial by choosing each exponent independently.
monomial = 1
var = self._get_variable(rng)
for v in var:
v = sp.Symbol(v)
exp = random.randint(self.config.min_degree, self.config.max_degree)
monomial *= v**exp
# Pick a random monomial
var = rng.choice(monomials)
# If '-' in operators, we can randomly flip the sign
if "-" in self.config.operators and rng.random() < 0.5:
coeff = -coeff
polynomial_expr += coeff * monomial
return polynomial_expr
def _generate_regular_polynomial(self, rng: random.Random, variable: Optional[str]):
"""
Randomly generate a polynomial expression of 'degree'.
We'll use the config parameters:
- min_terms, max_terms: how many total terms to combine
- min_value, max_value: range for coefficients
- operators: to decide sign flips or direct addition
Args:
rng: Random number generator
Returns:
Polynomial string
"""
variable = variable if variable else self._get_variable(rng)
degree = rng.randint(self.config.min_degree, self.config.max_degree)
x = sp.Symbol(variable)
# Choose the number of terms and their respective degrees
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
# Keep track of exponents, exponents can repeat or skip but we force the highest exponent
chosen_exponents = [degree]
# Fill the rest randomly in [0, degree]
for _ in range(num_terms - 1):
exp = rng.randint(0, degree)
chosen_exponents.append(exp)
# Now build the polynomial expression: sum_{term}( coeff * x^exponent ), with optional sign
polynomial_expr = 0
for exp in chosen_exponents:
coeff = rng.randint(self.config.min_value, self.config.max_value)
# If '-' in operators, we can randomly flip the sign
if "-" in self.config.operators and rng.random() < 0.5:
coeff = -coeff
term_expr = coeff * (x**exp)
polynomial_expr += term_expr
polynomial_expr += coeff * var
return polynomial_expr
@ -185,8 +147,8 @@ class PolynomialMultiplicationDataset(ProceduralDataset):
metadata = entry["metadata"]
if answer is not None:
try:
predicted_poly = sp.Poly(answer)
target_poly = sp.Poly(metadata["result"])
predicted_poly = sp.parse_expr(answer)
target_poly = sp.parse_expr(metadata["result"])
# Check if the difference simplifies to zero (i.e. they are equivalent).
if predicted_poly == target_poly:

View file

@ -19,7 +19,7 @@ def test_polynomial_config_validation():
PolynomialMultiplicationConfig(min_value=0).validate()
with pytest.raises(AssertionError):
PolynomialMultiplicationConfig(min_degree=0, max_degree=3).validate()
PolynomialMultiplicationConfig(min_degree=-1, max_degree=3).validate()
with pytest.raises(AssertionError):
PolynomialMultiplicationConfig(min_degree=4, max_degree=3).validate()
@ -31,7 +31,7 @@ def test_polynomial_config_validation():
PolynomialMultiplicationConfig(min_polynomials=5, max_polynomials=2).validate()
with pytest.raises(AssertionError):
PolynomialMultiplicationConfig(variables=tuple("")).validate()
PolynomialMultiplicationConfig(variables="").validate()
with pytest.raises(AssertionError):
PolynomialMultiplicationConfig(
@ -183,7 +183,7 @@ def test_multivariate_polynomial_equations_dataset_items():
max_degree=2,
min_polynomials=2,
max_polynomials=5,
variables=tuple(["x", "y", "xy"]),
variables=tuple(["x", "y"]),
allow_cross_variable_product=True,
allow_multivariate_polynomials=True,
size=3,
@ -228,7 +228,7 @@ def test_polynomial_solutions_evaluation():
max_degree=3,
min_polynomials=2,
max_polynomials=5,
variables=tuple(["x", "y", "xy"]),
variables=tuple(["x", "y"]),
allow_cross_variable_product=True,
allow_multivariate_polynomials=True,
size=5,
@ -257,18 +257,20 @@ def test_score_function():
max_degree=3,
min_polynomials=3,
max_polynomials=3,
variables=tuple(["x", "y", "xy"]),
variables=tuple(["x", "y"]),
allow_cross_variable_product=True,
allow_multivariate_polynomials=True,
size=1,
size=3,
seed=42,
)
for item in ds:
poly_str = item["metadata"]["polynomial_expr"]
poly_expr = sp.expand(poly_str)
assert ds.score_answer(poly_str, item) == 0.05
poly_expr = str(sp.expand(poly_str))
assert ds.score_answer(poly_expr, item) == 1.0
assert ds.score_answer(None, item) == 0.00
assert ds.score_answer("Not a polynomial", item) == 0.01
assert ds.score_answer("x**4", item) == 0.05