Refactor PolynomialMultiplicationDataset and fix issues with score_answer

This commit is contained in:
tohskai 2025-02-17 17:04:48 +01:00
parent 7bad77b426
commit 28fcf4d481
2 changed files with 40 additions and 76 deletions

View file

@ -4,6 +4,7 @@ from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
import sympy as sp
from sympy.polys.monomials import itermonomials
from ..factory import ProceduralDataset, register_dataset
@ -18,7 +19,7 @@ class PolynomialMultiplicationConfig:
max_terms: int = 4 # Maximum number of polynomial terms
min_value: int = 1 # Minimum value for coefficients
max_value: int = 100 # Maximum value for coefficients
min_degree: int = 1 # Minimum polynomial degree
min_degree: int = 0 # Minimum polynomial degree
max_degree: int = 3 # Maximum polynomial degree
min_polynomials: int = 2 # Minimum number of polynomials being multiplied
max_polynomials: int = 3 # Maximum number of polynomials being multiplied
@ -40,7 +41,7 @@ class PolynomialMultiplicationConfig:
assert self.min_value > 0, "min_value must be positive."
assert self.max_value >= self.min_value, "max_value must be >= min_value."
assert self.min_degree >= 1, "min_degree must be >= 1."
assert self.min_degree >= 0, "min_degree must be >= 0."
assert self.max_degree >= self.min_degree, "max_degree must be >= min_degree."
assert self.min_polynomials >= 2, "min_polynomials must be >= 2."
@ -80,9 +81,22 @@ class PolynomialMultiplicationDataset(ProceduralDataset):
- answer: str (Product, e.g. "8x^4 - 24x^3 + x^2 - x - 6")
- metadata: dict with details (polynomial_expr, result, variables)
"""
rng = random.Random(self.seed + idx)
polynomial_expr = sp.prod(self._generate_polynomial_product(rng))
"""
Three Monomial States:
- allow_multivariate_polynomials == 1: list of multivariate monomials (e.g "xy" --> [x, y, xy, x**2, y**2])
- allow_cross_variable_product == 1: None. Will generate a unique list of single variable monomials for each term
- allow_cross_variable_product == 0: A shared list of monomials for each term (e.g "x" --> [x, x**2, 1])
"""
monomials = self._get_monomials(rng) if self.config.allow_cross_variable_product else None
monomials = None if self.config.allow_cross_variable_product else self._get_monomials(rng)
number_polynomials = rng.randint(self.config.min_polynomials, self.config.max_polynomials)
polynomial_terms = [self._generate_polynomial(rng, monomials) for _ in range(number_polynomials)]
polynomial_expr = sp.prod(polynomial_terms)
product = sp.expand(polynomial_expr)
return {
@ -97,26 +111,19 @@ class PolynomialMultiplicationDataset(ProceduralDataset):
},
}
def _get_variable(self, rng: random.Random) -> str:
"""Get a random lowercase variable name"""
return rng.choice(self.config.variables)
def _generate_polynomial_product(self, rng):
"""Helper for selecting regular or multivariate polynomial. Returns expressions and unique variables."""
variable = None if self.config.allow_cross_variable_product else self._get_variable(rng)
number_polynomials = rng.randint(self.config.min_polynomials, self.config.max_polynomials)
def _get_monomials(self, rng: random.Random) -> str:
"""Get a list of monomials"""
if self.config.allow_multivariate_polynomials:
generated = [self._generate_multivariate_polynomial(rng) for _ in range(number_polynomials)]
sym = sp.symbols(self.config.variables)
else:
generated = [self._generate_regular_polynomial(rng, variable) for _ in range(number_polynomials)]
sym = [sp.symbols(rng.choice(self.config.variables))]
monomials = list(itermonomials(sym, self.config.max_degree, self.config.min_degree))
return monomials
return generated
def _generate_multivariate_polynomial(self, rng: random.Random):
"""Generates a multivariate polynomial, returns variable set and expression."""
def _generate_polynomial(self, rng: random.Random, monomials: Optional[list]):
"""Generates a random polynomial, returns expression."""
# Choose the number of terms and their respective degrees
monomials = monomials if monomials else self._get_monomials(rng)
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
polynomial_expr = 0
@ -124,59 +131,14 @@ class PolynomialMultiplicationDataset(ProceduralDataset):
# Pick a nonzero random coefficient between min_value and max_value.
coeff = rng.randint(self.config.min_value, self.config.max_value)
# Build the monomial by choosing each exponent independently.
monomial = 1
var = self._get_variable(rng)
for v in var:
v = sp.Symbol(v)
exp = random.randint(self.config.min_degree, self.config.max_degree)
monomial *= v**exp
# Pick a random monomial
var = rng.choice(monomials)
# If '-' in operators, we can randomly flip the sign
if "-" in self.config.operators and rng.random() < 0.5:
coeff = -coeff
polynomial_expr += coeff * monomial
return polynomial_expr
def _generate_regular_polynomial(self, rng: random.Random, variable: Optional[str]):
"""
Randomly generate a polynomial expression of 'degree'.
We'll use the config parameters:
- min_terms, max_terms: how many total terms to combine
- min_value, max_value: range for coefficients
- operators: to decide sign flips or direct addition
Args:
rng: Random number generator
Returns:
Polynomial string
"""
variable = variable if variable else self._get_variable(rng)
degree = rng.randint(self.config.min_degree, self.config.max_degree)
x = sp.Symbol(variable)
# Choose the number of terms and their respective degrees
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
# Keep track of exponents, exponents can repeat or skip but we force the highest exponent
chosen_exponents = [degree]
# Fill the rest randomly in [0, degree]
for _ in range(num_terms - 1):
exp = rng.randint(0, degree)
chosen_exponents.append(exp)
# Now build the polynomial expression: sum_{term}( coeff * x^exponent ), with optional sign
polynomial_expr = 0
for exp in chosen_exponents:
coeff = rng.randint(self.config.min_value, self.config.max_value)
# If '-' in operators, we can randomly flip the sign
if "-" in self.config.operators and rng.random() < 0.5:
coeff = -coeff
term_expr = coeff * (x**exp)
polynomial_expr += term_expr
polynomial_expr += coeff * var
return polynomial_expr
@ -185,8 +147,8 @@ class PolynomialMultiplicationDataset(ProceduralDataset):
metadata = entry["metadata"]
if answer is not None:
try:
predicted_poly = sp.Poly(answer)
target_poly = sp.Poly(metadata["result"])
predicted_poly = sp.parse_expr(answer)
target_poly = sp.parse_expr(metadata["result"])
# Check if the difference simplifies to zero (i.e. they are equivalent).
if predicted_poly == target_poly: