mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-30 17:40:45 +00:00
added score_answer implementation and tests
This commit is contained in:
parent
f5838da534
commit
b0d21cf664
4 changed files with 148 additions and 26 deletions
|
|
@ -1,6 +1,6 @@
|
|||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import sympy
|
||||
|
||||
|
|
@ -221,16 +221,43 @@ class IntermediateIntegrationDataset(ProceduralDataset):
|
|||
integrand = self._generate_repeated_parts(rng, x)
|
||||
|
||||
answer = sympy.integrate(integrand, x)
|
||||
answer_str = str(answer) + " + C"
|
||||
|
||||
return {
|
||||
"question": rng.choice(self.prompt_template).format(integrand=integrand),
|
||||
"answer": str(answer) + " + C",
|
||||
"answer": answer_str,
|
||||
"metadata": {
|
||||
"integrand": str(integrand),
|
||||
"problem_type": problem_type,
|
||||
"variable": str(x),
|
||||
"type": substitution_type if problem_type == "substitution" else parts_type,
|
||||
"expected_answer_expression": answer,
|
||||
},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], metadata: Dict[str, Any]) -> float:
|
||||
"""Determine if the solution provided solves the problem"""
|
||||
reward = 0.0
|
||||
if answer is not None:
|
||||
try:
|
||||
var = metadata["variable"]
|
||||
x = sympy.Symbol(var)
|
||||
# Parse answer while allowing integration constant 'C'
|
||||
user_expr = sympy.parse_expr(answer, local_dict={var: x, "C": sympy.Symbol("C")})
|
||||
# Compute derivative of student's answer
|
||||
derivative = sympy.diff(user_expr, x)
|
||||
integrand = sympy.parse_expr(metadata["integrand"], local_dict={var: x})
|
||||
|
||||
# Check mathematical equivalence through simplification
|
||||
if sympy.simplify(derivative - integrand) == 0:
|
||||
reward = 1.0
|
||||
elif answer.strip():
|
||||
reward = 0.05
|
||||
else:
|
||||
reward = 0.01
|
||||
except:
|
||||
reward = 0.01
|
||||
return reward
|
||||
|
||||
|
||||
register_dataset("intermediate_integration", IntermediateIntegrationDataset, IntermediateIntegrationConfig)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue