fix prompt and scoring function

This commit is contained in:
Zafir Stojanovski 2025-02-17 13:17:29 +01:00
parent 2cbaab2918
commit 51c31e7015

View file

@ -9,6 +9,30 @@ from reasoning_gym.data import read_data_file
from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Your task is to unsramble words in a sentence.
For each word in a sentence, the letter may have been randomly shuffled. Your task is to unscramble the words.
The order of the words in the sentence is preserved. Moreover, the style of the sentence is preserved (i.e. punctuation, capitalization, new lines, etc.).
Example:
- Input: Unscramble these words: raendgmeins yWh nya hilcd anc od hatt
- Output: meanderings Why any child can do that
- Explanation
- We unscramble each of the words independently.
- raendgmeins -> meanderings
- yWh -> Why
- nya -> any
- hilcd -> child
- anc -> can
- od -> do
- hatt -> that
- The final answer is: meanderings Why any child can do that
- Notice that the order of the words is preserved, no new words / symbols (e.g. new lines) are added.
Now, unscramble these words: {words}
"""
@dataclass
class LetterJumbleConfig:
@ -89,7 +113,7 @@ class LetterJumbleDataset(ProceduralDataset):
scrambled_words = [self._scramble_word(word, corruption_level, rng) for word in selected_words]
return {
"question": f"Unscramble these words: {' '.join(scrambled_words)}",
"question": QUESTION_TEMPLATE.format(words=" ".join(scrambled_words)),
"answer": " ".join(selected_words),
"metadata": {
"num_words": num_words,
@ -112,14 +136,16 @@ class LetterJumbleDataset(ProceduralDataset):
float: The computed score between 0.0 and 1.0.
"""
if answer == None:
return 0.0
s_answer = answer.strip().lower()
if not s_answer == entry["answer"].strip().lower():
return 0.01
else:
return 1.0
oracle_answer = entry["answer"].strip()
if answer:
answer = answer.strip()
if answer == oracle_answer:
return 1.0
elif answer.lower() == oracle_answer.lower():
return 0.5
else:
return 0.01
return 0.0
register_dataset("letter_jumble", LetterJumbleDataset, LetterJumbleConfig)