add ArcAgiDataset class, fix score_entry() metadata params

This commit is contained in:
Andreas Koepf 2025-02-08 23:18:18 +01:00
parent 2ad0965fdc
commit 4e49806d22
20 changed files with 194 additions and 93 deletions

View file

@ -127,11 +127,12 @@ class ComplexArithmeticDataset(ProceduralDataset):
return student_result
def score_answer(self, answer: str, metadata: dict) -> float:
def score_answer(self, answer: Optional[str], entry: dict) -> float:
"""Score the answer using exponential distance-based scoring."""
if answer is None:
return 0.0
metadata = entry["metadata"]
try:
student_result = self.parse_string_to_complex(answer)
expected_result = complex(*metadata["result"])

View file

@ -235,9 +235,10 @@ class IntermediateIntegrationDataset(ProceduralDataset):
},
}
def score_answer(self, answer: Optional[str], metadata: Dict[str, Any]) -> float:
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
"""Determine if the solution provided solves the problem"""
reward = 0.0
metadata = entry["metadata"]
if answer is not None:
try:
var = metadata["variable"]

View file

@ -138,8 +138,9 @@ class PolynomialMultiplicationDataset(ProceduralDataset):
return polynomial_expr
def score_answer(self, answer: Optional[str], metadata: Dict[str, Any]) -> float:
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
reward = 0.0
metadata = entry["metadata"]
if answer is not None:
try:
predicted_poly = sp.parse_expr(answer)

View file

@ -80,9 +80,10 @@ class SimpleIntegrationDataset(ProceduralDataset):
},
}
def score_answer(self, answer: Optional[str], metadata: Dict[str, Any]) -> float:
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
"""Determine if the solution provided solves the problem"""
reward = 0.0
metadata = entry["metadata"]
if answer is not None:
try:
var = metadata["variable"]