mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-24 17:05:03 +00:00
fix prompt and scoring function
This commit is contained in:
parent
2cbaab2918
commit
51c31e7015
1 changed files with 35 additions and 9 deletions
|
|
@ -9,6 +9,30 @@ from reasoning_gym.data import read_data_file
|
|||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
QUESTION_TEMPLATE = """Your task is to unsramble words in a sentence.
|
||||
|
||||
For each word in a sentence, the letter may have been randomly shuffled. Your task is to unscramble the words.
|
||||
|
||||
The order of the words in the sentence is preserved. Moreover, the style of the sentence is preserved (i.e. punctuation, capitalization, new lines, etc.).
|
||||
|
||||
Example:
|
||||
- Input: Unscramble these words: raendgmeins yWh nya hilcd anc od hatt
|
||||
- Output: meanderings Why any child can do that
|
||||
- Explanation
|
||||
- We unscramble each of the words independently.
|
||||
- raendgmeins -> meanderings
|
||||
- yWh -> Why
|
||||
- nya -> any
|
||||
- hilcd -> child
|
||||
- anc -> can
|
||||
- od -> do
|
||||
- hatt -> that
|
||||
- The final answer is: meanderings Why any child can do that
|
||||
- Notice that the order of the words is preserved, no new words / symbols (e.g. new lines) are added.
|
||||
|
||||
Now, unscramble these words: {words}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class LetterJumbleConfig:
|
||||
|
|
@ -89,7 +113,7 @@ class LetterJumbleDataset(ProceduralDataset):
|
|||
scrambled_words = [self._scramble_word(word, corruption_level, rng) for word in selected_words]
|
||||
|
||||
return {
|
||||
"question": f"Unscramble these words: {' '.join(scrambled_words)}",
|
||||
"question": QUESTION_TEMPLATE.format(words=" ".join(scrambled_words)),
|
||||
"answer": " ".join(selected_words),
|
||||
"metadata": {
|
||||
"num_words": num_words,
|
||||
|
|
@ -112,14 +136,16 @@ class LetterJumbleDataset(ProceduralDataset):
|
|||
float: The computed score between 0.0 and 1.0.
|
||||
"""
|
||||
|
||||
if answer == None:
|
||||
return 0.0
|
||||
|
||||
s_answer = answer.strip().lower()
|
||||
if not s_answer == entry["answer"].strip().lower():
|
||||
return 0.01
|
||||
else:
|
||||
return 1.0
|
||||
oracle_answer = entry["answer"].strip()
|
||||
if answer:
|
||||
answer = answer.strip()
|
||||
if answer == oracle_answer:
|
||||
return 1.0
|
||||
elif answer.lower() == oracle_answer.lower():
|
||||
return 0.5
|
||||
else:
|
||||
return 0.01
|
||||
return 0.0
|
||||
|
||||
|
||||
register_dataset("letter_jumble", LetterJumbleDataset, LetterJumbleConfig)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue