Minor question template & score_answer improvements (#261)

* math prompt improvements
* ignore brackets in complex_arithmetic results
* improve additional instruction in prompt of polynomial_equations
* more strict tests for score_answer in polynomial_equations
* simplify special reward handling
* fix test_intermediate_integration
* fix sokoban dataset
* add common dataset score_answer consistency test
This commit is contained in:
Andreas Köpf 2025-03-04 21:55:09 +01:00 committed by GitHub
parent bf24999bb0
commit b2904ccab9
106 changed files with 403 additions and 507 deletions

View file

@ -358,16 +358,14 @@ class FamilyRelationshipsDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
reward = 0.0
if answer is not None:
if isinstance(answer, str):
try:
answer_formatted = answer.strip().lower()
solved = answer_formatted == entry["answer"].strip().lower()
if solved:
oracle_answer = entry["answer"].strip().lower()
if answer_formatted == oracle_answer:
reward = 1.0
else:
reward = 0.01
except:
reward = 0.01
pass
return reward

View file

@ -169,21 +169,15 @@ Buttons:
The function awards 1.0 for a correct answer and less otherwise.
"""
if answer == None:
if not isinstance(answer, str):
return 0.0
# Get correct solution from metadata
correct_solution = entry["metadata"].get("solution_path", [])
# Normalize both answers
def normalize_seq(seq):
"""Handle both string and list inputs by converting to string first"""
# Convert sequence to string representation if it's a list
input_str = "".join(seq) if isinstance(seq, list) else str(seq or "")
return [c.upper() for c in re.findall(r"[A-C]", input_str.upper())]
def normalize_seq(seq: str) -> list[str]:
return [c.upper() for c in re.findall(r"[A-C]", seq.upper())]
user_sequence = normalize_seq(answer)
target_sequence = normalize_seq("".join(correct_solution))
target_sequence = normalize_seq(entry["answer"])
# Exact sequence match required
if user_sequence == target_sequence:
@ -196,7 +190,7 @@ Buttons:
return 1.0 # Different answer, but qually correct
return 0.5 # Alternative scoring - you're correct, but not optimal
return 0.1
return 0.0
def simulate_sequence(self, metadata: dict, sequence: list[str]) -> int:
"""Simulate button presses to verify solutions"""

View file

@ -125,8 +125,8 @@ class ShortestPathDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Overwrite this method in derived classes if a single oracle answer is not available."""
oracle_answer = entry["answer"].strip()
if answer is not None and len(answer) > 0:
if isinstance(answer, str) and len(answer) > 0:
oracle_answer = entry["answer"].strip()
answer = answer.strip()
# Exact answer
@ -145,8 +145,6 @@ class ShortestPathDataset(ProceduralDataset):
elif self._is_valid_path(matrix, answer):
return 0.5
return 0.01
return 0.0
def __getitem__(self, idx: int) -> dict: