mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
added score_answer implementation and tests
This commit is contained in:
parent
f5838da534
commit
b0d21cf664
4 changed files with 148 additions and 26 deletions
|
|
@ -1,6 +1,3 @@
|
|||
import random
|
||||
from fractions import Fraction
|
||||
|
||||
import pytest
|
||||
import sympy
|
||||
from sympy.parsing.sympy_parser import parse_expr
|
||||
|
|
@ -63,7 +60,7 @@ def test_simple_integration_dataset_items():
|
|||
|
||||
assert "integrand" in item["metadata"]
|
||||
assert "variable" in item["metadata"]
|
||||
assert "antiderivative" in item["metadata"]
|
||||
assert "expected_answer_expression" in item["metadata"]
|
||||
|
||||
# Verify answer is a mathematical expression
|
||||
answer = item["answer"]
|
||||
|
|
@ -71,15 +68,50 @@ def test_simple_integration_dataset_items():
|
|||
assert isinstance(parse_expr(answer), sympy.Expr)
|
||||
|
||||
|
||||
def test_simple_integration_solution_verification():
|
||||
"""Test for solution verification of each answer"""
|
||||
config = SimpleIntegrationConfig(seed=42, size=10)
|
||||
def test_verify_answer():
|
||||
config = SimpleIntegrationConfig(seed=42)
|
||||
dataset = SimpleIntegrationDataset(config)
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
score = dataset.score_answer(item["answer"], item["metadata"])
|
||||
assert score == 1.0
|
||||
|
||||
for item in dataset:
|
||||
integrand = parse_expr(item["metadata"]["integrand"])
|
||||
variable = sympy.Symbol(item["metadata"]["variable"])
|
||||
answer = parse_expr(item["answer"].replace(" + C", ""))
|
||||
|
||||
# Verify that the derivative of the answer equals the integrand
|
||||
assert sympy.simplify(sympy.diff(answer, variable) - integrand) == 0
|
||||
def test_score_answer_cases():
|
||||
"""Test various answer scoring scenarios"""
|
||||
config = SimpleIntegrationConfig(seed=42)
|
||||
dataset = SimpleIntegrationDataset(config)
|
||||
x = sympy.Symbol("x")
|
||||
X = sympy.Symbol("X")
|
||||
|
||||
# Test cases: (answer, metadata, expected_score)
|
||||
test_cases = [
|
||||
# Correct answers
|
||||
("x**2 + C", {"variable": "x", "integrand": "2*x"}, 1.0),
|
||||
("X**3 - 5*X + C", {"variable": "X", "integrand": "3*X**2 - 5"}, 1.0),
|
||||
("sin(x) + C", {"variable": "x", "integrand": "cos(x)"}, 1.0),
|
||||
# Correct without explicit constant
|
||||
("x**2", {"variable": "x", "integrand": "2*x"}, 1.0),
|
||||
("log(x)", {"variable": "x", "integrand": "1/x"}, 1.0),
|
||||
# Incorrect but properly formatted
|
||||
("x**3 + C", {"variable": "x", "integrand": "2*x"}, 0.05),
|
||||
("cos(X)", {"variable": "X", "integrand": "sin(X)"}, 0.05),
|
||||
# Malformed expressions
|
||||
("x**2 +", {"variable": "x", "integrand": "2*x"}, 0.01),
|
||||
("sin(x", {"variable": "x", "integrand": "cos(x)"}, 0.01),
|
||||
# Empty answer
|
||||
("", {"variable": "x", "integrand": "2*x"}, 0.01),
|
||||
# Case sensitivity
|
||||
("x**2 + C", {"variable": "X", "integrand": "2*X"}, 0.05),
|
||||
("X**2 + C", {"variable": "x", "integrand": "2*x"}, 0.05),
|
||||
# Alternative constant notation
|
||||
("x**2 + K", {"variable": "x", "integrand": "2*x"}, 1.0),
|
||||
("sin(x) + D", {"variable": "x", "integrand": "cos(x)"}, 1.0),
|
||||
# Simplification required
|
||||
("x**2 + C + 5 - 5", {"variable": "x", "integrand": "2*x"}, 1.0),
|
||||
("(x**3)/3 - 2*x + C", {"variable": "x", "integrand": "x**2 - 2"}, 1.0),
|
||||
]
|
||||
|
||||
for answer, metadata, expected in test_cases:
|
||||
score = dataset.score_answer(answer, metadata)
|
||||
assert score == expected, f"Failed case: {answer} | Expected {expected}, got {score}"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue