Minor question template & score_answer improvements (#261)

* math prompt improvements
* ignore brackets in complex_arithmetic results
* improve additional instruction in prompt of polynomial_equations
* more strict tests for score_answer in polynomial_equations
* simplify special reward handling
* fix test_intermediate_integration
* fix sokoban dataset
* add common dataset score_answer consistency test
This commit is contained in:
Andreas Köpf 2025-03-04 21:55:09 +01:00 committed by GitHub
parent 061282e373
commit 5d7fbac0ad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
106 changed files with 403 additions and 507 deletions

View file

@ -160,17 +160,15 @@ class BitwiseArithmeticDataset(ProceduralDataset):
Returns:
float: 1.0 if the user's answer is correct; otherwise, 0.01 unless no answer is provided, in which case 0.
"""
if answer is None:
return 0.0
if isinstance(answer, str):
try:
solved = verify_solution(entry["metadata"]["problem"], answer)
if solved:
return 1.0
except Exception:
pass
try:
solved = verify_solution(entry["metadata"]["problem"], answer)
if solved:
return 1.0
except Exception:
return 0.01
return 0.01
return 0.0
# Register the dataset with the factory.

View file

@ -428,7 +428,7 @@ class CalendarArithmeticDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
# we suppose the answer is the last occurence of the expected answer type
if answer is None:
if not isinstance(answer, str) or len(answer) == 0:
return 0.0
oracle_answer = entry["answer"]
@ -439,9 +439,6 @@ class CalendarArithmeticDataset(ProceduralDataset):
CalendarTask.WEEKDAY_OF_DATE_FROM_FIRST_DATE.value,
CalendarTask.WEEKDAY_OF_DATE.value,
}:
if not answer:
return 0.0
answer = answer.strip()
oracle_answer = oracle_answer
weekdays = {d.name.title() for d in Weekday}

View file

@ -178,7 +178,7 @@ class DecimalArithmeticDataset(ProceduralDataset):
+ problem_str
)
return {"question": problem_str, "answer": answer, "metadata": {}}
return {"question": problem_str, "answer": str(answer), "metadata": {}}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""
@ -189,12 +189,12 @@ class DecimalArithmeticDataset(ProceduralDataset):
Returns:
float: 1.0 if the user's answer is within tolerance; otherwise, 0.01.
"""
if answer is None:
if not isinstance(answer, str):
return 0.0
try:
user_ans: Decimal = Decimal(answer)
correct_ans: Decimal = entry["answer"]
correct_ans: Decimal = Decimal(entry["answer"])
# Determine tolerance based on the desired precision.
precision: int = self.config.max_num_decimal_places
@ -202,9 +202,9 @@ class DecimalArithmeticDataset(ProceduralDataset):
if abs(user_ans - correct_ans) <= tol:
return 1.0
except Exception:
return 0.01
pass
return 0.01
return 0.0
# Register the dataset with the factory.

View file

@ -1,6 +1,6 @@
import random
from dataclasses import dataclass
from decimal import Decimal
from decimal import Decimal, InvalidOperation
from typing import Any, Optional
from ..factory import ProceduralDataset, register_dataset
@ -129,7 +129,11 @@ class DecimalChainSumDataset(ProceduralDataset):
result -= c
expression = " ".join(expression_parts)
result = result.quantize(Decimal(f"0.{'0' * max(decimal_places)}"))
try:
q = Decimal(f"0.{'0' * max(decimal_places)}")
result = result.quantize(q)
except InvalidOperation:
pass
return expression, result
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
@ -141,16 +145,19 @@ class DecimalChainSumDataset(ProceduralDataset):
Returns:
1.0 for exact numerical match, 0.01 otherwise
"""
if answer is None or len(answer.strip()) == 0:
if not isinstance(answer, str) or len(answer.strip()) == 0:
return 0.0
try:
student_answer = Decimal(answer.strip())
oracle_answer = Decimal(entry["answer"])
return 1.0 if student_answer == oracle_answer else 0.01
except (ValueError, TypeError, ArithmeticError):
return 0.01
if student_answer == oracle_answer:
return 1.0
except Exception:
pass
return 0.0
register_dataset("decimal_chain_sum", DecimalChainSumDataset, DecimalChainSumConfig)

View file

@ -138,12 +138,11 @@ class DiceDataset(ProceduralDataset):
float: The computed score between 0.0 and 1.0.
"""
if answer == None:
return 0.0
if answer.lower().replace("\n", "") != entry["answer"].lower().replace("\n", ""):
return 0.01
else:
return 1.0 # Yay
if isinstance(answer, str):
if answer.lower().replace("\n", "") == entry["answer"].lower().replace("\n", ""):
return 1.0 # Yay
return 0.0
register_dataset("dice", DiceDataset, DiceConfig)

View file

@ -65,14 +65,13 @@ class NumberFormatDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Overwrite this method in derived classes if a single oracle answer is not available."""
oracle_answer = entry["metadata"]["solution"]
if answer is not None and len(answer) > 0:
if isinstance(answer, str) and len(answer) > 0:
try:
answer = float(answer.strip().replace(",", ""))
if abs(answer - oracle_answer) < 1e-2:
return 1.0
return 0.01
except:
return 0.0
pass
return 0.0
def __getitem__(self, idx: int) -> dict:

View file

@ -44,10 +44,8 @@ class PowerFunctionDataset(ProceduralDataset):
return 1.0
elif difference < 1e-1:
return 0.5
else:
return 0.01
except Exception as e:
return 0.01
except Exception:
pass
return 0.0
def __getitem__(self, idx: int) -> dict:

View file

@ -246,7 +246,7 @@ class TimeIntervalsDataset(ProceduralDataset):
Returns a score between 0 and 1, with partial credit for answers that are
close to correct in the appropriate units/format
"""
if not answer:
if not isinstance(answer, str):
return 0.0
expected = entry["answer"]