mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-25 17:10:51 +00:00
Minor question template & score_answer improvements (#261)
* math prompt improvements * ignore brackets in complex_arithmetic results * improve additional instruction in prompt of polynomial_equations * more strict tests for score_answer in polynomial_equations * simplify special reward handling * fix test_intermediate_integration * fix sokoban dataset * add common dataset score_answer consistency test
This commit is contained in:
parent
bf24999bb0
commit
b2904ccab9
106 changed files with 403 additions and 507 deletions
|
|
@ -358,16 +358,14 @@ class FamilyRelationshipsDataset(ProceduralDataset):
|
|||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
reward = 0.0
|
||||
if answer is not None:
|
||||
if isinstance(answer, str):
|
||||
try:
|
||||
answer_formatted = answer.strip().lower()
|
||||
solved = answer_formatted == entry["answer"].strip().lower()
|
||||
if solved:
|
||||
oracle_answer = entry["answer"].strip().lower()
|
||||
if answer_formatted == oracle_answer:
|
||||
reward = 1.0
|
||||
else:
|
||||
reward = 0.01
|
||||
except:
|
||||
reward = 0.01
|
||||
pass
|
||||
return reward
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -169,21 +169,15 @@ Buttons:
|
|||
|
||||
The function awards 1.0 for a correct answer and less otherwise.
|
||||
"""
|
||||
if answer == None:
|
||||
if not isinstance(answer, str):
|
||||
return 0.0
|
||||
|
||||
# Get correct solution from metadata
|
||||
correct_solution = entry["metadata"].get("solution_path", [])
|
||||
|
||||
# Normalize both answers
|
||||
def normalize_seq(seq):
|
||||
"""Handle both string and list inputs by converting to string first"""
|
||||
# Convert sequence to string representation if it's a list
|
||||
input_str = "".join(seq) if isinstance(seq, list) else str(seq or "")
|
||||
return [c.upper() for c in re.findall(r"[A-C]", input_str.upper())]
|
||||
def normalize_seq(seq: str) -> list[str]:
|
||||
return [c.upper() for c in re.findall(r"[A-C]", seq.upper())]
|
||||
|
||||
user_sequence = normalize_seq(answer)
|
||||
target_sequence = normalize_seq("".join(correct_solution))
|
||||
target_sequence = normalize_seq(entry["answer"])
|
||||
|
||||
# Exact sequence match required
|
||||
if user_sequence == target_sequence:
|
||||
|
|
@ -196,7 +190,7 @@ Buttons:
|
|||
return 1.0 # Different answer, but qually correct
|
||||
return 0.5 # Alternative scoring - you're correct, but not optimal
|
||||
|
||||
return 0.1
|
||||
return 0.0
|
||||
|
||||
def simulate_sequence(self, metadata: dict, sequence: list[str]) -> int:
|
||||
"""Simulate button presses to verify solutions"""
|
||||
|
|
|
|||
|
|
@ -125,8 +125,8 @@ class ShortestPathDataset(ProceduralDataset):
|
|||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Overwrite this method in derived classes if a single oracle answer is not available."""
|
||||
oracle_answer = entry["answer"].strip()
|
||||
if answer is not None and len(answer) > 0:
|
||||
if isinstance(answer, str) and len(answer) > 0:
|
||||
oracle_answer = entry["answer"].strip()
|
||||
answer = answer.strip()
|
||||
|
||||
# Exact answer
|
||||
|
|
@ -145,8 +145,6 @@ class ShortestPathDataset(ProceduralDataset):
|
|||
elif self._is_valid_path(matrix, answer):
|
||||
return 0.5
|
||||
|
||||
return 0.01
|
||||
|
||||
return 0.0
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue