diff --git a/reasoning_gym/games/tsumego.py b/reasoning_gym/games/tsumego.py index c56bbd25..4e3048d3 100644 --- a/reasoning_gym/games/tsumego.py +++ b/reasoning_gym/games/tsumego.py @@ -200,7 +200,7 @@ class TsumegoDataset(ProceduralDataset): board, solution = self._generate_capture_problem(size, rng) board_str = self._board_to_string(board) - solution_str = f"{chr(ord('A')+solution[1])}{solution[0]+1}" + solution_str = f"{chr(ord('A')+solution[1])}{size-solution[0]}" return { "question": ( @@ -225,22 +225,20 @@ class TsumegoDataset(ProceduralDataset): if not answer: return 0.01 metadata = entry["metadata"] - try: - # get solution from (row, col) tuple - expected_row, expected_col = metadata["solution"] - except Exception: - return 0.01 + board_size = len(metadata["board"]) + expected_row, expected_col = metadata["solution"] # get solution from (row, col) tuple + try: # Assume letter-number format, e.g. "C4" m = re.match(r"^([A-Za-z])(\d+)$", answer) if not m: return 0.01 col_letter, row_str = m.group(1), m.group(2) - row = int(row_str) - 1 + row = board_size - int(row_str) col = ord(col_letter.upper()) - ord("A") if (row, col) == (expected_row, expected_col): return 1.0 - board_size = metadata["board_size"] + if 0 <= row < board_size and 0 <= col < board_size: return 0.05 except Exception: diff --git a/tests/test_tsumego.py b/tests/test_tsumego.py index a1e6e6b5..82a5b67f 100644 --- a/tests/test_tsumego.py +++ b/tests/test_tsumego.py @@ -100,7 +100,10 @@ def test_liberties_and_move(): def test_score_answer(): config = TsumegoConfig(min_board_size=9, max_board_size=9, max_stones=10, size=5) dataset = TsumegoDataset(config) - entry = {"metadata": {"board_size": 9, "solution": (4, 4)}} + + # prepare dummy + entry = dataset[0].copy() + entry["metadata"]["solution"] = (4, 4) # Correct letter-number answer (E corresponds to 5) assert dataset.score_answer("E5", entry) == 1.0 @@ -120,7 +123,9 @@ def test_score_answer(): # Out-of-bound letter-number move: 'J' corresponds to 10 which is greater than board size = 9 assert dataset.score_answer("J9", entry) == 0.01 + # test optimal score for answers for x in dataset: + assert len(x["metadata"]["board"]) == x["metadata"]["difficulty"]["board_size"] assert dataset.score_answer(x["answer"], entry=x) == 1.0