mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
fix: add score_answer() to number_sorting (#380)
* fix: add score_answer() to number_sorting * chore: run pre-commit * fix: use json.loads() * fix: run isort()
This commit is contained in:
parent
1c6f2d01ee
commit
d6aad5a329
2 changed files with 120 additions and 1 deletions
|
|
@ -1,8 +1,9 @@
|
|||
"""Number sorting task generator"""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
|
@ -100,6 +101,68 @@ Please follow the instruction below:
|
|||
},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Score the user's answer against the expected answer.
|
||||
|
||||
Args:
|
||||
answer (Optional[str]): The user's answer string.
|
||||
entry (dict[str, Any]): The original dataset entry with the correct answer.
|
||||
|
||||
Returns:
|
||||
float: 1.0 for a correct answer, 0.0 for incorrect.
|
||||
"""
|
||||
if answer is None:
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
# Try to parse the user's answer as a JSON list first
|
||||
try:
|
||||
user_answer = json.loads(answer)
|
||||
except json.JSONDecodeError:
|
||||
# If JSON parsing fails, fall back to eval (with caution)
|
||||
user_answer = eval(answer)
|
||||
|
||||
if not isinstance(user_answer, list):
|
||||
return 0.0
|
||||
|
||||
# Get the expected answer
|
||||
try:
|
||||
expected_answer = json.loads(entry["answer"])
|
||||
except json.JSONDecodeError:
|
||||
# Fall back to eval if necessary
|
||||
expected_answer = eval(entry["answer"])
|
||||
|
||||
# Check if the lists have the same length
|
||||
if len(user_answer) != len(expected_answer):
|
||||
return 0.0
|
||||
|
||||
# Convert both answers to floats for comparison
|
||||
user_floats = [float(num) for num in user_answer]
|
||||
expected_floats = [float(num) for num in expected_answer]
|
||||
|
||||
# First, verify the user's answer is properly sorted
|
||||
direction = entry["metadata"]["direction"]
|
||||
is_correctly_sorted = False
|
||||
|
||||
if direction == "ascending":
|
||||
is_correctly_sorted = user_floats == sorted(user_floats)
|
||||
else: # descending
|
||||
is_correctly_sorted = user_floats == sorted(user_floats, reverse=True)
|
||||
|
||||
if not is_correctly_sorted:
|
||||
return 0.0
|
||||
|
||||
# Check if the values are close enough (allowing for small rounding differences)
|
||||
tolerance = 0.1 # Increased tolerance to handle decimal differences
|
||||
for i in range(len(user_floats)):
|
||||
if abs(user_floats[i] - expected_floats[i]) > tolerance:
|
||||
return 0.0
|
||||
|
||||
return 1.0
|
||||
except Exception as e:
|
||||
# Any parsing error means the answer is incorrect
|
||||
return 0.0
|
||||
|
||||
|
||||
class NumberSortingCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue