fix: add score_answer() to number_sorting (#380)

* fix: add score_answer() to number_sorting

* chore: run pre-commit

* fix: use json.loads()

* fix: run isort()
This commit is contained in:
Jean Kaddour 2025-03-17 22:04:13 +00:00 committed by GitHub
parent 1c6f2d01ee
commit d6aad5a329
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 120 additions and 1 deletions

View file

@ -1,8 +1,9 @@
"""Number sorting task generator"""
import json
from dataclasses import dataclass
from random import Random
from typing import Optional
from typing import Any, Optional
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@ -100,6 +101,68 @@ Please follow the instruction below:
},
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Score the user's answer against the expected answer.
Args:
answer (Optional[str]): The user's answer string.
entry (dict[str, Any]): The original dataset entry with the correct answer.
Returns:
float: 1.0 for a correct answer, 0.0 for incorrect.
"""
if answer is None:
return 0.0
try:
# Try to parse the user's answer as a JSON list first
try:
user_answer = json.loads(answer)
except json.JSONDecodeError:
# If JSON parsing fails, fall back to eval (with caution)
user_answer = eval(answer)
if not isinstance(user_answer, list):
return 0.0
# Get the expected answer
try:
expected_answer = json.loads(entry["answer"])
except json.JSONDecodeError:
# Fall back to eval if necessary
expected_answer = eval(entry["answer"])
# Check if the lists have the same length
if len(user_answer) != len(expected_answer):
return 0.0
# Convert both answers to floats for comparison
user_floats = [float(num) for num in user_answer]
expected_floats = [float(num) for num in expected_answer]
# First, verify the user's answer is properly sorted
direction = entry["metadata"]["direction"]
is_correctly_sorted = False
if direction == "ascending":
is_correctly_sorted = user_floats == sorted(user_floats)
else: # descending
is_correctly_sorted = user_floats == sorted(user_floats, reverse=True)
if not is_correctly_sorted:
return 0.0
# Check if the values are close enough (allowing for small rounding differences)
tolerance = 0.1 # Increased tolerance to handle decimal differences
for i in range(len(user_floats)):
if abs(user_floats[i] - expected_floats[i]) > tolerance:
return 0.0
return 1.0
except Exception as e:
# Any parsing error means the answer is incorrect
return 0.0
class NumberSortingCurriculum(BaseCurriculum):
def __init__(self):