diff --git a/README.md b/README.md index d8193e84..c329d21d 100644 --- a/README.md +++ b/README.md @@ -134,6 +134,7 @@ See the [Dataset Gallery](GALLERY.md) for a complete list of available datasets - `MazeDataset`: Generate a maze with a start and a goal - `CountdownDataset`: Generate number game tasks where numbers and operators must be combined to reach a target value - `NQueensDataset`: Generate N-Queens puzzles with configurable board size and number of starting queens +- `TsumegoDataset`: Generate Tsumego capture puzzles with variable board sizes and stone placements ## Future Generator Ideas diff --git a/reasoning_gym/games/__init__.py b/reasoning_gym/games/__init__.py index 958dcd01..295f6cdf 100644 --- a/reasoning_gym/games/__init__.py +++ b/reasoning_gym/games/__init__.py @@ -14,6 +14,7 @@ from .n_queens import NQueensDataset from .sokoban import SokobanConfig, SokobanDataset from .sudoku import SudokuConfig, SudokuDataset from .tower_of_hanoi import HanoiConfig, HanoiDataset +from .tsumego import TsumegoConfig, TsumegoDataset __all__ = [ "CountdownConfig", @@ -31,4 +32,6 @@ __all__ = [ "HanoiConfig", "HanoiDataset", "NQueensDataset", + "TsumegoConfig", + "TsumegoDataset", ] diff --git a/reasoning_gym/games/tsumego.py b/reasoning_gym/games/tsumego.py new file mode 100644 index 00000000..c4473c4c --- /dev/null +++ b/reasoning_gym/games/tsumego.py @@ -0,0 +1,253 @@ +"""Go problem (tsumego) generator""" + +from dataclasses import dataclass +from random import Random +from typing import Dict, List, Optional, Set, Tuple + +from ..factory import ProceduralDataset, register_dataset + +# Added constant to avoid repetition of adjacent directions +DIRECTIONS = [(-1, 0), (1, 0), (0, -1), (0, 1)] + + +@dataclass +class TsumegoConfig: + """Configuration for Tsumego problem generation""" + + min_board_size: int = 9 + max_board_size: int = 13 + max_stones: int = 15 + size: int = 100 + seed: Optional[int] = None + + def __post_init__(self): + """Validate configuration parameters""" + if self.min_board_size < 5: + raise ValueError("min_board_size must be at least 5") + if self.max_board_size > 19: + raise ValueError("max_board_size must be at most 19") + if self.min_board_size > self.max_board_size: + raise ValueError("min_board_size must be less than or equal to max_board_size") + if self.max_stones < 5: + raise ValueError("max_stones must be at least 5") + + +class TsumegoDataset(ProceduralDataset): + """Generates Tsumego problems with configurable parameters""" + + def __init__(self, config: TsumegoConfig): + self._prompt_templates = [ + "Black to play and capture some stones.\nFind the key move.", + "It's Black's turn. Capture the marked white stones.", + "Play as Black. What's the best move to capture?", + ] + self._ko_point = None + super().__init__(config=config, seed=config.seed, size=config.size) + + # New helper method for board copying + def _copy_board(self, board: List[List[str]]) -> List[List[str]]: + """Return a deep copy of the board.""" + return [row[:] for row in board] + + def _get_liberties(self, board: List[List[str]], row: int, col: int) -> Set[Tuple[int, int]]: + """Get empty adjacent points (liberties) for a stone""" + size = len(board) + liberties = set() + for dr, dc in DIRECTIONS: + r, c = row + dr, col + dc + if 0 <= r < size and 0 <= c < size and board[r][c] == ".": + liberties.add((r, c)) + return liberties + + def _get_group(self, board: List[List[str]], row: int, col: int) -> Set[Tuple[int, int]]: + """Get all stones in the same group (connected stones of same color)""" + size = len(board) + color = board[row][col] + if color == ".": + return set() + + group = {(row, col)} + queue = [(row, col)] + while queue: + r, c = queue.pop(0) + for dr, dc in DIRECTIONS: + nr, nc = r + dr, c + dc + if 0 <= nr < size and 0 <= nc < size and board[nr][nc] == color and (nr, nc) not in group: + group.add((nr, nc)) + queue.append((nr, nc)) + return group + + def _count_liberties(self, board: List[List[str]], group: Set[Tuple[int, int]]) -> int: + """Count total liberties for a group of stones""" + liberties = set() + for row, col in group: + liberties.update(self._get_liberties(board, row, col)) + return len(liberties) + + def _would_capture(self, board: List[List[str]], row: int, col: int, color: str) -> bool: + """Check if a move would capture any opponent stones""" + size = len(board) + opponent = "O" if color == "X" else "X" + + # Make a copy of the board and place the stone + board_copy = self._copy_board(board) + board_copy[row][col] = color + + checked = set() + for dr, dc in DIRECTIONS: + r, c = row + dr, col + dc + if 0 <= r < size and 0 <= c < size and board_copy[r][c] == opponent and (r, c) not in checked: + group = self._get_group(board_copy, r, c) + checked.update(group) + if self._count_liberties(board_copy, group) == 0: + return True + return False + + def _is_valid_move(self, board: List[List[str]], row: int, col: int, color: str) -> bool: + """Check if a move is legal (not suicide, unless it captures)""" + size = len(board) + if not (0 <= row < size and 0 <= col < size): + return False + if board[row][col] != ".": + return False + if (row, col) == self._ko_point: + return False + + # If the move captures opponent stones, it's valid + if self._would_capture(board, row, col, color): + return True + + board_copy = self._copy_board(board) + board_copy[row][col] = color + group = self._get_group(board_copy, row, col) + return self._count_liberties(board_copy, group) > 0 + + def _make_move(self, board: List[List[str]], row: int, col: int, color: str) -> bool: + """Make a move and update ko point. Returns True if move was valid.""" + if not self._is_valid_move(board, row, col, color): + return False + + self._ko_point = None + board[row][col] = color + opponent = "O" if color == "X" else "X" + captured_stones = [] + + for dr, dc in DIRECTIONS: + r, c = row + dr, col + dc + if 0 <= r < len(board) and 0 <= c < len(board) and board[r][c] == opponent: + group = self._get_group(board, r, c) + if self._count_liberties(board, group) == 0: + captured_stones.extend(group) + + if len(captured_stones) == 1 and len(self._get_group(board, row, col)) == 1: + self._ko_point = captured_stones[0] + + for r, c in captured_stones: + board[r][c] = "." + + return True + + def _generate_capture_problem(self, size: int, rng: Random) -> Tuple[List[List[str]], str]: + """Generate a capture problem""" + board = [["." for _ in range(size)] for _ in range(size)] + stones_placed = 0 + max_stones = self.config.max_stones - 4 # Reserve space for capture setup + + while stones_placed < max_stones: + row = rng.randint(0, size - 1) + col = rng.randint(0, size - 1) + color = "X" if rng.random() < 0.5 else "O" + if board[row][col] == "." and self._is_valid_move(board, row, col, color): + self._make_move(board, row, col, color) + stones_placed += 1 + + tries = 0 + while tries < 50: + row = rng.randint(1, size - 2) + col = rng.randint(1, size - 2) + capture_neighbors = [(0, 0)] + DIRECTIONS # <-- incorporate (0,0) with the constant DIRECTIONS + if board[row][col] == "." and all(board[row + dr][col + dc] == "." for dr, dc in capture_neighbors): + board[row][col] = "O" + board[row - 1][col] = "O" + board[row + 1][col] = "O" + board[row][col - 1] = "O" + if self._is_valid_move(board, row, col + 1, "X"): + return board, f"{row+1},{col+2}" + tries += 1 + raise RuntimeError("Failed to generate a capture problem") + + def _board_to_string(self, board: List[List[str]]) -> str: + """Convert board to string representation""" + size = len(board) + # Column labels + cols = " " + " ".join(chr(ord("A") + i) for i in range(size)) + "\n" + # Board with row numbers + rows = [f"{size-i:2d} {' '.join(row)}" for i, row in enumerate(board)] + return cols + "\n".join(rows) + + def __getitem__(self, idx: int) -> dict: + """Generate a single Tsumego problem + + Returns: + dict with: + - "question": Problem description and board state + - "answer": Solution move(s) + - "metadata": Problem details and configuration + """ + rng = Random(self.seed + idx if self.seed is not None else None) + size = rng.randint(self.config.min_board_size, self.config.max_board_size) + + board, solution = self._generate_capture_problem(size, rng) + board_str = self._board_to_string(board) + + return { + "question": ( + rng.choice(self._prompt_templates) + "\n\n" + board_str + "\n\n" + "Specify your move in coordinates (e.g. 'C4' for column C, row 4)" + ), + "answer": solution, + "metadata": { + "board_size": size, + "board": board, + "solution": solution, + }, + } + + def score_answer(self, answer: Optional[str], metadata: Dict[str, any]) -> float: + """Score the answer against the solution""" + if answer is None: + return 0.0 + answer = answer.strip() + if not answer: + return 0.01 + try: + # Parse expected solution in the format "row,col" + expected_row, expected_col = map(int, metadata["solution"].split(",")) + except Exception: + return 0.01 + try: + if "," in answer: + # Assume numeric format: "row,col" + row, col = map(int, answer.split(",")) + else: + # Assume letter-number format, e.g. "C4" + import re + + m = re.match(r"^([A-Za-z])(\d+)$", answer) + if not m: + return 0.01 + col_letter, row_str = m.group(1), m.group(2) + row = int(row_str) + col = ord(col_letter.upper()) - ord("A") + 1 + if (row, col) == (expected_row, expected_col): + return 1.0 + board_size = metadata["board_size"] + if 1 <= row <= board_size and 1 <= col <= board_size: + return 0.05 + except Exception: + return 0.01 + return 0.01 + + +# Register the dataset +register_dataset("tsumego", TsumegoDataset, TsumegoConfig) diff --git a/tests/test_tsumego.py b/tests/test_tsumego.py new file mode 100644 index 00000000..86ac203f --- /dev/null +++ b/tests/test_tsumego.py @@ -0,0 +1,233 @@ +"""Tests for Ttsumego problem generation""" + +import pytest +from random import Random + +from reasoning_gym.games.tsumego import TsumegoConfig, TsumegoDataset + + +def test_config_validation(): + # Valid configuration + TsumegoConfig(min_board_size=9, max_board_size=13, max_stones=10, size=100, seed=42) + + # Invalid configurations + with pytest.raises(ValueError): + TsumegoConfig(min_board_size=4, max_board_size=13, max_stones=10) # min_board_size too low + with pytest.raises(ValueError): + TsumegoConfig(min_board_size=9, max_board_size=20, max_stones=10) # max_board_size too high + with pytest.raises(ValueError): + TsumegoConfig(min_board_size=13, max_board_size=9, max_stones=10) # min_board_size > max_board_size + with pytest.raises(ValueError): + TsumegoConfig(min_board_size=9, max_board_size=13, max_stones=2) # max_stones too low + + +def test_dataset_item_properties(): + config = TsumegoConfig(min_board_size=9, max_board_size=9, max_stones=15, size=100, seed=42) + dataset = TsumegoDataset(config) + item = dataset[0] + # Check that item has the required keys + for key in ["question", "answer", "metadata"]: + assert key in item + + metadata = item["metadata"] + for key in ["board_size", "board", "solution"]: + assert key in metadata + + board = metadata["board"] + # Board size should be equal to the fixed min_board_size for this test + assert len(board) == config.min_board_size + assert all(len(row) == config.min_board_size for row in board) + # Check stone count does not exceed max_stones + stone_count = sum(cell in "XO" for row in board for cell in row) + assert stone_count <= config.max_stones + + +def test_deterministic_generation(): + config = TsumegoConfig(min_board_size=9, max_board_size=9, max_stones=10, seed=42) + dataset1 = TsumegoDataset(config) + dataset2 = TsumegoDataset(config) + for i in range(3): + item1 = dataset1[i] + item2 = dataset2[i] + assert item1["metadata"]["board"] == item2["metadata"]["board"] + assert item1["answer"] == item2["answer"] + + +def test_liberties_and_move(): + # Use a small board for simplicity + config = TsumegoConfig(min_board_size=5, max_board_size=5, max_stones=10, size=10) + dataset = TsumegoDataset(config) + + # Part 1: Liberty counting test + board_liberties = [ + [".", "O", ".", ".", "."], + ["O", "X", "O", ".", "."], + [".", "O", ".", ".", "."], + [".", ".", ".", ".", "."], + [".", ".", ".", ".", "."], + ] + liberties = dataset._get_liberties(board_liberties, 1, 1) + assert len(liberties) == 0 + liberties_edge = dataset._get_liberties(board_liberties, 0, 1) + assert len(liberties_edge) == 2 + + # Part 2: Test capturing move + # Construct a board where an enemy stone at (2,2) is surrounded on three sides, + # so that placing an "X" at (2,3) will remove its last liberty and capture it. + board_capture = [["." for _ in range(5)] for _ in range(5)] + board_capture[1][2] = "X" + board_capture[2][1] = "X" + board_capture[3][2] = "X" + board_capture[2][2] = "O" + # Now, (2,2) (enemy) has only one liberty at (2,3). + # Placing "X" at (2,3) should capture the enemy stone. + assert dataset._is_valid_move(board_capture, 2, 3, "X") + dataset._make_move(board_capture, 2, 3, "X") + # After move, captured_stones should be [(2,2)] and ko point set to (2,2). + assert not dataset._is_valid_move(board_capture, 2, 2, "O"), "Ko move should be invalid" + + # Part 3: Test suicide move (without capture) + board_move = [ + [".", "O", ".", ".", "."], + ["O", ".", "O", ".", "."], + [".", "O", ".", ".", "."], + [".", ".", ".", ".", "."], + [".", ".", ".", ".", "."], + ] + # Placing "X" at (1,1) would be suicide as all adjacent positions are occupied by "O". + assert not dataset._is_valid_move(board_move, 1, 1, "X") + + +def test_score_answer(): + config = TsumegoConfig(min_board_size=9, max_board_size=9, max_stones=10) + dataset = TsumegoDataset(config) + metadata = {"board_size": 9, "solution": "5,5"} + + # Correct numeric answer + assert dataset.score_answer("5,5", metadata) == 1.0 + + # Correct letter-number answer (E corresponds to 5) + assert dataset.score_answer("E5", metadata) == 1.0 + + # Valid but incorrect numeric move + assert dataset.score_answer("4,4", metadata) == 0.05 + + # Valid but incorrect letter-number move (D corresponds to 4) + assert dataset.score_answer("D4", metadata) == 0.05 + + # Invalid format + assert dataset.score_answer("invalid", metadata) == 0.01 + + # Empty answer + assert dataset.score_answer("", metadata) == 0.01 + + # None answer + assert dataset.score_answer(None, metadata) == 0.0 + + # Out-of-bound letter-number move: 'J' corresponds to 10 which is greater than board size = 9 + assert dataset.score_answer("J9", metadata) == 0.01 + + +# Additional tests for game logic edge cases + + +def test_get_group(): + config = TsumegoConfig(min_board_size=5, max_board_size=5, max_stones=10, size=1, seed=42) + dataset = TsumegoDataset(config) + board = [ + ["X", "X", "."], + [".", "X", "O"], + [".", ".", "O"], + ] + group_X = dataset._get_group(board, 0, 0) + expected_group_X = {(0, 0), (0, 1), (1, 1)} + assert group_X == expected_group_X + + group_O = dataset._get_group(board, 1, 2) + expected_group_O = {(1, 2), (2, 2)} + assert group_O == expected_group_O + + +def test_count_liberties(): + config = TsumegoConfig(min_board_size=5, max_board_size=5, max_stones=10, size=1, seed=42) + dataset = TsumegoDataset(config) + board = [ + ["X", "X", "."], + [".", "X", "O"], + [".", ".", "O"], + ] + group_X = {(0, 0), (0, 1), (1, 1)} + liberties_X = dataset._count_liberties(board, group_X) + # For (0,0): neighbor (1,0); (0,1): neighbor (0,2); (1,1): neighbors (1,0) and (2,1) + # Combined unique liberties: {(1,0), (0,2), (2,1)} so count should be 3 + assert liberties_X == 3 + + +def test_out_of_bounds_move(): + config = TsumegoConfig(min_board_size=5, max_board_size=5, max_stones=10, size=1, seed=42) + dataset = TsumegoDataset(config) + board = [["." for _ in range(5)] for _ in range(5)] + # Test moves that are out of bounds + assert not dataset._is_valid_move(board, -1, 0, "X") + assert not dataset._is_valid_move(board, 0, -1, "X") + assert not dataset._is_valid_move(board, 5, 0, "X") + assert not dataset._is_valid_move(board, 0, 5, "X") + + +def test_move_on_occupied_intersection(): + config = TsumegoConfig(min_board_size=5, max_board_size=5, max_stones=10, size=1, seed=42) + dataset = TsumegoDataset(config) + board = [["." for _ in range(5)] for _ in range(5)] + board[1][1] = "X" + # Attempting to play on an occupied spot should be invalid + assert not dataset._is_valid_move(board, 1, 1, "O") + assert not dataset._is_valid_move(board, 1, 1, "X") + + +def test_valid_non_capturing_move(): + config = TsumegoConfig(min_board_size=5, max_board_size=5, max_stones=10, size=1, seed=42) + dataset = TsumegoDataset(config) + board = [["." for _ in range(5)] for _ in range(5)] + # A move on an empty board that doesn't result in capture or suicide should be valid + assert dataset._is_valid_move(board, 0, 0, "X") + move_result = dataset._make_move(board, 0, 0, "X") + assert move_result + assert board[0][0] == "X" + + +def test_multiple_capture(): + # Set up a board where a move will capture multiple opponent stones, + # which should not trigger the ko rule (ko point remains None) + config = TsumegoConfig(min_board_size=5, max_board_size=5, max_stones=10, size=1, seed=42) + dataset = TsumegoDataset(config) + board = [ + [".", ".", ".", ".", "."], + [".", "X", "X", "X", "."], + ["X", "O", "O", ".", "."], + [".", "X", "X", "X", "."], + [".", ".", ".", ".", "."], + ] + # Move at (2,3) with 'X' should capture the opponent stones at (2,1) and (2,2) + assert dataset._is_valid_move(board, 2, 3, "X") + move_result = dataset._make_move(board, 2, 3, "X") + assert move_result, "Move should be successfully made" + assert board[2][1] == ".", "Stone at (2,1) should be captured" + assert board[2][2] == ".", "Stone at (2,2) should be captured" + assert dataset._ko_point is None, "Ko point should not be set for multiple captures" + + +def test_would_capture(): + config = TsumegoConfig(min_board_size=5, max_board_size=5, max_stones=10, size=1, seed=42) + dataset = TsumegoDataset(config) + # Create a scenario similar to the one in test_liberties_and_move for capturing + board_capture = [["." for _ in range(5)] for _ in range(5)] + board_capture[1][2] = "X" + board_capture[2][1] = "X" + board_capture[3][2] = "X" + board_capture[2][2] = "O" + # Placing 'X' at (2,3) should capture the stone at (2,2) + assert dataset._would_capture(board_capture, 2, 3, "X") + # In a scenario with no capture, the move should not be considered capturing + board_no_capture = [["." for _ in range(5)] for _ in range(5)] + board_no_capture[2][2] = "O" + assert not dataset._would_capture(board_no_capture, 0, 0, "X")