diff --git a/reasoning_gym/algorithmic/letter_jumble.py b/reasoning_gym/algorithmic/letter_jumble.py index 728c9c67..b659f6d5 100644 --- a/reasoning_gym/algorithmic/letter_jumble.py +++ b/reasoning_gym/algorithmic/letter_jumble.py @@ -9,6 +9,30 @@ from reasoning_gym.data import read_data_file from ..factory import ProceduralDataset, register_dataset +QUESTION_TEMPLATE = """Your task is to unsramble words in a sentence. + +For each word in a sentence, the letter may have been randomly shuffled. Your task is to unscramble the words. + +The order of the words in the sentence is preserved. Moreover, the style of the sentence is preserved (i.e. punctuation, capitalization, new lines, etc.). + +Example: +- Input: Unscramble these words: raendgmeins yWh nya hilcd anc od hatt +- Output: meanderings Why any child can do that +- Explanation + - We unscramble each of the words independently. + - raendgmeins -> meanderings + - yWh -> Why + - nya -> any + - hilcd -> child + - anc -> can + - od -> do + - hatt -> that + - The final answer is: meanderings Why any child can do that + - Notice that the order of the words is preserved, no new words / symbols (e.g. new lines) are added. + +Now, unscramble these words: {words} +""" + @dataclass class LetterJumbleConfig: @@ -89,7 +113,7 @@ class LetterJumbleDataset(ProceduralDataset): scrambled_words = [self._scramble_word(word, corruption_level, rng) for word in selected_words] return { - "question": f"Unscramble these words: {' '.join(scrambled_words)}", + "question": QUESTION_TEMPLATE.format(words=" ".join(scrambled_words)), "answer": " ".join(selected_words), "metadata": { "num_words": num_words, @@ -112,14 +136,16 @@ class LetterJumbleDataset(ProceduralDataset): float: The computed score between 0.0 and 1.0. """ - if answer == None: - return 0.0 - - s_answer = answer.strip().lower() - if not s_answer == entry["answer"].strip().lower(): - return 0.01 - else: - return 1.0 + oracle_answer = entry["answer"].strip() + if answer: + answer = answer.strip() + if answer == oracle_answer: + return 1.0 + elif answer.lower() == oracle_answer.lower(): + return 0.5 + else: + return 0.01 + return 0.0 register_dataset("letter_jumble", LetterJumbleDataset, LetterJumbleConfig)