Merge pull request #92 from rishabhranawat/poly-reward

Add score_answer() for PolynomialEquationsDataset
This commit is contained in:
Andreas Köpf 2025-02-09 19:30:24 +01:00 committed by GitHub
commit 7bd841d640
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 124 additions and 1 deletions

View file

@ -1,4 +1,5 @@
import pytest
from pytest import approx
from sympy import Symbol, sympify
from reasoning_gym import create_dataset
@ -115,3 +116,25 @@ def test_polynomial_solutions_evaluation():
f"Solution {solution} does not satisfy the polynomial {poly_str}. "
f"Evaluated value: {evaluated_value}"
)
@pytest.mark.parametrize(
"oracle_answer, predicted_answer, expected_reward",
[
("4,-4.12", "4,-4.12", 1.0), # Exact match
("4,-4.12", "4.0001,-4.120001", approx(0.9999, rel=1e-3)), # Very close match
("4,-4.12", "4.1,-4.2", approx(0.9139, rel=1e-3)),
("4,8", "4", approx(0.9, rel=1e-3)), # Missing an oracle solution -> missing solution penalty applies
("4", "4,8", approx(0.95, rel=1e-3)), # extra solution -> extra solution penalty
("-1,-2", "1,4", approx(0.06890, rel=1e-3)), # -1 matched w/ 1 and -2 matched w/ 4
("", "1", approx(0, rel=1e-4)), # oracle no solution, predicted extra solution
("1", "", approx(0, rel=1e-4)), # oracle has a solution, predicted no solution
],
)
def test_polynomial_solutions_score_answer(oracle_answer, predicted_answer, expected_reward):
# You might want to parameterize cfg as well
cfg = PolynomialEquationsConfig(seed=999, size=3)
ds = PolynomialEquationsDataset(cfg)
actual_reward = ds.score_answer(predicted_answer, {"answer": oracle_answer})
assert actual_reward == pytest.approx(expected_reward, rel=1e-3) # Fuzzy comparison for floats