Refactor PolynomialMultiplicationDataset and fix issues with score_answer

This commit is contained in:
tohskai 2025-02-17 17:04:48 +01:00
parent 7bad77b426
commit 28fcf4d481
2 changed files with 40 additions and 76 deletions

View file

@ -19,7 +19,7 @@ def test_polynomial_config_validation():
PolynomialMultiplicationConfig(min_value=0).validate()
with pytest.raises(AssertionError):
PolynomialMultiplicationConfig(min_degree=0, max_degree=3).validate()
PolynomialMultiplicationConfig(min_degree=-1, max_degree=3).validate()
with pytest.raises(AssertionError):
PolynomialMultiplicationConfig(min_degree=4, max_degree=3).validate()
@ -31,7 +31,7 @@ def test_polynomial_config_validation():
PolynomialMultiplicationConfig(min_polynomials=5, max_polynomials=2).validate()
with pytest.raises(AssertionError):
PolynomialMultiplicationConfig(variables=tuple("")).validate()
PolynomialMultiplicationConfig(variables="").validate()
with pytest.raises(AssertionError):
PolynomialMultiplicationConfig(
@ -183,7 +183,7 @@ def test_multivariate_polynomial_equations_dataset_items():
max_degree=2,
min_polynomials=2,
max_polynomials=5,
variables=tuple(["x", "y", "xy"]),
variables=tuple(["x", "y"]),
allow_cross_variable_product=True,
allow_multivariate_polynomials=True,
size=3,
@ -228,7 +228,7 @@ def test_polynomial_solutions_evaluation():
max_degree=3,
min_polynomials=2,
max_polynomials=5,
variables=tuple(["x", "y", "xy"]),
variables=tuple(["x", "y"]),
allow_cross_variable_product=True,
allow_multivariate_polynomials=True,
size=5,
@ -257,18 +257,20 @@ def test_score_function():
max_degree=3,
min_polynomials=3,
max_polynomials=3,
variables=tuple(["x", "y", "xy"]),
variables=tuple(["x", "y"]),
allow_cross_variable_product=True,
allow_multivariate_polynomials=True,
size=1,
size=3,
seed=42,
)
for item in ds:
poly_str = item["metadata"]["polynomial_expr"]
poly_expr = sp.expand(poly_str)
assert ds.score_answer(poly_str, item) == 0.05
poly_expr = str(sp.expand(poly_str))
assert ds.score_answer(poly_expr, item) == 1.0
assert ds.score_answer(None, item) == 0.00
assert ds.score_answer("Not a polynomial", item) == 0.01
assert ds.score_answer("x**4", item) == 0.05