add ArcAgiDataset class, fix score_entry() metadata params

This commit is contained in:
Andreas Koepf 2025-02-08 23:18:18 +01:00
parent 2ad0965fdc
commit 4e49806d22
20 changed files with 194 additions and 93 deletions

View file

@ -137,10 +137,10 @@ def test_score_function():
seed=42,
)
assert ds.score_answer(None, ds[0]["metadata"]) == 0.00
assert ds.score_answer("6*x**4 + 9*x**3 - 6*x**2 - 39*x - 45", ds[0]["metadata"]) == 1
assert ds.score_answer("Not a polynomial", ds[0]["metadata"]) == 0.01
assert ds.score_answer("x**4", ds[0]["metadata"]) == 0.05
assert ds.score_answer(None, ds[0]) == 0.00
assert ds.score_answer("6*x**4 + 9*x**3 - 6*x**2 - 39*x - 45", ds[0]) == 1
assert ds.score_answer("Not a polynomial", ds[0]) == 0.01
assert ds.score_answer("x**4", ds[0]) == 0.05
def test_multivariate_score_function():
@ -160,7 +160,7 @@ def test_multivariate_score_function():
seed=42,
)
assert ds.score_answer(None, ds[0]["metadata"]) == 0.00
assert ds.score_answer("-27*a**3*c - 27*a**3 + 144*a*c + 144*a", ds[0]["metadata"]) == 1
assert ds.score_answer("Not a polynomial", ds[0]["metadata"]) == 0.01
assert ds.score_answer("x**4", ds[0]["metadata"]) == 0.05
assert ds.score_answer(None, ds[0]) == 0.00
assert ds.score_answer("-27*a**3*c - 27*a**3 + 144*a*c + 144*a", ds[0]) == 1
assert ds.score_answer("Not a polynomial", ds[0]) == 0.01
assert ds.score_answer("x**4", ds[0]) == 0.05