added score_answer implementation and tests

This commit is contained in:
joesharratt1229 2025-02-02 17:18:56 +00:00
parent f5838da534
commit b0d21cf664
4 changed files with 148 additions and 26 deletions

View file

@ -1,6 +1,6 @@
import random
from dataclasses import dataclass
from typing import Optional
from typing import Any, Dict, Optional
import sympy
@ -221,16 +221,43 @@ class IntermediateIntegrationDataset(ProceduralDataset):
integrand = self._generate_repeated_parts(rng, x)
answer = sympy.integrate(integrand, x)
answer_str = str(answer) + " + C"
return {
"question": rng.choice(self.prompt_template).format(integrand=integrand),
"answer": str(answer) + " + C",
"answer": answer_str,
"metadata": {
"integrand": str(integrand),
"problem_type": problem_type,
"variable": str(x),
"type": substitution_type if problem_type == "substitution" else parts_type,
"expected_answer_expression": answer,
},
}
def score_answer(self, answer: Optional[str], metadata: Dict[str, Any]) -> float:
"""Determine if the solution provided solves the problem"""
reward = 0.0
if answer is not None:
try:
var = metadata["variable"]
x = sympy.Symbol(var)
# Parse answer while allowing integration constant 'C'
user_expr = sympy.parse_expr(answer, local_dict={var: x, "C": sympy.Symbol("C")})
# Compute derivative of student's answer
derivative = sympy.diff(user_expr, x)
integrand = sympy.parse_expr(metadata["integrand"], local_dict={var: x})
# Check mathematical equivalence through simplification
if sympy.simplify(derivative - integrand) == 0:
reward = 1.0
elif answer.strip():
reward = 0.05
else:
reward = 0.01
except:
reward = 0.01
return reward
register_dataset("intermediate_integration", IntermediateIntegrationDataset, IntermediateIntegrationConfig)