From 28fcf4d481657d415c3b00a868063a38463e8fb5 Mon Sep 17 00:00:00 2001 From: tohskai Date: Mon, 17 Feb 2025 17:04:48 +0100 Subject: [PATCH] Refactor PolynomialMultiplicationDataset and fix issues with score_answer --- .../algebra/polynomial_multiplication.py | 100 ++++++------------ tests/test_polynomial_multiplication.py | 16 +-- 2 files changed, 40 insertions(+), 76 deletions(-) diff --git a/reasoning_gym/algebra/polynomial_multiplication.py b/reasoning_gym/algebra/polynomial_multiplication.py index 36f3ea1a..4e3712de 100644 --- a/reasoning_gym/algebra/polynomial_multiplication.py +++ b/reasoning_gym/algebra/polynomial_multiplication.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple import sympy as sp +from sympy.polys.monomials import itermonomials from ..factory import ProceduralDataset, register_dataset @@ -18,7 +19,7 @@ class PolynomialMultiplicationConfig: max_terms: int = 4 # Maximum number of polynomial terms min_value: int = 1 # Minimum value for coefficients max_value: int = 100 # Maximum value for coefficients - min_degree: int = 1 # Minimum polynomial degree + min_degree: int = 0 # Minimum polynomial degree max_degree: int = 3 # Maximum polynomial degree min_polynomials: int = 2 # Minimum number of polynomials being multiplied max_polynomials: int = 3 # Maximum number of polynomials being multiplied @@ -40,7 +41,7 @@ class PolynomialMultiplicationConfig: assert self.min_value > 0, "min_value must be positive." assert self.max_value >= self.min_value, "max_value must be >= min_value." - assert self.min_degree >= 1, "min_degree must be >= 1." + assert self.min_degree >= 0, "min_degree must be >= 0." assert self.max_degree >= self.min_degree, "max_degree must be >= min_degree." assert self.min_polynomials >= 2, "min_polynomials must be >= 2." @@ -80,9 +81,22 @@ class PolynomialMultiplicationDataset(ProceduralDataset): - answer: str (Product, e.g. "8x^4 - 24x^3 + x^2 - x - 6") - metadata: dict with details (polynomial_expr, result, variables) """ + rng = random.Random(self.seed + idx) - polynomial_expr = sp.prod(self._generate_polynomial_product(rng)) + """ + Three Monomial States: + - allow_multivariate_polynomials == 1: list of multivariate monomials (e.g "xy" --> [x, y, xy, x**2, y**2]) + - allow_cross_variable_product == 1: None. Will generate a unique list of single variable monomials for each term + - allow_cross_variable_product == 0: A shared list of monomials for each term (e.g "x" --> [x, x**2, 1]) + """ + monomials = self._get_monomials(rng) if self.config.allow_cross_variable_product else None + monomials = None if self.config.allow_cross_variable_product else self._get_monomials(rng) + + number_polynomials = rng.randint(self.config.min_polynomials, self.config.max_polynomials) + + polynomial_terms = [self._generate_polynomial(rng, monomials) for _ in range(number_polynomials)] + polynomial_expr = sp.prod(polynomial_terms) product = sp.expand(polynomial_expr) return { @@ -97,26 +111,19 @@ class PolynomialMultiplicationDataset(ProceduralDataset): }, } - def _get_variable(self, rng: random.Random) -> str: - """Get a random lowercase variable name""" - return rng.choice(self.config.variables) - - def _generate_polynomial_product(self, rng): - """Helper for selecting regular or multivariate polynomial. Returns expressions and unique variables.""" - - variable = None if self.config.allow_cross_variable_product else self._get_variable(rng) - number_polynomials = rng.randint(self.config.min_polynomials, self.config.max_polynomials) - + def _get_monomials(self, rng: random.Random) -> str: + """Get a list of monomials""" if self.config.allow_multivariate_polynomials: - generated = [self._generate_multivariate_polynomial(rng) for _ in range(number_polynomials)] + sym = sp.symbols(self.config.variables) else: - generated = [self._generate_regular_polynomial(rng, variable) for _ in range(number_polynomials)] + sym = [sp.symbols(rng.choice(self.config.variables))] + monomials = list(itermonomials(sym, self.config.max_degree, self.config.min_degree)) + return monomials - return generated - - def _generate_multivariate_polynomial(self, rng: random.Random): - """Generates a multivariate polynomial, returns variable set and expression.""" + def _generate_polynomial(self, rng: random.Random, monomials: Optional[list]): + """Generates a random polynomial, returns expression.""" # Choose the number of terms and their respective degrees + monomials = monomials if monomials else self._get_monomials(rng) num_terms = rng.randint(self.config.min_terms, self.config.max_terms) polynomial_expr = 0 @@ -124,59 +131,14 @@ class PolynomialMultiplicationDataset(ProceduralDataset): # Pick a nonzero random coefficient between min_value and max_value. coeff = rng.randint(self.config.min_value, self.config.max_value) - # Build the monomial by choosing each exponent independently. - monomial = 1 - var = self._get_variable(rng) - for v in var: - v = sp.Symbol(v) - exp = random.randint(self.config.min_degree, self.config.max_degree) - monomial *= v**exp + # Pick a random monomial + var = rng.choice(monomials) # If '-' in operators, we can randomly flip the sign if "-" in self.config.operators and rng.random() < 0.5: coeff = -coeff - polynomial_expr += coeff * monomial - - return polynomial_expr - - def _generate_regular_polynomial(self, rng: random.Random, variable: Optional[str]): - """ - Randomly generate a polynomial expression of 'degree'. - We'll use the config parameters: - - min_terms, max_terms: how many total terms to combine - - min_value, max_value: range for coefficients - - operators: to decide sign flips or direct addition - - Args: - rng: Random number generator - - Returns: - Polynomial string - """ - variable = variable if variable else self._get_variable(rng) - degree = rng.randint(self.config.min_degree, self.config.max_degree) - - x = sp.Symbol(variable) - - # Choose the number of terms and their respective degrees - num_terms = rng.randint(self.config.min_terms, self.config.max_terms) - # Keep track of exponents, exponents can repeat or skip but we force the highest exponent - chosen_exponents = [degree] - # Fill the rest randomly in [0, degree] - for _ in range(num_terms - 1): - exp = rng.randint(0, degree) - chosen_exponents.append(exp) - - # Now build the polynomial expression: sum_{term}( coeff * x^exponent ), with optional sign - polynomial_expr = 0 - for exp in chosen_exponents: - coeff = rng.randint(self.config.min_value, self.config.max_value) - # If '-' in operators, we can randomly flip the sign - if "-" in self.config.operators and rng.random() < 0.5: - coeff = -coeff - term_expr = coeff * (x**exp) - polynomial_expr += term_expr + polynomial_expr += coeff * var return polynomial_expr @@ -185,8 +147,8 @@ class PolynomialMultiplicationDataset(ProceduralDataset): metadata = entry["metadata"] if answer is not None: try: - predicted_poly = sp.Poly(answer) - target_poly = sp.Poly(metadata["result"]) + predicted_poly = sp.parse_expr(answer) + target_poly = sp.parse_expr(metadata["result"]) # Check if the difference simplifies to zero (i.e. they are equivalent). if predicted_poly == target_poly: diff --git a/tests/test_polynomial_multiplication.py b/tests/test_polynomial_multiplication.py index 35d1fdd3..10b404d9 100644 --- a/tests/test_polynomial_multiplication.py +++ b/tests/test_polynomial_multiplication.py @@ -19,7 +19,7 @@ def test_polynomial_config_validation(): PolynomialMultiplicationConfig(min_value=0).validate() with pytest.raises(AssertionError): - PolynomialMultiplicationConfig(min_degree=0, max_degree=3).validate() + PolynomialMultiplicationConfig(min_degree=-1, max_degree=3).validate() with pytest.raises(AssertionError): PolynomialMultiplicationConfig(min_degree=4, max_degree=3).validate() @@ -31,7 +31,7 @@ def test_polynomial_config_validation(): PolynomialMultiplicationConfig(min_polynomials=5, max_polynomials=2).validate() with pytest.raises(AssertionError): - PolynomialMultiplicationConfig(variables=tuple("")).validate() + PolynomialMultiplicationConfig(variables="").validate() with pytest.raises(AssertionError): PolynomialMultiplicationConfig( @@ -183,7 +183,7 @@ def test_multivariate_polynomial_equations_dataset_items(): max_degree=2, min_polynomials=2, max_polynomials=5, - variables=tuple(["x", "y", "xy"]), + variables=tuple(["x", "y"]), allow_cross_variable_product=True, allow_multivariate_polynomials=True, size=3, @@ -228,7 +228,7 @@ def test_polynomial_solutions_evaluation(): max_degree=3, min_polynomials=2, max_polynomials=5, - variables=tuple(["x", "y", "xy"]), + variables=tuple(["x", "y"]), allow_cross_variable_product=True, allow_multivariate_polynomials=True, size=5, @@ -257,18 +257,20 @@ def test_score_function(): max_degree=3, min_polynomials=3, max_polynomials=3, - variables=tuple(["x", "y", "xy"]), + variables=tuple(["x", "y"]), allow_cross_variable_product=True, allow_multivariate_polynomials=True, - size=1, + size=3, seed=42, ) for item in ds: poly_str = item["metadata"]["polynomial_expr"] - poly_expr = sp.expand(poly_str) + assert ds.score_answer(poly_str, item) == 0.05 + poly_expr = str(sp.expand(poly_str)) assert ds.score_answer(poly_expr, item) == 1.0 + assert ds.score_answer(None, item) == 0.00 assert ds.score_answer("Not a polynomial", item) == 0.01 assert ds.score_answer("x**4", item) == 0.05