mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-24 17:05:03 +00:00
Greatly speed up solver
This commit is contained in:
parent
145ceeb109
commit
dce5d9367d
1 changed files with 307 additions and 86 deletions
|
|
@ -2,10 +2,9 @@
|
|||
|
||||
import copy
|
||||
import itertools
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
|
@ -59,7 +58,7 @@ class FutoshikiDataset(ProceduralDataset):
|
|||
# Add random adjacency constraints consistent with generated solved grid
|
||||
constraints = self._generate_random_constraints(solution, self.config.difficulty, rng)
|
||||
# Starting with full solution, remove clues to desired difficulty
|
||||
puzzle = self._remove_clues(copy.deepcopy(solution), constraints, self.config.difficulty, rng)
|
||||
puzzle = self._remove_clues(copy.deepcopy(solution), constraints, rng)
|
||||
|
||||
# Format as strings
|
||||
puzzle_str = self._puzzle_to_string(puzzle, constraints)
|
||||
|
|
@ -82,6 +81,10 @@ class FutoshikiDataset(ProceduralDataset):
|
|||
puzzle_grid: List[List[int]],
|
||||
constraints: Dict[Tuple[Tuple[int, int], Tuple[int, int]], str]
|
||||
) -> str:
|
||||
"""
|
||||
Formats a Futoshiki puzzle grid as a string with constraints.
|
||||
Constraints are represented as '<', '>', '\u2227', or '\u2228' between adjacent cells.
|
||||
"""
|
||||
n = len(puzzle_grid)
|
||||
|
||||
def cell_str(val: int) -> str:
|
||||
|
|
@ -150,24 +153,261 @@ class FutoshikiDataset(ProceduralDataset):
|
|||
|
||||
return "\n".join(lines)
|
||||
|
||||
# currently this gets a bit slow for larger grid sizes as it relies on brute force backtracking
|
||||
# possible improvements: implement optimisations, using common rules in Futoshiki to reduce search space
|
||||
# see: https://www.futoshiki.com/how-to-solve
|
||||
# also see other solvers' approaches e.g. https://github.com/davidswarbrick/futoshiki-solver/blob/master/Futoshiki.py
|
||||
# however I attempted some optimisations based on the code of the above parser, such as the recursive constraint following, and it was actually quite a lot slower
|
||||
def _solve_logical(
|
||||
self,
|
||||
grid: List[List[int]],
|
||||
constraints: Dict[Tuple[Tuple[int, int], Tuple[int, int]], str],
|
||||
) -> Tuple[List[List[int]], List[List[Set[int]]]]:
|
||||
"""
|
||||
Apply logical rules to progress solution.
|
||||
Returns current state if logical rules can't progress further.
|
||||
Logical rules are implemented based on the descriptions here: https://futoshiki.uk/
|
||||
"""
|
||||
size, working_grid = len(grid), copy.deepcopy(grid)
|
||||
|
||||
# Starting point all numbers are candidates for all unfilled squares
|
||||
candidates: List[List[Set[int]]] = [
|
||||
[
|
||||
set(range(1, len(grid) + 1)) if grid[r][c] == 0 else {grid[r][c]}
|
||||
for c in range(len(grid))
|
||||
]
|
||||
for r in range(len(grid))
|
||||
]
|
||||
|
||||
# Any cells > another cannot be 1, and any cells < another cannot be `size`
|
||||
# This is separated from the repeated function below to avoid redundant checks
|
||||
for ((r1, c1), (_, _)), rel in constraints.items():
|
||||
if rel == ">":
|
||||
candidates[r1][c1].discard(1)
|
||||
elif rel == "<":
|
||||
candidates[r1][c1].discard(size)
|
||||
|
||||
def _update_grid():
|
||||
"""Update solution wherever a cell's candidates set is reduced to a single element."""
|
||||
for r in range(len(working_grid)):
|
||||
for c in range(len(working_grid)):
|
||||
if working_grid[r][c] == 0 and len(candidates[r][c]) == 1:
|
||||
working_grid[r][c] = next(iter(candidates[r][c]))
|
||||
|
||||
def _try_solve_logical() -> bool:
|
||||
progress = False
|
||||
|
||||
# Eliminate candidates based on numbers already placed
|
||||
# If a number is placed in a cell, it cannot be a candidate for any other cell in the same row or column
|
||||
for r in range(len(working_grid)):
|
||||
for c in range(len(working_grid)):
|
||||
if working_grid[r][c] == 0:
|
||||
continue
|
||||
for cc in range(len(working_grid)):
|
||||
if cc != c and working_grid[r][c] in candidates[r][cc]:
|
||||
candidates[r][cc].discard(working_grid[r][c])
|
||||
progress = True
|
||||
for rr in range(len(working_grid)):
|
||||
if rr != r and working_grid[r][c] in candidates[rr][c]:
|
||||
candidates[rr][c].discard(working_grid[r][c])
|
||||
progress = True
|
||||
|
||||
_update_grid()
|
||||
|
||||
# Eliminate candidates based on constraints
|
||||
# Based on currently filled values, eliminate candidates that violate constraints
|
||||
def _eliminate_by_constraint(rc_less: Tuple[int, int], rc_greater: Tuple[int, int]) -> bool:
|
||||
r_less, c_less = rc_less
|
||||
r_greater, c_greater = rc_greater
|
||||
progress = False
|
||||
|
||||
if working_grid[r_less][c_less] != 0:
|
||||
# greater must only have candidates > less
|
||||
for v in candidates[r_greater][c_greater].copy():
|
||||
if v <= working_grid[r_less][c_less] and v in candidates[r_greater][c_greater]:
|
||||
candidates[r_greater][c_greater].discard(v)
|
||||
progress = True
|
||||
|
||||
if working_grid[r_greater][c_greater] != 0:
|
||||
# less must only have candidates < greater
|
||||
for v in candidates[r_less][c_less].copy():
|
||||
if v >= working_grid[r_greater][c_greater] and v in candidates[r_less][c_less]:
|
||||
candidates[r_less][c_less].discard(v)
|
||||
progress = True
|
||||
|
||||
return progress
|
||||
|
||||
for ((r1, c1), (r2, c2)), rel in constraints.items():
|
||||
v1, v2 = working_grid[r1][c1], working_grid[r2][c2]
|
||||
if v1 != 0 and v2 != 0: # both already filled, skip
|
||||
continue
|
||||
if rel == "<":
|
||||
progress |= _eliminate_by_constraint((r1, c1), (r2, c2))
|
||||
elif rel == ">":
|
||||
progress |= _eliminate_by_constraint((r2, c2), (r1, c1))
|
||||
|
||||
_update_grid()
|
||||
|
||||
# Seek "hidden singles" - cells where a candidate is unique in the row or column
|
||||
for r in range(len(working_grid)):
|
||||
for c in range(len(working_grid)):
|
||||
if working_grid[r][c] != 0:
|
||||
continue
|
||||
for v in candidates[r][c]:
|
||||
if sum(v in candidates[r][cc] for cc in range(len(working_grid))) == 1:
|
||||
candidates[r][c] = {v} # candidate unique in row
|
||||
break
|
||||
if sum(v in candidates[rr][c] for rr in range(len(working_grid))) == 1:
|
||||
candidates[r][c] = {v} # candidate unique in column
|
||||
break
|
||||
|
||||
_update_grid()
|
||||
|
||||
# Seek "naked pairs" if same pair of candidates twice in a row or col, with nothing else in those two cells
|
||||
# Remove them from other cells in row/col
|
||||
for r in range(len(working_grid)):
|
||||
for c in range(len(working_grid)):
|
||||
if working_grid[r][c] != 0 or len(candidates[r][c]) != 2:
|
||||
continue
|
||||
for cc in range(len(working_grid)):
|
||||
if cc != c and candidates[r][cc] == candidates[r][c]:
|
||||
for ccc in range(len(working_grid)):
|
||||
if ccc != c and ccc != cc and candidates[r][c].intersection(candidates[r][ccc]):
|
||||
candidates[r][ccc] -= candidates[r][c]
|
||||
progress = True
|
||||
for rr in range(len(working_grid)):
|
||||
if rr != r and candidates[rr][c] == candidates[r][c]:
|
||||
for rrr in range(len(working_grid)):
|
||||
if rrr != r and rrr != rr and candidates[r][c].intersection(candidates[rrr][c]):
|
||||
candidates[rrr][c] -= candidates[r][c]
|
||||
progress = True
|
||||
|
||||
_update_grid()
|
||||
|
||||
# Seek "hidden pairs" - same pair of candidates occurs in two cells in a line, but nowhere else in the line
|
||||
# alongside other candidates (hence hidden). All other candidates can be removed from those two cells
|
||||
for r in range(len(working_grid)):
|
||||
for c in range(len(working_grid)):
|
||||
if working_grid[r][c] != 0:
|
||||
continue
|
||||
for cc in range(c + 1, len(working_grid)):
|
||||
if working_grid[r][cc] != 0:
|
||||
continue
|
||||
# Check if r, c shares a candidate pair with r, cc (maybe subset, not exact candidate set match)
|
||||
r_c_pairs = itertools.permutations(candidates[r][c], 2)
|
||||
r_cc_pairs = itertools.permutations(candidates[r][cc], 2)
|
||||
for pair in r_c_pairs:
|
||||
if pair not in r_cc_pairs:
|
||||
continue
|
||||
otherwise_unique = True
|
||||
# If this pair occurs elsewhere in the row, skip
|
||||
for ccc in range(len(working_grid)):
|
||||
if ccc in (c, cc):
|
||||
continue
|
||||
if pair in itertools.permutations(candidates[r][ccc], 2):
|
||||
otherwise_unique = False
|
||||
break
|
||||
if not otherwise_unique:
|
||||
continue
|
||||
# Found a hidden pair, remove all other candidates from these two cells
|
||||
candidates[r][c] = set(pair)
|
||||
candidates[r][cc] = set(pair)
|
||||
|
||||
for rr in range(r + 1, len(working_grid)):
|
||||
if working_grid[rr][c] != 0:
|
||||
continue
|
||||
# Check if r, c shares a candidate pair with rr, c (maybe subset, not exact candidate set match)
|
||||
r_c_pairs = itertools.permutations(candidates[r][c], 2)
|
||||
rr_c_pairs = itertools.permutations(candidates[rr][c], 2)
|
||||
for pair in r_c_pairs:
|
||||
if pair not in rr_c_pairs:
|
||||
continue
|
||||
otherwise_unique = True
|
||||
# If this pair occurs elsewhere in the col, skip
|
||||
for rrr in range(len(working_grid)):
|
||||
if rrr in (r, rr):
|
||||
continue
|
||||
if pair in itertools.permutations(candidates[rrr][c], 2):
|
||||
otherwise_unique = False
|
||||
break
|
||||
if not otherwise_unique:
|
||||
continue
|
||||
# Found a hidden pair, remove all other candidates from these two cells
|
||||
candidates[r][c] = set(pair)
|
||||
candidates[rr][c] = set(pair)
|
||||
|
||||
_update_grid()
|
||||
|
||||
# Seek X-wings by rows
|
||||
for v in range(1, size + 1):
|
||||
# If candidate is in the same 2 positions in 2 rows, and nowhere else in those rows
|
||||
# Delete from the 2 intersecting cols
|
||||
|
||||
# Find rows which have exactly 2 instances of the value in their candidate sets
|
||||
rows_with_v = [r for r in range(size) if sum(v in candidates[r][c] for c in range(size)) == 2]
|
||||
if len(rows_with_v) < 2:
|
||||
continue
|
||||
# Check whether the 2 columns with the value are the same in the 2 rows
|
||||
cols_with_v_per_row = [set() for _ in range(len(rows_with_v))]
|
||||
for i, r in enumerate(rows_with_v):
|
||||
for c in range(size):
|
||||
if v in candidates[r][c]:
|
||||
cols_with_v_per_row[i].add(c)
|
||||
# Check if there are a pair of tows with the same cols (there may be more than 2 rows)
|
||||
for i in range(len(rows_with_v)):
|
||||
for j in range(i + 1, len(rows_with_v)):
|
||||
if cols_with_v_per_row[i] == cols_with_v_per_row[j]:
|
||||
# Found an X-wing, remove candidate from the 2 cols
|
||||
for c in cols_with_v_per_row[i]:
|
||||
for rr in range(size):
|
||||
if rr not in (rows_with_v[i], rows_with_v[j]) and v in candidates[rr][c]:
|
||||
candidates[rr][c].discard(v)
|
||||
progress = True
|
||||
|
||||
# Seek X-wings by cols
|
||||
for v in range(1, size + 1):
|
||||
# If candidate is in the same 2 positions in 2 cols, and nowhere else in those cols
|
||||
# Delete from the 2 intersecting rows
|
||||
|
||||
# Find cols which have exactly 2 instances of the value in their candidate sets
|
||||
cols_with_v = [c for c in range(size) if sum(v in candidates[r][c] for r in range(size)) == 2]
|
||||
if len(cols_with_v) < 2:
|
||||
continue
|
||||
# Check whether the 2 rows with the value are the same in the 2 cols
|
||||
rows_with_v_per_col = [set() for _ in range(len(cols_with_v))]
|
||||
for i, c in enumerate(cols_with_v):
|
||||
for r in range(size):
|
||||
if v in candidates[r][c]:
|
||||
rows_with_v_per_col[i].add(r)
|
||||
# Check if there are a pair of cols with the same rows (there may be more than 2 cols)
|
||||
for i in range(len(cols_with_v)):
|
||||
for j in range(i + 1, len(cols_with_v)):
|
||||
if rows_with_v_per_col[i] == rows_with_v_per_col[j]:
|
||||
# Found an X-wing, remove candidate from the 2 rows
|
||||
for r in rows_with_v_per_col[i]:
|
||||
for cc in range(size):
|
||||
if cc not in (cols_with_v[i], cols_with_v[j]) and v in candidates[r][cc]:
|
||||
candidates[r][cc].discard(v)
|
||||
progress = True
|
||||
|
||||
_update_grid()
|
||||
|
||||
return progress
|
||||
|
||||
while _try_solve_logical():
|
||||
continue
|
||||
|
||||
return working_grid, candidates
|
||||
|
||||
def _solve(
|
||||
self,
|
||||
grid: List[List[int]],
|
||||
constraints: Dict[Tuple[Tuple[int, int], Tuple[int, int]], str],
|
||||
rng: Random,
|
||||
find_multiple: bool = False,
|
||||
) -> List[List[int]] | None:
|
||||
"""
|
||||
Backtracking Futoshiki solver. Used to verify generated puzzles.
|
||||
Applies logical rules first then backtracks to fill gaps.
|
||||
Return solved grid, or None if unsolvable.
|
||||
If find_multiple: also return None if more than one solution found.
|
||||
"""
|
||||
|
||||
grid, candidates = self._solve_logical(grid, constraints)
|
||||
|
||||
size = len(grid)
|
||||
empty_cell = None
|
||||
|
||||
|
|
@ -179,76 +419,25 @@ class FutoshikiDataset(ProceduralDataset):
|
|||
break
|
||||
if empty_cell:
|
||||
break
|
||||
|
||||
|
||||
# If no empty cell, solution is complete
|
||||
if not empty_cell:
|
||||
return copy.deepcopy(grid)
|
||||
|
||||
r, c = empty_cell
|
||||
for val in range(1, size + 1):
|
||||
if val not in candidates[r][c]:
|
||||
continue
|
||||
if self._is_valid(grid, r, c, val, constraints):
|
||||
grid[r][c] = val
|
||||
solution = self._solve(grid, constraints, rng, find_multiple)
|
||||
solution = self._solve(grid, constraints, rng)
|
||||
if solution is not None:
|
||||
# If find_multiple, continue searching to check for non-uniqueness
|
||||
if find_multiple and self._has_other_solution(solution, grid, constraints, rng):
|
||||
grid[r][c] = 0
|
||||
return None
|
||||
|
||||
grid[r][c] = 0
|
||||
return solution
|
||||
grid[r][c] = 0
|
||||
|
||||
|
||||
return None
|
||||
|
||||
def _has_other_solution(
|
||||
self,
|
||||
existing_solution: List[List[int]],
|
||||
partial_grid: List[List[int]],
|
||||
constraints: Dict[Tuple[Tuple[int, int], Tuple[int, int]], str],
|
||||
rng: Random,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if there's at least one solution different from existing_solution, given the partial_grid so far.
|
||||
This is a quick hack: we attempt to find another solution with a slight difference.
|
||||
A full approach is backtracking that tries to find any solution differing from existing_solution in >= 1 cell.
|
||||
"""
|
||||
# Each cell not set in partial_grid could be varied
|
||||
size = len(existing_solution)
|
||||
# Make a fresh puzzle using partial_grid
|
||||
puzzle_copy = copy.deepcopy(partial_grid)
|
||||
|
||||
def backtrack(i = 0, j = 0) -> bool:
|
||||
# Move past end of row
|
||||
if j == size:
|
||||
i += 1
|
||||
j = 0
|
||||
# Completed all rows
|
||||
if i == size:
|
||||
# Confirm puzzle_copy differs in at least one cell from existing_solution
|
||||
for rr in range(size):
|
||||
for cc in range(size):
|
||||
if puzzle_copy[rr][cc] != existing_solution[rr][cc]:
|
||||
return True
|
||||
return False
|
||||
|
||||
if puzzle_copy[i][j] != 0:
|
||||
# Move on
|
||||
return backtrack(i, j + 1)
|
||||
else:
|
||||
# Try different values
|
||||
vals = list(range(1, size + 1))
|
||||
rng.shuffle(vals)
|
||||
for val in vals:
|
||||
if self._is_valid(puzzle_copy, i, j, val, constraints):
|
||||
puzzle_copy[i][j] = val
|
||||
if backtrack(i, j + 1):
|
||||
return True
|
||||
puzzle_copy[i][j] = 0
|
||||
return False
|
||||
|
||||
return backtrack(0, 0)
|
||||
|
||||
def _is_valid(
|
||||
self,
|
||||
grid: List[List[int]],
|
||||
|
|
@ -334,7 +523,7 @@ class FutoshikiDataset(ProceduralDataset):
|
|||
constraints = {}
|
||||
# For each pair of adjacent cells, we might add a constraint
|
||||
# P(adding a constraint) increases with difficulty on an arbitrary scale
|
||||
base_prob = 0.05 + 0.05 * difficulty
|
||||
base_prob = 0.03 + 0.07 * difficulty
|
||||
for r in range(size):
|
||||
for c in range(size):
|
||||
# Horizontal neighbor
|
||||
|
|
@ -353,44 +542,76 @@ class FutoshikiDataset(ProceduralDataset):
|
|||
constraints[((r, c), (r + 1, c))] = ">"
|
||||
return constraints
|
||||
|
||||
def count_solutions(self, grid, constraints, limit=2) -> int:
|
||||
size = len(grid)
|
||||
count = 0
|
||||
|
||||
def backtrack():
|
||||
nonlocal count
|
||||
# Early exit if limit reached
|
||||
if count >= limit:
|
||||
return
|
||||
# Find the next empty cell
|
||||
for r in range(size):
|
||||
for c in range(size):
|
||||
if grid[r][c] == 0:
|
||||
for val in range(1, size + 1):
|
||||
if self._is_valid(grid, r, c, val, constraints):
|
||||
grid[r][c] = val
|
||||
backtrack()
|
||||
grid[r][c] = 0
|
||||
return
|
||||
count += 1
|
||||
|
||||
backtrack()
|
||||
return count
|
||||
|
||||
def _remove_clues(
|
||||
self,
|
||||
grid: List[List[int]],
|
||||
constraints: Dict[Tuple[Tuple[int, int], Tuple[int, int]], str],
|
||||
difficulty: int,
|
||||
rng: Random,
|
||||
) -> List[List[int]]:
|
||||
"""
|
||||
Remove clues from a full solution to try to maintain a unique-solution puzzle.
|
||||
The higher the difficulty, the more clues we remove.
|
||||
We remove in random order until we reach our target, or can't without losing uniqueness.
|
||||
"""
|
||||
size = len(grid)
|
||||
fill_fraction = [0.09, 0.07, 0.05, 0.03] # Easiest -> hardest
|
||||
target_filled = int(fill_fraction[difficulty] * (size * size))
|
||||
fill_fraction = 0.1
|
||||
target_filled = int(fill_fraction * (size * size))
|
||||
|
||||
coords = [(r, c) for r in range(size) for c in range(size)]
|
||||
rng.shuffle(coords)
|
||||
|
||||
def count_filled_cells(g):
|
||||
def _count_filled_cells(g):
|
||||
return sum(g[r][c] != 0 for r in range(size) for c in range(size))
|
||||
|
||||
for (r,c) in coords:
|
||||
if count_filled_cells(grid) <= target_filled:
|
||||
break # Removal target hit
|
||||
def _try_remove():
|
||||
for (r,c) in coords:
|
||||
if _count_filled_cells(grid) <= target_filled:
|
||||
break # Removal target hit
|
||||
|
||||
saved = grid[r][c]
|
||||
if saved == 0:
|
||||
continue
|
||||
# Try remove
|
||||
grid[r][c] = 0
|
||||
saved = grid[r][c]
|
||||
if saved == 0:
|
||||
continue
|
||||
# Try remove
|
||||
grid[r][c] = 0
|
||||
|
||||
# Check if unsolvable or non-unique
|
||||
puzzle_copy = copy.deepcopy(grid)
|
||||
sol = self._solve(puzzle_copy, constraints, rng, find_multiple=True)
|
||||
if sol is None:
|
||||
# Not solvable or non-unique, revert
|
||||
grid[r][c] = saved
|
||||
# Check if unsolvable
|
||||
sol = self._solve(copy.deepcopy(grid), constraints, rng)
|
||||
if sol is None:
|
||||
grid[r][c] = saved
|
||||
continue
|
||||
# Check if not unique
|
||||
if self.count_solutions(copy.deepcopy(grid), constraints, limit=2) > 1:
|
||||
grid[r][c] = saved
|
||||
|
||||
_try_remove()
|
||||
|
||||
# Second pass if we aren't close to target
|
||||
if _count_filled_cells(grid) > 2 * target_filled:
|
||||
rng.shuffle(coords)
|
||||
_try_remove()
|
||||
|
||||
return grid
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue