mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-22 16:49:06 +00:00
Refactor PolynomialMultiplicationDataset and fix issues with score_answer
This commit is contained in:
parent
7bad77b426
commit
28fcf4d481
2 changed files with 40 additions and 76 deletions
|
|
@ -19,7 +19,7 @@ def test_polynomial_config_validation():
|
|||
PolynomialMultiplicationConfig(min_value=0).validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
PolynomialMultiplicationConfig(min_degree=0, max_degree=3).validate()
|
||||
PolynomialMultiplicationConfig(min_degree=-1, max_degree=3).validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
PolynomialMultiplicationConfig(min_degree=4, max_degree=3).validate()
|
||||
|
|
@ -31,7 +31,7 @@ def test_polynomial_config_validation():
|
|||
PolynomialMultiplicationConfig(min_polynomials=5, max_polynomials=2).validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
PolynomialMultiplicationConfig(variables=tuple("")).validate()
|
||||
PolynomialMultiplicationConfig(variables="").validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
PolynomialMultiplicationConfig(
|
||||
|
|
@ -183,7 +183,7 @@ def test_multivariate_polynomial_equations_dataset_items():
|
|||
max_degree=2,
|
||||
min_polynomials=2,
|
||||
max_polynomials=5,
|
||||
variables=tuple(["x", "y", "xy"]),
|
||||
variables=tuple(["x", "y"]),
|
||||
allow_cross_variable_product=True,
|
||||
allow_multivariate_polynomials=True,
|
||||
size=3,
|
||||
|
|
@ -228,7 +228,7 @@ def test_polynomial_solutions_evaluation():
|
|||
max_degree=3,
|
||||
min_polynomials=2,
|
||||
max_polynomials=5,
|
||||
variables=tuple(["x", "y", "xy"]),
|
||||
variables=tuple(["x", "y"]),
|
||||
allow_cross_variable_product=True,
|
||||
allow_multivariate_polynomials=True,
|
||||
size=5,
|
||||
|
|
@ -257,18 +257,20 @@ def test_score_function():
|
|||
max_degree=3,
|
||||
min_polynomials=3,
|
||||
max_polynomials=3,
|
||||
variables=tuple(["x", "y", "xy"]),
|
||||
variables=tuple(["x", "y"]),
|
||||
allow_cross_variable_product=True,
|
||||
allow_multivariate_polynomials=True,
|
||||
size=1,
|
||||
size=3,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
for item in ds:
|
||||
poly_str = item["metadata"]["polynomial_expr"]
|
||||
poly_expr = sp.expand(poly_str)
|
||||
assert ds.score_answer(poly_str, item) == 0.05
|
||||
|
||||
poly_expr = str(sp.expand(poly_str))
|
||||
assert ds.score_answer(poly_expr, item) == 1.0
|
||||
|
||||
assert ds.score_answer(None, item) == 0.00
|
||||
assert ds.score_answer("Not a polynomial", item) == 0.01
|
||||
assert ds.score_answer("x**4", item) == 0.05
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue