Refactor PolynomialMultiplicationDataset and fix issues with score_answer

This commit is contained in:
tohskai 2025-02-17 17:04:48 +01:00
parent 7bad77b426
commit 28fcf4d481
2 changed files with 40 additions and 76 deletions

View file

@ -4,6 +4,7 @@ from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
import sympy as sp import sympy as sp
from sympy.polys.monomials import itermonomials
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
@ -18,7 +19,7 @@ class PolynomialMultiplicationConfig:
max_terms: int = 4 # Maximum number of polynomial terms max_terms: int = 4 # Maximum number of polynomial terms
min_value: int = 1 # Minimum value for coefficients min_value: int = 1 # Minimum value for coefficients
max_value: int = 100 # Maximum value for coefficients max_value: int = 100 # Maximum value for coefficients
min_degree: int = 1 # Minimum polynomial degree min_degree: int = 0 # Minimum polynomial degree
max_degree: int = 3 # Maximum polynomial degree max_degree: int = 3 # Maximum polynomial degree
min_polynomials: int = 2 # Minimum number of polynomials being multiplied min_polynomials: int = 2 # Minimum number of polynomials being multiplied
max_polynomials: int = 3 # Maximum number of polynomials being multiplied max_polynomials: int = 3 # Maximum number of polynomials being multiplied
@ -40,7 +41,7 @@ class PolynomialMultiplicationConfig:
assert self.min_value > 0, "min_value must be positive." assert self.min_value > 0, "min_value must be positive."
assert self.max_value >= self.min_value, "max_value must be >= min_value." assert self.max_value >= self.min_value, "max_value must be >= min_value."
assert self.min_degree >= 1, "min_degree must be >= 1." assert self.min_degree >= 0, "min_degree must be >= 0."
assert self.max_degree >= self.min_degree, "max_degree must be >= min_degree." assert self.max_degree >= self.min_degree, "max_degree must be >= min_degree."
assert self.min_polynomials >= 2, "min_polynomials must be >= 2." assert self.min_polynomials >= 2, "min_polynomials must be >= 2."
@ -80,9 +81,22 @@ class PolynomialMultiplicationDataset(ProceduralDataset):
- answer: str (Product, e.g. "8x^4 - 24x^3 + x^2 - x - 6") - answer: str (Product, e.g. "8x^4 - 24x^3 + x^2 - x - 6")
- metadata: dict with details (polynomial_expr, result, variables) - metadata: dict with details (polynomial_expr, result, variables)
""" """
rng = random.Random(self.seed + idx) rng = random.Random(self.seed + idx)
polynomial_expr = sp.prod(self._generate_polynomial_product(rng)) """
Three Monomial States:
- allow_multivariate_polynomials == 1: list of multivariate monomials (e.g "xy" --> [x, y, xy, x**2, y**2])
- allow_cross_variable_product == 1: None. Will generate a unique list of single variable monomials for each term
- allow_cross_variable_product == 0: A shared list of monomials for each term (e.g "x" --> [x, x**2, 1])
"""
monomials = self._get_monomials(rng) if self.config.allow_cross_variable_product else None
monomials = None if self.config.allow_cross_variable_product else self._get_monomials(rng)
number_polynomials = rng.randint(self.config.min_polynomials, self.config.max_polynomials)
polynomial_terms = [self._generate_polynomial(rng, monomials) for _ in range(number_polynomials)]
polynomial_expr = sp.prod(polynomial_terms)
product = sp.expand(polynomial_expr) product = sp.expand(polynomial_expr)
return { return {
@ -97,26 +111,19 @@ class PolynomialMultiplicationDataset(ProceduralDataset):
}, },
} }
def _get_variable(self, rng: random.Random) -> str: def _get_monomials(self, rng: random.Random) -> str:
"""Get a random lowercase variable name""" """Get a list of monomials"""
return rng.choice(self.config.variables)
def _generate_polynomial_product(self, rng):
"""Helper for selecting regular or multivariate polynomial. Returns expressions and unique variables."""
variable = None if self.config.allow_cross_variable_product else self._get_variable(rng)
number_polynomials = rng.randint(self.config.min_polynomials, self.config.max_polynomials)
if self.config.allow_multivariate_polynomials: if self.config.allow_multivariate_polynomials:
generated = [self._generate_multivariate_polynomial(rng) for _ in range(number_polynomials)] sym = sp.symbols(self.config.variables)
else: else:
generated = [self._generate_regular_polynomial(rng, variable) for _ in range(number_polynomials)] sym = [sp.symbols(rng.choice(self.config.variables))]
monomials = list(itermonomials(sym, self.config.max_degree, self.config.min_degree))
return monomials
return generated def _generate_polynomial(self, rng: random.Random, monomials: Optional[list]):
"""Generates a random polynomial, returns expression."""
def _generate_multivariate_polynomial(self, rng: random.Random):
"""Generates a multivariate polynomial, returns variable set and expression."""
# Choose the number of terms and their respective degrees # Choose the number of terms and their respective degrees
monomials = monomials if monomials else self._get_monomials(rng)
num_terms = rng.randint(self.config.min_terms, self.config.max_terms) num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
polynomial_expr = 0 polynomial_expr = 0
@ -124,59 +131,14 @@ class PolynomialMultiplicationDataset(ProceduralDataset):
# Pick a nonzero random coefficient between min_value and max_value. # Pick a nonzero random coefficient between min_value and max_value.
coeff = rng.randint(self.config.min_value, self.config.max_value) coeff = rng.randint(self.config.min_value, self.config.max_value)
# Build the monomial by choosing each exponent independently. # Pick a random monomial
monomial = 1 var = rng.choice(monomials)
var = self._get_variable(rng)
for v in var:
v = sp.Symbol(v)
exp = random.randint(self.config.min_degree, self.config.max_degree)
monomial *= v**exp
# If '-' in operators, we can randomly flip the sign # If '-' in operators, we can randomly flip the sign
if "-" in self.config.operators and rng.random() < 0.5: if "-" in self.config.operators and rng.random() < 0.5:
coeff = -coeff coeff = -coeff
polynomial_expr += coeff * monomial polynomial_expr += coeff * var
return polynomial_expr
def _generate_regular_polynomial(self, rng: random.Random, variable: Optional[str]):
"""
Randomly generate a polynomial expression of 'degree'.
We'll use the config parameters:
- min_terms, max_terms: how many total terms to combine
- min_value, max_value: range for coefficients
- operators: to decide sign flips or direct addition
Args:
rng: Random number generator
Returns:
Polynomial string
"""
variable = variable if variable else self._get_variable(rng)
degree = rng.randint(self.config.min_degree, self.config.max_degree)
x = sp.Symbol(variable)
# Choose the number of terms and their respective degrees
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
# Keep track of exponents, exponents can repeat or skip but we force the highest exponent
chosen_exponents = [degree]
# Fill the rest randomly in [0, degree]
for _ in range(num_terms - 1):
exp = rng.randint(0, degree)
chosen_exponents.append(exp)
# Now build the polynomial expression: sum_{term}( coeff * x^exponent ), with optional sign
polynomial_expr = 0
for exp in chosen_exponents:
coeff = rng.randint(self.config.min_value, self.config.max_value)
# If '-' in operators, we can randomly flip the sign
if "-" in self.config.operators and rng.random() < 0.5:
coeff = -coeff
term_expr = coeff * (x**exp)
polynomial_expr += term_expr
return polynomial_expr return polynomial_expr
@ -185,8 +147,8 @@ class PolynomialMultiplicationDataset(ProceduralDataset):
metadata = entry["metadata"] metadata = entry["metadata"]
if answer is not None: if answer is not None:
try: try:
predicted_poly = sp.Poly(answer) predicted_poly = sp.parse_expr(answer)
target_poly = sp.Poly(metadata["result"]) target_poly = sp.parse_expr(metadata["result"])
# Check if the difference simplifies to zero (i.e. they are equivalent). # Check if the difference simplifies to zero (i.e. they are equivalent).
if predicted_poly == target_poly: if predicted_poly == target_poly:

View file

@ -19,7 +19,7 @@ def test_polynomial_config_validation():
PolynomialMultiplicationConfig(min_value=0).validate() PolynomialMultiplicationConfig(min_value=0).validate()
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
PolynomialMultiplicationConfig(min_degree=0, max_degree=3).validate() PolynomialMultiplicationConfig(min_degree=-1, max_degree=3).validate()
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
PolynomialMultiplicationConfig(min_degree=4, max_degree=3).validate() PolynomialMultiplicationConfig(min_degree=4, max_degree=3).validate()
@ -31,7 +31,7 @@ def test_polynomial_config_validation():
PolynomialMultiplicationConfig(min_polynomials=5, max_polynomials=2).validate() PolynomialMultiplicationConfig(min_polynomials=5, max_polynomials=2).validate()
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
PolynomialMultiplicationConfig(variables=tuple("")).validate() PolynomialMultiplicationConfig(variables="").validate()
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
PolynomialMultiplicationConfig( PolynomialMultiplicationConfig(
@ -183,7 +183,7 @@ def test_multivariate_polynomial_equations_dataset_items():
max_degree=2, max_degree=2,
min_polynomials=2, min_polynomials=2,
max_polynomials=5, max_polynomials=5,
variables=tuple(["x", "y", "xy"]), variables=tuple(["x", "y"]),
allow_cross_variable_product=True, allow_cross_variable_product=True,
allow_multivariate_polynomials=True, allow_multivariate_polynomials=True,
size=3, size=3,
@ -228,7 +228,7 @@ def test_polynomial_solutions_evaluation():
max_degree=3, max_degree=3,
min_polynomials=2, min_polynomials=2,
max_polynomials=5, max_polynomials=5,
variables=tuple(["x", "y", "xy"]), variables=tuple(["x", "y"]),
allow_cross_variable_product=True, allow_cross_variable_product=True,
allow_multivariate_polynomials=True, allow_multivariate_polynomials=True,
size=5, size=5,
@ -257,18 +257,20 @@ def test_score_function():
max_degree=3, max_degree=3,
min_polynomials=3, min_polynomials=3,
max_polynomials=3, max_polynomials=3,
variables=tuple(["x", "y", "xy"]), variables=tuple(["x", "y"]),
allow_cross_variable_product=True, allow_cross_variable_product=True,
allow_multivariate_polynomials=True, allow_multivariate_polynomials=True,
size=1, size=3,
seed=42, seed=42,
) )
for item in ds: for item in ds:
poly_str = item["metadata"]["polynomial_expr"] poly_str = item["metadata"]["polynomial_expr"]
poly_expr = sp.expand(poly_str) assert ds.score_answer(poly_str, item) == 0.05
poly_expr = str(sp.expand(poly_str))
assert ds.score_answer(poly_expr, item) == 1.0 assert ds.score_answer(poly_expr, item) == 1.0
assert ds.score_answer(None, item) == 0.00 assert ds.score_answer(None, item) == 0.00
assert ds.score_answer("Not a polynomial", item) == 0.01 assert ds.score_answer("Not a polynomial", item) == 0.01
assert ds.score_answer("x**4", item) == 0.05 assert ds.score_answer("x**4", item) == 0.05