Minor question template & score_answer improvements (#261)

* math prompt improvements
* ignore brackets in complex_arithmetic results
* improve additional instruction in prompt of polynomial_equations
* more strict tests for score_answer in polynomial_equations
* simplify special reward handling
* fix test_intermediate_integration
* fix sokoban dataset
* add common dataset score_answer consistency test
This commit is contained in:
Andreas Köpf 2025-03-04 21:55:09 +01:00 committed by GitHub
parent 061282e373
commit 5d7fbac0ad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
106 changed files with 403 additions and 507 deletions

View file

@ -130,12 +130,9 @@ Return the final state of the program.
float: The computed score between 0.0 and 1.0.
"""
if answer == None:
return 0.0
if answer != entry["answer"]:
return 0.01
else:
if answer == entry["answer"]:
return 1.0 # Yay
return 0.0
# Register the dataset

View file

@ -108,9 +108,9 @@ class BinaryMatrixDataset(ProceduralDataset):
# check if answer is python list of lists
answer = self._matrix_to_str(eval(answer))
if answer == oracle_answer:
return 0.5
except Exception as e:
return 0.01
return 0.1
except Exception:
return 0.0
return 0.0
def __getitem__(self, idx: int) -> dict:

View file

@ -200,7 +200,7 @@ class CryptarithmDataset(ProceduralDataset):
Returns:
float: The computed score between 0.0 and 1.0.
"""
if not answer:
if not isinstance(answer, str):
return 0.0
correct_mapping = {}

View file

@ -106,7 +106,7 @@ class GameOfLifeDataset(ProceduralDataset):
ans_arr = json.loads(answer)
correct_arr = json.loads(entry["answer"])
except Exception:
return 0.01
return 0.0
total_cells = 0
correct_cells = 0

View file

@ -228,12 +228,13 @@ Return your solution as a JSON map of vertices to colors. (For example: {{"0": 1
try:
danswer = json.loads(answer)
solved, failure = verify_graph_coloring_solution(entry["metadata"]["puzzle"], danswer)
if not solved:
return 0.01 # json was parsable but solution incorrect
else:
if solved:
return 1.0 # Yay
else:
return 0.01 # json parsable
except Exception:
return 0.0
pass
return 0.0
register_dataset("graph_color", GraphColorDataset, GraphColorConfig)

View file

@ -95,7 +95,7 @@ class GroupAnagramsDataset(ProceduralDataset):
if answer_str == oracle_str:
reward = 1.0
else:
reward = 0.01
reward = 0.01 # json parsable
except Exception:
reward = 0.0
return reward

View file

@ -303,11 +303,11 @@ Reply as a JSON-parsable list of moves which result in any of the jugs being fil
danswer = json.loads(answer)
valid, _ = verify_solution(entry["metadata"]["puzzle"], danswer)
if not valid:
return 0.01
return 0.01 # json parsable
else:
return 1.0 # Yay
except Exception as e:
return 0.01
return 0.0
register_dataset("jugs", JugsDataset, JugsConfig)

View file

@ -116,7 +116,7 @@ class LetterJumbleDataset(ProceduralDataset):
# Each word in the expected answer is worth an equal fraction of 1.0
total_words = len(expected_words)
score_per_word = 1.0 / total_words if total_words else 0
score_per_word = 1.0 / total_words if total_words > 0 else 0
# Calculate scores word by word
scores = []
@ -142,18 +142,16 @@ class LetterJumbleDataset(ProceduralDataset):
float: The computed score between 0.0 and 1.0.
"""
if not answer:
if not isinstance(answer, str):
return 0.0
oracle_answer = entry["answer"].strip().lower()
if answer:
answer = answer.strip().lower()
if answer == oracle_answer:
return 1.0 # Perfect score!
else:
partial_score = self.partial(oracle_answer, answer)
return partial_score
return 0.01
answer = answer.strip().lower()
if answer == oracle_answer:
return 1.0 # Perfect score!
else:
partial_score = self.partial(oracle_answer, answer)
return partial_score
register_dataset("letter_jumble", LetterJumbleDataset, LetterJumbleConfig)

View file

@ -144,8 +144,6 @@ class ManipulateMatrixDataset(ProceduralDataset):
if oracle_answer in answer:
return len(oracle_answer) / len(answer)
else:
return 0.01
return 0.0

View file

@ -92,14 +92,14 @@ class PalindromeDataset(ProceduralDataset):
- Correct answer (palindrome with only correct letters in the correct quantities) gives 1.0
- An answer that is a palindrome, but not with the same letters as provided, gives 0.05
- An answer that is a string, but not a palindrome gives 0.02
- An empty string gives 0.01.
- An empty string gives 0.0
- None gives 0.0.
"""
if answer is None or not isinstance(answer, str):
return 0.0 # No answer given
if answer == "":
return 0.01
return 0.0
metadata = entry["metadata"]
answer = answer.strip().lower()

View file

@ -95,9 +95,8 @@ class PalindromePartitioningDataset(ProceduralDataset):
oracle = self.to_set_of_tuples(entry["metadata"]["solution"])
if answer == oracle:
return 1.0
return 0.01
except Exception:
return 0.0
pass
return 0.0
def _generate_palindrome_letters(self, rng: Random, length: int) -> list[str]:

View file

@ -80,7 +80,7 @@ class PoolMatrixDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Score the answer based on the metadata"""
if not answer:
if not isinstance(answer, str):
return 0.0
reward = 0.0
@ -91,8 +91,6 @@ class PoolMatrixDataset(ProceduralDataset):
reward = 1.0
elif oracle_answer.shape == answer.shape:
reward = 0.1
else:
reward = 0.01
except Exception:
pass
return reward

View file

@ -108,14 +108,12 @@ class RansomNoteDataset(ProceduralDataset):
float: The computed score between 0.0 and 1.0.
"""
if answer == None:
return 0.0
if isinstance(answer, str):
s_answer = answer.strip()
if s_answer == str(entry["answer"]):
return 1.0
s_answer = answer.strip()
if not s_answer == str(entry["answer"]):
return 0.01
else:
return 1.0
return 0.0
register_dataset("ransom_note", RansomNoteDataset, RansomNoteConfig)

View file

@ -110,7 +110,7 @@ class SentenceReorderingDataset(ProceduralDataset):
else:
reward = 0.05
except:
reward = 0.01
reward = 0.0
return reward

View file

@ -52,14 +52,14 @@ class SpellBackwardDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
reward = 0.0
expected_answer = entry["answer"]
if answer is not None:
if isinstance(answer, str):
try:
if expected_answer.lower() == answer.lower():
reward = 1.0
else:
reward = 0.05
except:
reward = 0.01
reward = 0.0
return reward

View file

@ -126,11 +126,9 @@ class SpiralMatrixDataset(ProceduralDataset):
try:
answer = " ".join(str(item) for item in eval(answer))
if answer == oracle_answer:
return 0.5
else:
return 0.01
except Exception as e:
return 0.01
return 0.1
except Exception:
pass
return 0.0

View file

@ -75,7 +75,7 @@ class StringInsertionDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Overwrite this method in derived classes if a single oracle answer is not available."""
oracle_answer = entry["answer"]
if answer is not None:
if isinstance(answer, str):
if answer == oracle_answer:
return 1.0
else:
@ -83,9 +83,9 @@ class StringInsertionDataset(ProceduralDataset):
# check if answer is python list of characters
answer = "".join(eval(answer))
if answer == oracle_answer:
return 0.5
except Exception as e:
return 0.01
return 0.1
except Exception:
pass
return 0.0
def __getitem__(self, idx: int) -> dict:

View file

@ -221,8 +221,8 @@ class WordLadderDataset(ProceduralDataset):
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
if answer is None:
return 0
if not isinstance(answer, str):
return 0.0
answer_words = tuple(s.strip() for s in answer.upper().split(","))
@ -239,17 +239,17 @@ class WordLadderDataset(ProceduralDataset):
# 4. all words are in our vocabulary
if len(answer_words) < 2:
return 0
return 0.0
if answer_words[0] != start_word or answer_words[-1] != end_word:
return 0.01
return 0.0
if not all(len(w) == word_length for w in answer_words):
return 0.01
return 0.0
for i in range(1, len(answer_words)):
if sum(1 for a, b in zip(answer_words[i - 1], answer_words[i]) if a != b) != 1:
return 0.01
return 0.0
reward = 1.0
for word in answer_words:

View file

@ -121,8 +121,6 @@ class WordSortingDataset(ProceduralDataset):
return 1.0
elif sorted(parsed_answer) == oracle_answer:
return 0.2
else:
return 0.01
return 0.0