reasoning-gym/reasoning_gym/games/mini_sudoku.py
Oliver Stanley 7475a20700
include ranges rather than sampled values in difficulty metadata dicts (#387)
* update difficulty metadata for logic datasets

* update difficulty metadata for graph datasets

* update difficulty metadata for geometry datasets

* update difficulty metadata for games datasets

* update difficulty metadata for cognition datasets

* update difficulty metadata for arithmetic datasets

* update difficulty metadata for arc datasets

* update difficulty metadata for algorithmic datasets

* update difficulty metadata for algebra datasets

* use tuples

* update tests

* update tests
2025-03-20 10:27:03 +01:00

260 lines
9.1 KiB
Python

"""Mini Sudoku (4x4) puzzle generator"""
import copy
from dataclasses import dataclass
from random import Random
from typing import Any, Optional
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@dataclass
class MiniSudokuConfig:
"""Configuration for 4x4 sudoku puzzle generation"""
min_empty: int = (
8 # Minimum number of empty cells. Occasionally this can be violated, if removing more cells would break the puzzle's uniqueness.
)
max_empty: int = 12 # Maximum number of empty cells
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
def validate(self):
"""Validate configuration parameters"""
# More than 12 empty cells is incompatible with a unique solution
assert 0 <= self.min_empty <= 12, "min_empty must be between 0 and 12"
assert self.min_empty <= self.max_empty <= 12, "max_empty must be between min_empty and 12"
class MiniSudokuDataset(ProceduralDataset):
"""Generates 4x4 sudoku puzzles with configurable difficulty"""
def __init__(self, config: MiniSudokuConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def __len__(self) -> int:
return self.config.size
def __iter__(self):
self._current_idx = 0
return self
def __next__(self):
if self._current_idx >= self.config.size:
raise StopIteration
item = self[self._current_idx]
self._current_idx += 1
return item
def _is_valid(self, board: list[list[int]], row: int, col: int, num: int) -> bool:
"""Check if number can be placed at position"""
# Check row
if num in board[row]:
return False
# Check column
if num in [board[i][col] for i in range(4)]:
return False
# Check 2x2 box
box_row, box_col = 2 * (row // 2), 2 * (col // 2)
for i in range(box_row, box_row + 2):
for j in range(box_col, box_col + 2):
if board[i][j] == num:
return False
return True
def _solve(self, board: list[list[int]]) -> bool:
"""Solve mini sudoku using backtracking"""
empty = self._find_empty(board)
if not empty:
return True
row, col = empty
for num in range(1, 5):
if self._is_valid(board, row, col, num):
board[row][col] = num
if self._solve(board):
return True
board[row][col] = 0
return False
def _find_empty(self, board: list[list[int]]) -> Optional[tuple[int, int]]:
"""Find an empty cell"""
for i in range(4):
for j in range(4):
if board[i][j] == 0:
return (i, j)
return None
def _generate_solved_board(self, rng: Random) -> list[list[int]]:
"""Generate a complete solved mini sudoku board"""
board = [[0] * 4 for _ in range(4)]
# Try multiple times to generate a valid board
max_attempts = 100
for _ in range(max_attempts):
# Start fresh
for i in range(4):
for j in range(4):
board[i][j] = 0
# Fill diagonal boxes first (they are independent)
for i in range(0, 4, 2):
nums = list(range(1, 5))
rng.shuffle(nums)
pos = 0
for r in range(i, i + 2):
for c in range(i, i + 2):
board[r][c] = nums[pos]
pos += 1
# Try to solve the rest
if self._solve(board):
return board
raise RuntimeError("Failed to generate valid mini sudoku board")
def _count_solutions(self, board: list[list[int]], limit: int = 2) -> int:
"""Count the number of solutions for a given board"""
def _count_solutions_helper(board: list[list[int]]) -> int:
empty = self._find_empty(board)
if not empty:
return 1
row, col = empty
count = 0
for num in range(1, 5):
if self._is_valid(board, row, col, num):
board[row][col] = num
count += _count_solutions_helper(board)
if count >= limit:
return count
board[row][col] = 0
return count
return _count_solutions_helper(board)
def _create_puzzle(self, solved_board: list[list[int]], num_empty: int, rng: Random) -> list[list[int]]:
"""Create puzzle by removing numbers from solved board"""
puzzle = [row[:] for row in solved_board]
cells = [(i, j) for i in range(4) for j in range(4)]
rng.shuffle(cells)
num_removed = 0
for i, j in cells:
saved = puzzle[i][j]
puzzle[i][j] = 0
puzzle_copy = copy.deepcopy(puzzle)
# Check if removing this clue breaks uniqueness
if self._count_solutions(puzzle_copy) > 1:
puzzle[i][j] = saved
else:
num_removed += 1
if num_removed == num_empty:
break
return puzzle
def _board_to_string(self, board: list[list[int]]) -> str:
"""Convert board to string representation"""
return "\n".join(" ".join(str(x) if x != 0 else "_" for x in row) for row in board)
def __getitem__(self, idx: int) -> dict:
"""Generate a single mini sudoku puzzle"""
rng = Random(self.seed + idx)
# Generate solved board
solved_board = self._generate_solved_board(rng)
# Create puzzle by removing numbers
num_empty = rng.randint(self.config.min_empty, self.config.max_empty)
puzzle = self._create_puzzle(solved_board, num_empty, rng)
# Update the num_empty to be used in the metadata if we couldn't remove as many as we wanted
num_empty = sum(1 for row in puzzle for x in row if x == 0)
# Format as strings
puzzle_str = self._board_to_string(puzzle)
solution_str = self._board_to_string(solved_board)
question = (
"In 4x4 Mini Sudoku:\n"
"- Each row must contain each number from 1-4 exactly once\n"
"- Each column must contain each number 1-4 exactly once\n"
"- Each 2x2 subgrid must contain each number 1-4 exactly once\n"
f"Solve this 4x4 Mini Sudoku puzzle:\n{puzzle_str}\n"
"Format your response as the puzzle above, with spaces separating each number within a row, and newlines separating rows.\n"
)
return {
"question": question,
"answer": solution_str,
"metadata": {
"puzzle": puzzle,
"solution": solved_board,
"num_empty": num_empty,
"difficulty": {
"empty": (self.config.min_empty, self.config.max_empty),
},
},
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
if not isinstance(answer, str) or len(answer) == 0:
return 0.0
oracle_answer = entry["answer"]
metadata = entry["metadata"]
solution: list[list[int]] = metadata["solution"]
board_size: int = len(solution[0])
# 1. match answer without trailing whitespaces
answer_stripped = "\n".join(l.rstrip() for l in answer.split("\n"))
oracle_answer_stripped = "\n".join(l.rstrip() for l in oracle_answer.split("\n"))
if answer_stripped == oracle_answer_stripped:
reward = 1.0
else:
# 2. accept answers with correct numeric sequence (ignoring non-numeric characters)
row = 0
num_matching = 0
for ln in answer.split("\n"):
if row >= len(solution):
break
numbers = [int(c) for c in ln if c in "123456789"]
if len(numbers) != board_size:
continue # ignore lines without numbers
for a, b in zip(solution[row], numbers):
if a == b:
num_matching += 1
row += 1
reward = num_matching / (board_size * board_size)
reward *= 0.9 # penalty for not using standard format
if len(answer) > len(oracle_answer):
reward *= len(oracle_answer) / len(answer) # penalty for additional length
return reward
class MiniSudokuCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(MiniSudokuCurriculum.__name__, MiniSudokuConfig)
# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="empty",
levels=[4, 6, 8, 10],
description="Number of empty cells in the puzzle",
lower_field_name="min_empty",
upper_field_name="max_empty",
ensure_interval=True,
)
)
register_dataset("mini_sudoku", MiniSudokuDataset, MiniSudokuConfig, MiniSudokuCurriculum)