mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-23 16:55:05 +00:00
Minor question template & score_answer improvements (#261)
* math prompt improvements * ignore brackets in complex_arithmetic results * improve additional instruction in prompt of polynomial_equations * more strict tests for score_answer in polynomial_equations * simplify special reward handling * fix test_intermediate_integration * fix sokoban dataset * add common dataset score_answer consistency test
This commit is contained in:
parent
bf24999bb0
commit
b2904ccab9
106 changed files with 403 additions and 507 deletions
|
|
@ -130,12 +130,9 @@ Return the final state of the program.
|
|||
float: The computed score between 0.0 and 1.0.
|
||||
"""
|
||||
|
||||
if answer == None:
|
||||
return 0.0
|
||||
if answer != entry["answer"]:
|
||||
return 0.01
|
||||
else:
|
||||
if answer == entry["answer"]:
|
||||
return 1.0 # Yay
|
||||
return 0.0
|
||||
|
||||
|
||||
# Register the dataset
|
||||
|
|
|
|||
|
|
@ -108,9 +108,9 @@ class BinaryMatrixDataset(ProceduralDataset):
|
|||
# check if answer is python list of lists
|
||||
answer = self._matrix_to_str(eval(answer))
|
||||
if answer == oracle_answer:
|
||||
return 0.5
|
||||
except Exception as e:
|
||||
return 0.01
|
||||
return 0.1
|
||||
except Exception:
|
||||
return 0.0
|
||||
return 0.0
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
|
|
|
|||
|
|
@ -200,7 +200,7 @@ class CryptarithmDataset(ProceduralDataset):
|
|||
Returns:
|
||||
float: The computed score between 0.0 and 1.0.
|
||||
"""
|
||||
if not answer:
|
||||
if not isinstance(answer, str):
|
||||
return 0.0
|
||||
|
||||
correct_mapping = {}
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ class GameOfLifeDataset(ProceduralDataset):
|
|||
ans_arr = json.loads(answer)
|
||||
correct_arr = json.loads(entry["answer"])
|
||||
except Exception:
|
||||
return 0.01
|
||||
return 0.0
|
||||
|
||||
total_cells = 0
|
||||
correct_cells = 0
|
||||
|
|
|
|||
|
|
@ -228,12 +228,13 @@ Return your solution as a JSON map of vertices to colors. (For example: {{"0": 1
|
|||
try:
|
||||
danswer = json.loads(answer)
|
||||
solved, failure = verify_graph_coloring_solution(entry["metadata"]["puzzle"], danswer)
|
||||
if not solved:
|
||||
return 0.01 # json was parsable but solution incorrect
|
||||
else:
|
||||
if solved:
|
||||
return 1.0 # Yay
|
||||
else:
|
||||
return 0.01 # json parsable
|
||||
except Exception:
|
||||
return 0.0
|
||||
pass
|
||||
return 0.0
|
||||
|
||||
|
||||
register_dataset("graph_color", GraphColorDataset, GraphColorConfig)
|
||||
|
|
|
|||
|
|
@ -95,7 +95,7 @@ class GroupAnagramsDataset(ProceduralDataset):
|
|||
if answer_str == oracle_str:
|
||||
reward = 1.0
|
||||
else:
|
||||
reward = 0.01
|
||||
reward = 0.01 # json parsable
|
||||
except Exception:
|
||||
reward = 0.0
|
||||
return reward
|
||||
|
|
|
|||
|
|
@ -303,11 +303,11 @@ Reply as a JSON-parsable list of moves which result in any of the jugs being fil
|
|||
danswer = json.loads(answer)
|
||||
valid, _ = verify_solution(entry["metadata"]["puzzle"], danswer)
|
||||
if not valid:
|
||||
return 0.01
|
||||
return 0.01 # json parsable
|
||||
else:
|
||||
return 1.0 # Yay
|
||||
except Exception as e:
|
||||
return 0.01
|
||||
return 0.0
|
||||
|
||||
|
||||
register_dataset("jugs", JugsDataset, JugsConfig)
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@ class LetterJumbleDataset(ProceduralDataset):
|
|||
|
||||
# Each word in the expected answer is worth an equal fraction of 1.0
|
||||
total_words = len(expected_words)
|
||||
score_per_word = 1.0 / total_words if total_words else 0
|
||||
score_per_word = 1.0 / total_words if total_words > 0 else 0
|
||||
|
||||
# Calculate scores word by word
|
||||
scores = []
|
||||
|
|
@ -142,18 +142,16 @@ class LetterJumbleDataset(ProceduralDataset):
|
|||
float: The computed score between 0.0 and 1.0.
|
||||
"""
|
||||
|
||||
if not answer:
|
||||
if not isinstance(answer, str):
|
||||
return 0.0
|
||||
|
||||
oracle_answer = entry["answer"].strip().lower()
|
||||
if answer:
|
||||
answer = answer.strip().lower()
|
||||
if answer == oracle_answer:
|
||||
return 1.0 # Perfect score!
|
||||
else:
|
||||
partial_score = self.partial(oracle_answer, answer)
|
||||
return partial_score
|
||||
return 0.01
|
||||
answer = answer.strip().lower()
|
||||
if answer == oracle_answer:
|
||||
return 1.0 # Perfect score!
|
||||
else:
|
||||
partial_score = self.partial(oracle_answer, answer)
|
||||
return partial_score
|
||||
|
||||
|
||||
register_dataset("letter_jumble", LetterJumbleDataset, LetterJumbleConfig)
|
||||
|
|
|
|||
|
|
@ -144,8 +144,6 @@ class ManipulateMatrixDataset(ProceduralDataset):
|
|||
|
||||
if oracle_answer in answer:
|
||||
return len(oracle_answer) / len(answer)
|
||||
else:
|
||||
return 0.01
|
||||
|
||||
return 0.0
|
||||
|
||||
|
|
|
|||
|
|
@ -92,14 +92,14 @@ class PalindromeDataset(ProceduralDataset):
|
|||
- Correct answer (palindrome with only correct letters in the correct quantities) gives 1.0
|
||||
- An answer that is a palindrome, but not with the same letters as provided, gives 0.05
|
||||
- An answer that is a string, but not a palindrome gives 0.02
|
||||
- An empty string gives 0.01.
|
||||
- An empty string gives 0.0
|
||||
- None gives 0.0.
|
||||
"""
|
||||
if answer is None or not isinstance(answer, str):
|
||||
return 0.0 # No answer given
|
||||
|
||||
if answer == "":
|
||||
return 0.01
|
||||
return 0.0
|
||||
|
||||
metadata = entry["metadata"]
|
||||
answer = answer.strip().lower()
|
||||
|
|
|
|||
|
|
@ -95,9 +95,8 @@ class PalindromePartitioningDataset(ProceduralDataset):
|
|||
oracle = self.to_set_of_tuples(entry["metadata"]["solution"])
|
||||
if answer == oracle:
|
||||
return 1.0
|
||||
return 0.01
|
||||
except Exception:
|
||||
return 0.0
|
||||
pass
|
||||
return 0.0
|
||||
|
||||
def _generate_palindrome_letters(self, rng: Random, length: int) -> list[str]:
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ class PoolMatrixDataset(ProceduralDataset):
|
|||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Score the answer based on the metadata"""
|
||||
|
||||
if not answer:
|
||||
if not isinstance(answer, str):
|
||||
return 0.0
|
||||
|
||||
reward = 0.0
|
||||
|
|
@ -91,8 +91,6 @@ class PoolMatrixDataset(ProceduralDataset):
|
|||
reward = 1.0
|
||||
elif oracle_answer.shape == answer.shape:
|
||||
reward = 0.1
|
||||
else:
|
||||
reward = 0.01
|
||||
except Exception:
|
||||
pass
|
||||
return reward
|
||||
|
|
|
|||
|
|
@ -108,14 +108,12 @@ class RansomNoteDataset(ProceduralDataset):
|
|||
float: The computed score between 0.0 and 1.0.
|
||||
"""
|
||||
|
||||
if answer == None:
|
||||
return 0.0
|
||||
if isinstance(answer, str):
|
||||
s_answer = answer.strip()
|
||||
if s_answer == str(entry["answer"]):
|
||||
return 1.0
|
||||
|
||||
s_answer = answer.strip()
|
||||
if not s_answer == str(entry["answer"]):
|
||||
return 0.01
|
||||
else:
|
||||
return 1.0
|
||||
return 0.0
|
||||
|
||||
|
||||
register_dataset("ransom_note", RansomNoteDataset, RansomNoteConfig)
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ class SentenceReorderingDataset(ProceduralDataset):
|
|||
else:
|
||||
reward = 0.05
|
||||
except:
|
||||
reward = 0.01
|
||||
reward = 0.0
|
||||
return reward
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -52,14 +52,14 @@ class SpellBackwardDataset(ProceduralDataset):
|
|||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
reward = 0.0
|
||||
expected_answer = entry["answer"]
|
||||
if answer is not None:
|
||||
if isinstance(answer, str):
|
||||
try:
|
||||
if expected_answer.lower() == answer.lower():
|
||||
reward = 1.0
|
||||
else:
|
||||
reward = 0.05
|
||||
except:
|
||||
reward = 0.01
|
||||
reward = 0.0
|
||||
return reward
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -126,11 +126,9 @@ class SpiralMatrixDataset(ProceduralDataset):
|
|||
try:
|
||||
answer = " ".join(str(item) for item in eval(answer))
|
||||
if answer == oracle_answer:
|
||||
return 0.5
|
||||
else:
|
||||
return 0.01
|
||||
except Exception as e:
|
||||
return 0.01
|
||||
return 0.1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return 0.0
|
||||
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ class StringInsertionDataset(ProceduralDataset):
|
|||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Overwrite this method in derived classes if a single oracle answer is not available."""
|
||||
oracle_answer = entry["answer"]
|
||||
if answer is not None:
|
||||
if isinstance(answer, str):
|
||||
if answer == oracle_answer:
|
||||
return 1.0
|
||||
else:
|
||||
|
|
@ -83,9 +83,9 @@ class StringInsertionDataset(ProceduralDataset):
|
|||
# check if answer is python list of characters
|
||||
answer = "".join(eval(answer))
|
||||
if answer == oracle_answer:
|
||||
return 0.5
|
||||
except Exception as e:
|
||||
return 0.01
|
||||
return 0.1
|
||||
except Exception:
|
||||
pass
|
||||
return 0.0
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
|
|
|
|||
|
|
@ -221,8 +221,8 @@ class WordLadderDataset(ProceduralDataset):
|
|||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
if answer is None:
|
||||
return 0
|
||||
if not isinstance(answer, str):
|
||||
return 0.0
|
||||
|
||||
answer_words = tuple(s.strip() for s in answer.upper().split(","))
|
||||
|
||||
|
|
@ -239,17 +239,17 @@ class WordLadderDataset(ProceduralDataset):
|
|||
# 4. all words are in our vocabulary
|
||||
|
||||
if len(answer_words) < 2:
|
||||
return 0
|
||||
return 0.0
|
||||
|
||||
if answer_words[0] != start_word or answer_words[-1] != end_word:
|
||||
return 0.01
|
||||
return 0.0
|
||||
|
||||
if not all(len(w) == word_length for w in answer_words):
|
||||
return 0.01
|
||||
return 0.0
|
||||
|
||||
for i in range(1, len(answer_words)):
|
||||
if sum(1 for a, b in zip(answer_words[i - 1], answer_words[i]) if a != b) != 1:
|
||||
return 0.01
|
||||
return 0.0
|
||||
|
||||
reward = 1.0
|
||||
for word in answer_words:
|
||||
|
|
|
|||
|
|
@ -121,8 +121,6 @@ class WordSortingDataset(ProceduralDataset):
|
|||
return 1.0
|
||||
elif sorted(parsed_answer) == oracle_answer:
|
||||
return 0.2
|
||||
else:
|
||||
return 0.01
|
||||
|
||||
return 0.0
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue