mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
Minor question template & score_answer improvements (#261)
* math prompt improvements * ignore brackets in complex_arithmetic results * improve additional instruction in prompt of polynomial_equations * more strict tests for score_answer in polynomial_equations * simplify special reward handling * fix test_intermediate_integration * fix sokoban dataset * add common dataset score_answer consistency test
This commit is contained in:
parent
061282e373
commit
5d7fbac0ad
106 changed files with 403 additions and 507 deletions
|
|
@ -160,17 +160,15 @@ class BitwiseArithmeticDataset(ProceduralDataset):
|
|||
Returns:
|
||||
float: 1.0 if the user's answer is correct; otherwise, 0.01 unless no answer is provided, in which case 0.
|
||||
"""
|
||||
if answer is None:
|
||||
return 0.0
|
||||
if isinstance(answer, str):
|
||||
try:
|
||||
solved = verify_solution(entry["metadata"]["problem"], answer)
|
||||
if solved:
|
||||
return 1.0
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
solved = verify_solution(entry["metadata"]["problem"], answer)
|
||||
if solved:
|
||||
return 1.0
|
||||
except Exception:
|
||||
return 0.01
|
||||
|
||||
return 0.01
|
||||
return 0.0
|
||||
|
||||
|
||||
# Register the dataset with the factory.
|
||||
|
|
|
|||
|
|
@ -428,7 +428,7 @@ class CalendarArithmeticDataset(ProceduralDataset):
|
|||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
# we suppose the answer is the last occurence of the expected answer type
|
||||
if answer is None:
|
||||
if not isinstance(answer, str) or len(answer) == 0:
|
||||
return 0.0
|
||||
|
||||
oracle_answer = entry["answer"]
|
||||
|
|
@ -439,9 +439,6 @@ class CalendarArithmeticDataset(ProceduralDataset):
|
|||
CalendarTask.WEEKDAY_OF_DATE_FROM_FIRST_DATE.value,
|
||||
CalendarTask.WEEKDAY_OF_DATE.value,
|
||||
}:
|
||||
if not answer:
|
||||
return 0.0
|
||||
|
||||
answer = answer.strip()
|
||||
oracle_answer = oracle_answer
|
||||
weekdays = {d.name.title() for d in Weekday}
|
||||
|
|
|
|||
|
|
@ -178,7 +178,7 @@ class DecimalArithmeticDataset(ProceduralDataset):
|
|||
+ problem_str
|
||||
)
|
||||
|
||||
return {"question": problem_str, "answer": answer, "metadata": {}}
|
||||
return {"question": problem_str, "answer": str(answer), "metadata": {}}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""
|
||||
|
|
@ -189,12 +189,12 @@ class DecimalArithmeticDataset(ProceduralDataset):
|
|||
Returns:
|
||||
float: 1.0 if the user's answer is within tolerance; otherwise, 0.01.
|
||||
"""
|
||||
if answer is None:
|
||||
if not isinstance(answer, str):
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
user_ans: Decimal = Decimal(answer)
|
||||
correct_ans: Decimal = entry["answer"]
|
||||
correct_ans: Decimal = Decimal(entry["answer"])
|
||||
|
||||
# Determine tolerance based on the desired precision.
|
||||
precision: int = self.config.max_num_decimal_places
|
||||
|
|
@ -202,9 +202,9 @@ class DecimalArithmeticDataset(ProceduralDataset):
|
|||
if abs(user_ans - correct_ans) <= tol:
|
||||
return 1.0
|
||||
except Exception:
|
||||
return 0.01
|
||||
pass
|
||||
|
||||
return 0.01
|
||||
return 0.0
|
||||
|
||||
|
||||
# Register the dataset with the factory.
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import random
|
||||
from dataclasses import dataclass
|
||||
from decimal import Decimal
|
||||
from decimal import Decimal, InvalidOperation
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
|
@ -129,7 +129,11 @@ class DecimalChainSumDataset(ProceduralDataset):
|
|||
result -= c
|
||||
|
||||
expression = " ".join(expression_parts)
|
||||
result = result.quantize(Decimal(f"0.{'0' * max(decimal_places)}"))
|
||||
try:
|
||||
q = Decimal(f"0.{'0' * max(decimal_places)}")
|
||||
result = result.quantize(q)
|
||||
except InvalidOperation:
|
||||
pass
|
||||
return expression, result
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
|
|
@ -141,16 +145,19 @@ class DecimalChainSumDataset(ProceduralDataset):
|
|||
Returns:
|
||||
1.0 for exact numerical match, 0.01 otherwise
|
||||
"""
|
||||
if answer is None or len(answer.strip()) == 0:
|
||||
if not isinstance(answer, str) or len(answer.strip()) == 0:
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
student_answer = Decimal(answer.strip())
|
||||
oracle_answer = Decimal(entry["answer"])
|
||||
|
||||
return 1.0 if student_answer == oracle_answer else 0.01
|
||||
except (ValueError, TypeError, ArithmeticError):
|
||||
return 0.01
|
||||
if student_answer == oracle_answer:
|
||||
return 1.0
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return 0.0
|
||||
|
||||
|
||||
register_dataset("decimal_chain_sum", DecimalChainSumDataset, DecimalChainSumConfig)
|
||||
|
|
|
|||
|
|
@ -138,12 +138,11 @@ class DiceDataset(ProceduralDataset):
|
|||
float: The computed score between 0.0 and 1.0.
|
||||
"""
|
||||
|
||||
if answer == None:
|
||||
return 0.0
|
||||
if answer.lower().replace("\n", "") != entry["answer"].lower().replace("\n", ""):
|
||||
return 0.01
|
||||
else:
|
||||
return 1.0 # Yay
|
||||
if isinstance(answer, str):
|
||||
if answer.lower().replace("\n", "") == entry["answer"].lower().replace("\n", ""):
|
||||
return 1.0 # Yay
|
||||
|
||||
return 0.0
|
||||
|
||||
|
||||
register_dataset("dice", DiceDataset, DiceConfig)
|
||||
|
|
|
|||
|
|
@ -65,14 +65,13 @@ class NumberFormatDataset(ProceduralDataset):
|
|||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Overwrite this method in derived classes if a single oracle answer is not available."""
|
||||
oracle_answer = entry["metadata"]["solution"]
|
||||
if answer is not None and len(answer) > 0:
|
||||
if isinstance(answer, str) and len(answer) > 0:
|
||||
try:
|
||||
answer = float(answer.strip().replace(",", ""))
|
||||
if abs(answer - oracle_answer) < 1e-2:
|
||||
return 1.0
|
||||
return 0.01
|
||||
except:
|
||||
return 0.0
|
||||
pass
|
||||
return 0.0
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
|
|
|
|||
|
|
@ -44,10 +44,8 @@ class PowerFunctionDataset(ProceduralDataset):
|
|||
return 1.0
|
||||
elif difference < 1e-1:
|
||||
return 0.5
|
||||
else:
|
||||
return 0.01
|
||||
except Exception as e:
|
||||
return 0.01
|
||||
except Exception:
|
||||
pass
|
||||
return 0.0
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
|
|
|
|||
|
|
@ -246,7 +246,7 @@ class TimeIntervalsDataset(ProceduralDataset):
|
|||
Returns a score between 0 and 1, with partial credit for answers that are
|
||||
close to correct in the appropriate units/format
|
||||
"""
|
||||
if not answer:
|
||||
if not isinstance(answer, str):
|
||||
return 0.0
|
||||
|
||||
expected = entry["answer"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue