mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-22 16:49:06 +00:00
Refactor PolynomialMultiplicationDataset and fix issues with score_answer
This commit is contained in:
parent
7bad77b426
commit
28fcf4d481
2 changed files with 40 additions and 76 deletions
|
|
@ -4,6 +4,7 @@ from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
import sympy as sp
|
import sympy as sp
|
||||||
|
from sympy.polys.monomials import itermonomials
|
||||||
|
|
||||||
from ..factory import ProceduralDataset, register_dataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
@ -18,7 +19,7 @@ class PolynomialMultiplicationConfig:
|
||||||
max_terms: int = 4 # Maximum number of polynomial terms
|
max_terms: int = 4 # Maximum number of polynomial terms
|
||||||
min_value: int = 1 # Minimum value for coefficients
|
min_value: int = 1 # Minimum value for coefficients
|
||||||
max_value: int = 100 # Maximum value for coefficients
|
max_value: int = 100 # Maximum value for coefficients
|
||||||
min_degree: int = 1 # Minimum polynomial degree
|
min_degree: int = 0 # Minimum polynomial degree
|
||||||
max_degree: int = 3 # Maximum polynomial degree
|
max_degree: int = 3 # Maximum polynomial degree
|
||||||
min_polynomials: int = 2 # Minimum number of polynomials being multiplied
|
min_polynomials: int = 2 # Minimum number of polynomials being multiplied
|
||||||
max_polynomials: int = 3 # Maximum number of polynomials being multiplied
|
max_polynomials: int = 3 # Maximum number of polynomials being multiplied
|
||||||
|
|
@ -40,7 +41,7 @@ class PolynomialMultiplicationConfig:
|
||||||
assert self.min_value > 0, "min_value must be positive."
|
assert self.min_value > 0, "min_value must be positive."
|
||||||
assert self.max_value >= self.min_value, "max_value must be >= min_value."
|
assert self.max_value >= self.min_value, "max_value must be >= min_value."
|
||||||
|
|
||||||
assert self.min_degree >= 1, "min_degree must be >= 1."
|
assert self.min_degree >= 0, "min_degree must be >= 0."
|
||||||
assert self.max_degree >= self.min_degree, "max_degree must be >= min_degree."
|
assert self.max_degree >= self.min_degree, "max_degree must be >= min_degree."
|
||||||
|
|
||||||
assert self.min_polynomials >= 2, "min_polynomials must be >= 2."
|
assert self.min_polynomials >= 2, "min_polynomials must be >= 2."
|
||||||
|
|
@ -80,9 +81,22 @@ class PolynomialMultiplicationDataset(ProceduralDataset):
|
||||||
- answer: str (Product, e.g. "8x^4 - 24x^3 + x^2 - x - 6")
|
- answer: str (Product, e.g. "8x^4 - 24x^3 + x^2 - x - 6")
|
||||||
- metadata: dict with details (polynomial_expr, result, variables)
|
- metadata: dict with details (polynomial_expr, result, variables)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
rng = random.Random(self.seed + idx)
|
rng = random.Random(self.seed + idx)
|
||||||
|
|
||||||
polynomial_expr = sp.prod(self._generate_polynomial_product(rng))
|
"""
|
||||||
|
Three Monomial States:
|
||||||
|
- allow_multivariate_polynomials == 1: list of multivariate monomials (e.g "xy" --> [x, y, xy, x**2, y**2])
|
||||||
|
- allow_cross_variable_product == 1: None. Will generate a unique list of single variable monomials for each term
|
||||||
|
- allow_cross_variable_product == 0: A shared list of monomials for each term (e.g "x" --> [x, x**2, 1])
|
||||||
|
"""
|
||||||
|
monomials = self._get_monomials(rng) if self.config.allow_cross_variable_product else None
|
||||||
|
monomials = None if self.config.allow_cross_variable_product else self._get_monomials(rng)
|
||||||
|
|
||||||
|
number_polynomials = rng.randint(self.config.min_polynomials, self.config.max_polynomials)
|
||||||
|
|
||||||
|
polynomial_terms = [self._generate_polynomial(rng, monomials) for _ in range(number_polynomials)]
|
||||||
|
polynomial_expr = sp.prod(polynomial_terms)
|
||||||
product = sp.expand(polynomial_expr)
|
product = sp.expand(polynomial_expr)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
@ -97,26 +111,19 @@ class PolynomialMultiplicationDataset(ProceduralDataset):
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def _get_variable(self, rng: random.Random) -> str:
|
def _get_monomials(self, rng: random.Random) -> str:
|
||||||
"""Get a random lowercase variable name"""
|
"""Get a list of monomials"""
|
||||||
return rng.choice(self.config.variables)
|
|
||||||
|
|
||||||
def _generate_polynomial_product(self, rng):
|
|
||||||
"""Helper for selecting regular or multivariate polynomial. Returns expressions and unique variables."""
|
|
||||||
|
|
||||||
variable = None if self.config.allow_cross_variable_product else self._get_variable(rng)
|
|
||||||
number_polynomials = rng.randint(self.config.min_polynomials, self.config.max_polynomials)
|
|
||||||
|
|
||||||
if self.config.allow_multivariate_polynomials:
|
if self.config.allow_multivariate_polynomials:
|
||||||
generated = [self._generate_multivariate_polynomial(rng) for _ in range(number_polynomials)]
|
sym = sp.symbols(self.config.variables)
|
||||||
else:
|
else:
|
||||||
generated = [self._generate_regular_polynomial(rng, variable) for _ in range(number_polynomials)]
|
sym = [sp.symbols(rng.choice(self.config.variables))]
|
||||||
|
monomials = list(itermonomials(sym, self.config.max_degree, self.config.min_degree))
|
||||||
|
return monomials
|
||||||
|
|
||||||
return generated
|
def _generate_polynomial(self, rng: random.Random, monomials: Optional[list]):
|
||||||
|
"""Generates a random polynomial, returns expression."""
|
||||||
def _generate_multivariate_polynomial(self, rng: random.Random):
|
|
||||||
"""Generates a multivariate polynomial, returns variable set and expression."""
|
|
||||||
# Choose the number of terms and their respective degrees
|
# Choose the number of terms and their respective degrees
|
||||||
|
monomials = monomials if monomials else self._get_monomials(rng)
|
||||||
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
|
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
|
||||||
|
|
||||||
polynomial_expr = 0
|
polynomial_expr = 0
|
||||||
|
|
@ -124,59 +131,14 @@ class PolynomialMultiplicationDataset(ProceduralDataset):
|
||||||
# Pick a nonzero random coefficient between min_value and max_value.
|
# Pick a nonzero random coefficient between min_value and max_value.
|
||||||
coeff = rng.randint(self.config.min_value, self.config.max_value)
|
coeff = rng.randint(self.config.min_value, self.config.max_value)
|
||||||
|
|
||||||
# Build the monomial by choosing each exponent independently.
|
# Pick a random monomial
|
||||||
monomial = 1
|
var = rng.choice(monomials)
|
||||||
var = self._get_variable(rng)
|
|
||||||
for v in var:
|
|
||||||
v = sp.Symbol(v)
|
|
||||||
exp = random.randint(self.config.min_degree, self.config.max_degree)
|
|
||||||
monomial *= v**exp
|
|
||||||
|
|
||||||
# If '-' in operators, we can randomly flip the sign
|
# If '-' in operators, we can randomly flip the sign
|
||||||
if "-" in self.config.operators and rng.random() < 0.5:
|
if "-" in self.config.operators and rng.random() < 0.5:
|
||||||
coeff = -coeff
|
coeff = -coeff
|
||||||
|
|
||||||
polynomial_expr += coeff * monomial
|
polynomial_expr += coeff * var
|
||||||
|
|
||||||
return polynomial_expr
|
|
||||||
|
|
||||||
def _generate_regular_polynomial(self, rng: random.Random, variable: Optional[str]):
|
|
||||||
"""
|
|
||||||
Randomly generate a polynomial expression of 'degree'.
|
|
||||||
We'll use the config parameters:
|
|
||||||
- min_terms, max_terms: how many total terms to combine
|
|
||||||
- min_value, max_value: range for coefficients
|
|
||||||
- operators: to decide sign flips or direct addition
|
|
||||||
|
|
||||||
Args:
|
|
||||||
rng: Random number generator
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Polynomial string
|
|
||||||
"""
|
|
||||||
variable = variable if variable else self._get_variable(rng)
|
|
||||||
degree = rng.randint(self.config.min_degree, self.config.max_degree)
|
|
||||||
|
|
||||||
x = sp.Symbol(variable)
|
|
||||||
|
|
||||||
# Choose the number of terms and their respective degrees
|
|
||||||
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
|
|
||||||
# Keep track of exponents, exponents can repeat or skip but we force the highest exponent
|
|
||||||
chosen_exponents = [degree]
|
|
||||||
# Fill the rest randomly in [0, degree]
|
|
||||||
for _ in range(num_terms - 1):
|
|
||||||
exp = rng.randint(0, degree)
|
|
||||||
chosen_exponents.append(exp)
|
|
||||||
|
|
||||||
# Now build the polynomial expression: sum_{term}( coeff * x^exponent ), with optional sign
|
|
||||||
polynomial_expr = 0
|
|
||||||
for exp in chosen_exponents:
|
|
||||||
coeff = rng.randint(self.config.min_value, self.config.max_value)
|
|
||||||
# If '-' in operators, we can randomly flip the sign
|
|
||||||
if "-" in self.config.operators and rng.random() < 0.5:
|
|
||||||
coeff = -coeff
|
|
||||||
term_expr = coeff * (x**exp)
|
|
||||||
polynomial_expr += term_expr
|
|
||||||
|
|
||||||
return polynomial_expr
|
return polynomial_expr
|
||||||
|
|
||||||
|
|
@ -185,8 +147,8 @@ class PolynomialMultiplicationDataset(ProceduralDataset):
|
||||||
metadata = entry["metadata"]
|
metadata = entry["metadata"]
|
||||||
if answer is not None:
|
if answer is not None:
|
||||||
try:
|
try:
|
||||||
predicted_poly = sp.Poly(answer)
|
predicted_poly = sp.parse_expr(answer)
|
||||||
target_poly = sp.Poly(metadata["result"])
|
target_poly = sp.parse_expr(metadata["result"])
|
||||||
|
|
||||||
# Check if the difference simplifies to zero (i.e. they are equivalent).
|
# Check if the difference simplifies to zero (i.e. they are equivalent).
|
||||||
if predicted_poly == target_poly:
|
if predicted_poly == target_poly:
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ def test_polynomial_config_validation():
|
||||||
PolynomialMultiplicationConfig(min_value=0).validate()
|
PolynomialMultiplicationConfig(min_value=0).validate()
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
PolynomialMultiplicationConfig(min_degree=0, max_degree=3).validate()
|
PolynomialMultiplicationConfig(min_degree=-1, max_degree=3).validate()
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
PolynomialMultiplicationConfig(min_degree=4, max_degree=3).validate()
|
PolynomialMultiplicationConfig(min_degree=4, max_degree=3).validate()
|
||||||
|
|
@ -31,7 +31,7 @@ def test_polynomial_config_validation():
|
||||||
PolynomialMultiplicationConfig(min_polynomials=5, max_polynomials=2).validate()
|
PolynomialMultiplicationConfig(min_polynomials=5, max_polynomials=2).validate()
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
PolynomialMultiplicationConfig(variables=tuple("")).validate()
|
PolynomialMultiplicationConfig(variables="").validate()
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
PolynomialMultiplicationConfig(
|
PolynomialMultiplicationConfig(
|
||||||
|
|
@ -183,7 +183,7 @@ def test_multivariate_polynomial_equations_dataset_items():
|
||||||
max_degree=2,
|
max_degree=2,
|
||||||
min_polynomials=2,
|
min_polynomials=2,
|
||||||
max_polynomials=5,
|
max_polynomials=5,
|
||||||
variables=tuple(["x", "y", "xy"]),
|
variables=tuple(["x", "y"]),
|
||||||
allow_cross_variable_product=True,
|
allow_cross_variable_product=True,
|
||||||
allow_multivariate_polynomials=True,
|
allow_multivariate_polynomials=True,
|
||||||
size=3,
|
size=3,
|
||||||
|
|
@ -228,7 +228,7 @@ def test_polynomial_solutions_evaluation():
|
||||||
max_degree=3,
|
max_degree=3,
|
||||||
min_polynomials=2,
|
min_polynomials=2,
|
||||||
max_polynomials=5,
|
max_polynomials=5,
|
||||||
variables=tuple(["x", "y", "xy"]),
|
variables=tuple(["x", "y"]),
|
||||||
allow_cross_variable_product=True,
|
allow_cross_variable_product=True,
|
||||||
allow_multivariate_polynomials=True,
|
allow_multivariate_polynomials=True,
|
||||||
size=5,
|
size=5,
|
||||||
|
|
@ -257,18 +257,20 @@ def test_score_function():
|
||||||
max_degree=3,
|
max_degree=3,
|
||||||
min_polynomials=3,
|
min_polynomials=3,
|
||||||
max_polynomials=3,
|
max_polynomials=3,
|
||||||
variables=tuple(["x", "y", "xy"]),
|
variables=tuple(["x", "y"]),
|
||||||
allow_cross_variable_product=True,
|
allow_cross_variable_product=True,
|
||||||
allow_multivariate_polynomials=True,
|
allow_multivariate_polynomials=True,
|
||||||
size=1,
|
size=3,
|
||||||
seed=42,
|
seed=42,
|
||||||
)
|
)
|
||||||
|
|
||||||
for item in ds:
|
for item in ds:
|
||||||
poly_str = item["metadata"]["polynomial_expr"]
|
poly_str = item["metadata"]["polynomial_expr"]
|
||||||
poly_expr = sp.expand(poly_str)
|
assert ds.score_answer(poly_str, item) == 0.05
|
||||||
|
|
||||||
|
poly_expr = str(sp.expand(poly_str))
|
||||||
assert ds.score_answer(poly_expr, item) == 1.0
|
assert ds.score_answer(poly_expr, item) == 1.0
|
||||||
|
|
||||||
assert ds.score_answer(None, item) == 0.00
|
assert ds.score_answer(None, item) == 0.00
|
||||||
assert ds.score_answer("Not a polynomial", item) == 0.01
|
assert ds.score_answer("Not a polynomial", item) == 0.01
|
||||||
assert ds.score_answer("x**4", item) == 0.05
|
assert ds.score_answer("x**4", item) == 0.05
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue