reasoning-gym/reasoning_gym/games/sudoku.py
Oliver Stanley 7475a20700
include ranges rather than sampled values in difficulty metadata dicts (#387)
* update difficulty metadata for logic datasets

* update difficulty metadata for graph datasets

* update difficulty metadata for geometry datasets

* update difficulty metadata for games datasets

* update difficulty metadata for cognition datasets

* update difficulty metadata for arithmetic datasets

* update difficulty metadata for arc datasets

* update difficulty metadata for algorithmic datasets

* update difficulty metadata for algebra datasets

* use tuples

* update tests

* update tests
2025-03-20 10:27:03 +01:00

279 lines
9.7 KiB
Python

"""Sudoku puzzle generator"""
import copy
from dataclasses import dataclass
from random import Random
from typing import Any, Optional
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@dataclass
class SudokuConfig:
"""
Configuration for sudoku puzzle generation
Puzzle generation can be a bit slower for puzzles with a high (~60+) number of empty cells
"""
min_empty: int = 30 # Minimum number of empty cells
max_empty: int = 50 # Maximum number of empty cells
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
def validate(self):
"""Validate configuration parameters"""
# 81 - 64 = 17, the minimum number of clues required for 9x9 Sudoku to have a unique solution
assert 0 <= self.min_empty <= 64, "min_empty must be between 0 and 64"
assert self.min_empty <= self.max_empty <= 64, "max_empty must be between min_empty and 64"
class SudokuDataset(ProceduralDataset):
"""Generates sudoku puzzles with configurable difficulty"""
def __init__(self, config: SudokuConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def __len__(self) -> int:
return self.config.size
def __iter__(self):
self._current_idx = 0
return self
def __next__(self):
if self._current_idx >= self.config.size:
raise StopIteration
item = self[self._current_idx]
self._current_idx += 1
return item
def _is_valid(self, board: list[list[int]], row: int, col: int, num: int) -> bool:
"""Check if number can be placed at position"""
# Check row
if num in board[row]:
return False
# Check column
if num in [board[i][col] for i in range(9)]:
return False
# Check 3x3 box
box_row, box_col = 3 * (row // 3), 3 * (col // 3)
for i in range(box_row, box_row + 3):
for j in range(box_col, box_col + 3):
if board[i][j] == num:
return False
return True
def _get_possible_values(self, board: list[list[int]], row: int, col: int) -> set[int]:
"""Get all possible values for a cell."""
row_values = set(board[row])
col_values = set(board[i][col] for i in range(9))
# Get filled values in the current 3x3 subgrid
box_row, box_col = 3 * (row // 3), 3 * (col // 3)
box_values = set()
for i in range(box_row, box_row + 3):
for j in range(box_col, box_col + 3):
box_values.add(board[i][j])
used_values = row_values | col_values | box_values
return set(range(1, 10)) - used_values
def _solve(self, board: list[list[int]]) -> bool:
"""Solve sudoku using backtracking"""
empty = self._find_empty(board)
if not empty:
return True
row, col = empty
for num in self._get_possible_values(board, row, col):
board[row][col] = num
if self._solve(board):
return True
board[row][col] = 0
return False
def _find_empty(self, board: list[list[int]]) -> Optional[tuple[int, int]]:
"""Find an empty cell"""
for i in range(9):
for j in range(9):
if board[i][j] == 0:
return (i, j)
return None
def _generate_solved_board(self, rng: Random) -> list[list[int]]:
"""Generate a complete solved sudoku board"""
board = [[0] * 9 for _ in range(9)]
# Fill diagonal boxes first (they are independent)
for i in range(0, 9, 3):
nums = list(range(1, 10))
rng.shuffle(nums)
pos = 0
for r in range(i, i + 3):
for c in range(i, i + 3):
board[r][c] = nums[pos]
pos += 1
# Solve the rest
self._solve(board)
return board
def _count_solutions(self, board: list[list[int]], limit: int = 2) -> int:
"""Count the number of solutions for a given board"""
def _get_min_possibilities_cell(board: list[list[int]]) -> Optional[tuple[int, int, set[int]]]:
"""
Get the cell with the lowest number of possibilities.
Returns None if the board is already solved.
"""
min_possibilities = 10
min_cell = None
min_values = None
for i in range(9):
for j in range(9):
if board[i][j] == 0:
possible = self._get_possible_values(board, i, j)
if len(possible) < min_possibilities:
min_possibilities = len(possible)
min_cell = (i, j)
min_values = possible
if min_possibilities == 1:
return (*min_cell, min_values)
return (*min_cell, min_values) if min_cell else None
def _count_solutions_helper(board: list[list[int]]) -> int:
cell_info = _get_min_possibilities_cell(board)
if not cell_info:
return 1
row, col, possible_values = cell_info
count = 0
for num in possible_values:
board[row][col] = num
count += _count_solutions_helper(board)
if count >= limit:
return count
board[row][col] = 0
return count
return _count_solutions_helper(board)
def _create_puzzle(self, solved_board: list[list[int]], num_empty: int, rng: Random) -> list[list[int]]:
"""Create puzzle by removing numbers from solved board"""
puzzle = [row[:] for row in solved_board]
cells = [(i, j) for i in range(9) for j in range(9)]
rng.shuffle(cells)
num_removed = 0
for i, j in cells:
saved = puzzle[i][j]
puzzle[i][j] = 0
puzzle_copy = copy.deepcopy(puzzle)
# Check if removing this clue breaks uniqueness
if self._count_solutions(puzzle_copy) > 1:
puzzle[i][j] = saved
else:
num_removed += 1
if num_removed == num_empty:
break
return puzzle
def _board_to_string(self, board: list[list[int]]) -> str:
"""Convert board to string representation"""
return "\n".join(" ".join(str(x) if x != 0 else "_" for x in row) for row in board)
def __getitem__(self, idx: int) -> dict:
"""Generate a single sudoku puzzle"""
rng = Random(self.seed + idx)
# Generate solved board
solved_board = self._generate_solved_board(rng)
# Create puzzle by removing numbers
num_empty = rng.randint(self.config.min_empty, self.config.max_empty)
puzzle = self._create_puzzle(solved_board, num_empty, rng)
# Format as strings
puzzle_str = self._board_to_string(puzzle)
solution_str = self._board_to_string(solved_board)
question = (
f"Solve this Sudoku puzzle:\n{puzzle_str}\n"
"Respond with only your answer, formatted as the puzzle, a 9x9 grid with numbers separated by spaces, and rows separated by newlines."
)
return {
"question": question,
"answer": solution_str,
"metadata": {
"puzzle": puzzle,
"solution": solved_board,
"num_empty": num_empty,
"difficulty": {
"empty": (self.config.min_empty, self.config.max_empty),
},
},
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
if not isinstance(answer, str) or len(answer) == 0:
return 0.0
oracle_answer = entry["answer"]
metadata = entry["metadata"]
solution: list[list[int]] = metadata["solution"]
board_size: int = len(solution[0])
# 1. match answer without trailing whitespaces
answer_stripped = "\n".join(l.rstrip() for l in answer.split("\n"))
oracle_answer_stripped = "\n".join(l.rstrip() for l in oracle_answer.split("\n"))
if answer_stripped == oracle_answer_stripped:
reward = 1.0
else:
# 2. accept answers with correct numeric sequence (ignoring non-numeric characters)
row = 0
num_matching = 0
for ln in answer.split("\n"):
if row >= len(solution):
break
numbers = [int(c) for c in ln if c in "123456789"]
if len(numbers) != board_size:
continue # ignore lines without numbers
for a, b in zip(solution[row], numbers):
if a == b:
num_matching += 1
row += 1
reward = num_matching / (board_size * board_size)
reward *= 0.9 # penalty for not using standard format
if len(answer) > len(oracle_answer):
reward *= len(oracle_answer) / len(answer) # penalty for additional length
return reward
class SudokuCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(SudokuCurriculum.__name__, SudokuConfig)
# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="empty",
levels=[20, 30, 40, 50],
description="Number of empty cells in the puzzle",
lower_field_name="min_empty",
upper_field_name="max_empty",
ensure_interval=True,
)
)
register_dataset("sudoku", SudokuDataset, SudokuConfig, SudokuCurriculum)