diff --git a/reasoning_gym/algorithmic/letter_jumble.py b/reasoning_gym/algorithmic/letter_jumble.py index 86cce2b3..3aab43f8 100644 --- a/reasoning_gym/algorithmic/letter_jumble.py +++ b/reasoning_gym/algorithmic/letter_jumble.py @@ -123,6 +123,25 @@ class LetterJumbleDataset(ProceduralDataset): }, } + def partial(self, expected_answer, model_answer): + expected_words = expected_answer.split() + model_words = model_answer.split() + + # Each word in the expected answer is worth an equal fraction of 1.0 + total_words = len(expected_words) + score_per_word = 1.0 / total_words if total_words else 0 + + # Calculate scores word by word + scores = [] + for i, word in enumerate(expected_words): + # Check if the corresponding word exists in model_answer and matches exactly + if i < len(model_words) and word == model_words[i]: + scores.append(score_per_word) + else: + scores.append(0.0) + + return min(1.0, sum(scores)) + def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: """Determine if the solution provided solves this task. @@ -136,16 +155,18 @@ class LetterJumbleDataset(ProceduralDataset): float: The computed score between 0.0 and 1.0. """ - oracle_answer = entry["answer"].strip() + if not answer: + return 0.0 + + oracle_answer = entry["answer"].strip().lower() if answer: - answer = answer.strip() + answer = answer.strip().lower() if answer == oracle_answer: - return 1.0 - elif answer.lower() == oracle_answer.lower(): - return 0.5 + return 1.0 # Perfect score! else: - return 0.01 - return 0.0 + partial_score = self.partial(oracle_answer, answer) + return partial_score + return 0.01 register_dataset("letter_jumble", LetterJumbleDataset, LetterJumbleConfig) diff --git a/tests/test_letter_jumble.py b/tests/test_letter_jumble.py index 89f860b5..0a11ce1e 100644 --- a/tests/test_letter_jumble.py +++ b/tests/test_letter_jumble.py @@ -110,8 +110,12 @@ def test_letter_jumble_dataset_items(): # Test the scoring assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 - assert dataset.score_answer(answer="gibberish", entry=item) == 0.01 assert dataset.score_answer(answer=None, entry=item) == 0.0 + answera = item["answer"].split(" ") + answera[0] = "flippityfloop" + answera[1] = "doopadoopadoop" + answerf = " ".join(answera) + assert 0.01 <= dataset.score_answer(answer=answerf, entry=item) <= 1.0 def test_letter_jumble_iteration():