mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-28 17:29:39 +00:00
Merge pull request #92 from rishabhranawat/poly-reward
Add score_answer() for PolynomialEquationsDataset
This commit is contained in:
commit
7bd841d640
2 changed files with 124 additions and 1 deletions
|
|
@ -1,4 +1,5 @@
|
|||
import pytest
|
||||
from pytest import approx
|
||||
from sympy import Symbol, sympify
|
||||
|
||||
from reasoning_gym import create_dataset
|
||||
|
|
@ -115,3 +116,25 @@ def test_polynomial_solutions_evaluation():
|
|||
f"Solution {solution} does not satisfy the polynomial {poly_str}. "
|
||||
f"Evaluated value: {evaluated_value}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"oracle_answer, predicted_answer, expected_reward",
|
||||
[
|
||||
("4,-4.12", "4,-4.12", 1.0), # Exact match
|
||||
("4,-4.12", "4.0001,-4.120001", approx(0.9999, rel=1e-3)), # Very close match
|
||||
("4,-4.12", "4.1,-4.2", approx(0.9139, rel=1e-3)),
|
||||
("4,8", "4", approx(0.9, rel=1e-3)), # Missing an oracle solution -> missing solution penalty applies
|
||||
("4", "4,8", approx(0.95, rel=1e-3)), # extra solution -> extra solution penalty
|
||||
("-1,-2", "1,4", approx(0.06890, rel=1e-3)), # -1 matched w/ 1 and -2 matched w/ 4
|
||||
("", "1", approx(0, rel=1e-4)), # oracle no solution, predicted extra solution
|
||||
("1", "", approx(0, rel=1e-4)), # oracle has a solution, predicted no solution
|
||||
],
|
||||
)
|
||||
def test_polynomial_solutions_score_answer(oracle_answer, predicted_answer, expected_reward):
|
||||
# You might want to parameterize cfg as well
|
||||
cfg = PolynomialEquationsConfig(seed=999, size=3)
|
||||
ds = PolynomialEquationsDataset(cfg)
|
||||
|
||||
actual_reward = ds.score_answer(predicted_answer, {"answer": oracle_answer})
|
||||
assert actual_reward == pytest.approx(expected_reward, rel=1e-3) # Fuzzy comparison for floats
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue