diff --git a/reasoning_gym/dataset.py b/reasoning_gym/dataset.py index 8d126536..78ab7fc2 100644 --- a/reasoning_gym/dataset.py +++ b/reasoning_gym/dataset.py @@ -55,7 +55,7 @@ class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]): """Overwrite this method in derived classes if a single oracle answer is not available.""" oracle_answer = entry["answer"] reward = 0.0 - if answer is not None: + if answer is not None and len(answer) > 0: if answer == oracle_answer: reward = 1.0 elif oracle_answer in answer: diff --git a/reasoning_gym/games/futoshiki.py b/reasoning_gym/games/futoshiki.py index 236b28a4..f71c6e2c 100644 --- a/reasoning_gym/games/futoshiki.py +++ b/reasoning_gym/games/futoshiki.py @@ -4,7 +4,7 @@ import copy import itertools from dataclasses import dataclass from random import Random -from typing import Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple from ..factory import ProceduralDataset, register_dataset @@ -617,5 +617,40 @@ class FutoshikiDataset(ProceduralDataset): return grid + def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: + if not answer: + return 0.0 + + oracle_answer = entry["answer"] + metadata = entry["metadata"] + solution: list[list[int]] = metadata["solution"] + board_size: int = len(solution[0]) + + # 1. match answer without trailing whitespaces + answer_stripped = "\n".join(l.rstrip() for l in answer.split("\n")) + oracle_answer_stripped = "\n".join(l.rstrip() for l in oracle_answer.split("\n")) + + if answer_stripped == oracle_answer_stripped: + reward = 1.0 + else: + # 2. accept answers with correct numeric sequence (ignoring non-numeric characters) + row = 0 + num_matching = 0 + for ln in answer.split("\n"): + numbers = [int(c) for c in ln if c.isnumeric()] + if len(numbers) != len(solution[0]): + continue # ignore lines without numbers + for a, b in zip(solution[row], numbers): + if a == b: + num_matching += 1 + row += 1 + + reward = num_matching / (board_size * board_size) + reward *= 0.9 # penalty for not using standard format + + if len(answer) > len(oracle_answer): + reward *= len(oracle_answer) / len(answer) # penalty for additional length + return reward + register_dataset("futoshiki", FutoshikiDataset, FutoshikiConfig) diff --git a/tests/test_futoshiki.py b/tests/test_futoshiki.py index 8154afad..61f3276b 100644 --- a/tests/test_futoshiki.py +++ b/tests/test_futoshiki.py @@ -162,15 +162,27 @@ def test_futoshiki_answer_scoring(): config = FutoshikiConfig(board_size=4, difficulty=0, size=5, seed=42) dataset = FutoshikiDataset(config) - item = dataset[0] + for item in dataset: + # Correct answer should score 1.0 + assert dataset.score_answer(item["answer"], item) == 1.0 - # Correct answer should score 1.0 - assert dataset.score_answer(item["answer"], item) == 1.0 + # Wrong answer should score lower + wrong_answer = item["answer"].replace("1", "2") + assert dataset.score_answer(wrong_answer, item) < 1.0 - # Wrong answer should score lower - wrong_answer = item["answer"].replace("1", "2") - assert dataset.score_answer(wrong_answer, item) < 1.0 + # None or empty answer should score 0.0 + assert dataset.score_answer(None, item) == 0.0 + assert dataset.score_answer("", item) == 0.0 - # None or empty answer should score 0.0 - assert dataset.score_answer(None, item) == 0.0 - assert dataset.score_answer("", item) == 0.01 + answer = item["answer"] + white_space_mismatch = answer.replace(" ", " ") + assert dataset.score_answer(white_space_mismatch, item) == 0.9 + + anwser_with_additional_text = "This is an anwser " + answer + "\nwith surrounding text." + assert 0 < dataset.score_answer(anwser_with_additional_text, item) < 0.9 + + partially_correct = anwser_with_additional_text.replace("1", "2") + assert dataset.score_answer(partially_correct, item) > 0.1 + + bad_answer = "\n".join(anwser_with_additional_text.split("\n")[::-1]) + assert dataset.score_answer(bad_answer, item) < 0.1