Greatly speed up solver

This commit is contained in:
Oliver 2025-02-09 21:23:53 +00:00
parent 145ceeb109
commit dce5d9367d

View file

@ -2,10 +2,9 @@
import copy import copy
import itertools import itertools
import random
from dataclasses import dataclass from dataclasses import dataclass
from random import Random from random import Random
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Set, Tuple
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
@ -59,7 +58,7 @@ class FutoshikiDataset(ProceduralDataset):
# Add random adjacency constraints consistent with generated solved grid # Add random adjacency constraints consistent with generated solved grid
constraints = self._generate_random_constraints(solution, self.config.difficulty, rng) constraints = self._generate_random_constraints(solution, self.config.difficulty, rng)
# Starting with full solution, remove clues to desired difficulty # Starting with full solution, remove clues to desired difficulty
puzzle = self._remove_clues(copy.deepcopy(solution), constraints, self.config.difficulty, rng) puzzle = self._remove_clues(copy.deepcopy(solution), constraints, rng)
# Format as strings # Format as strings
puzzle_str = self._puzzle_to_string(puzzle, constraints) puzzle_str = self._puzzle_to_string(puzzle, constraints)
@ -82,6 +81,10 @@ class FutoshikiDataset(ProceduralDataset):
puzzle_grid: List[List[int]], puzzle_grid: List[List[int]],
constraints: Dict[Tuple[Tuple[int, int], Tuple[int, int]], str] constraints: Dict[Tuple[Tuple[int, int], Tuple[int, int]], str]
) -> str: ) -> str:
"""
Formats a Futoshiki puzzle grid as a string with constraints.
Constraints are represented as '<', '>', '\u2227', or '\u2228' between adjacent cells.
"""
n = len(puzzle_grid) n = len(puzzle_grid)
def cell_str(val: int) -> str: def cell_str(val: int) -> str:
@ -150,24 +153,261 @@ class FutoshikiDataset(ProceduralDataset):
return "\n".join(lines) return "\n".join(lines)
# currently this gets a bit slow for larger grid sizes as it relies on brute force backtracking def _solve_logical(
# possible improvements: implement optimisations, using common rules in Futoshiki to reduce search space self,
# see: https://www.futoshiki.com/how-to-solve grid: List[List[int]],
# also see other solvers' approaches e.g. https://github.com/davidswarbrick/futoshiki-solver/blob/master/Futoshiki.py constraints: Dict[Tuple[Tuple[int, int], Tuple[int, int]], str],
# however I attempted some optimisations based on the code of the above parser, such as the recursive constraint following, and it was actually quite a lot slower ) -> Tuple[List[List[int]], List[List[Set[int]]]]:
"""
Apply logical rules to progress solution.
Returns current state if logical rules can't progress further.
Logical rules are implemented based on the descriptions here: https://futoshiki.uk/
"""
size, working_grid = len(grid), copy.deepcopy(grid)
# Starting point all numbers are candidates for all unfilled squares
candidates: List[List[Set[int]]] = [
[
set(range(1, len(grid) + 1)) if grid[r][c] == 0 else {grid[r][c]}
for c in range(len(grid))
]
for r in range(len(grid))
]
# Any cells > another cannot be 1, and any cells < another cannot be `size`
# This is separated from the repeated function below to avoid redundant checks
for ((r1, c1), (_, _)), rel in constraints.items():
if rel == ">":
candidates[r1][c1].discard(1)
elif rel == "<":
candidates[r1][c1].discard(size)
def _update_grid():
"""Update solution wherever a cell's candidates set is reduced to a single element."""
for r in range(len(working_grid)):
for c in range(len(working_grid)):
if working_grid[r][c] == 0 and len(candidates[r][c]) == 1:
working_grid[r][c] = next(iter(candidates[r][c]))
def _try_solve_logical() -> bool:
progress = False
# Eliminate candidates based on numbers already placed
# If a number is placed in a cell, it cannot be a candidate for any other cell in the same row or column
for r in range(len(working_grid)):
for c in range(len(working_grid)):
if working_grid[r][c] == 0:
continue
for cc in range(len(working_grid)):
if cc != c and working_grid[r][c] in candidates[r][cc]:
candidates[r][cc].discard(working_grid[r][c])
progress = True
for rr in range(len(working_grid)):
if rr != r and working_grid[r][c] in candidates[rr][c]:
candidates[rr][c].discard(working_grid[r][c])
progress = True
_update_grid()
# Eliminate candidates based on constraints
# Based on currently filled values, eliminate candidates that violate constraints
def _eliminate_by_constraint(rc_less: Tuple[int, int], rc_greater: Tuple[int, int]) -> bool:
r_less, c_less = rc_less
r_greater, c_greater = rc_greater
progress = False
if working_grid[r_less][c_less] != 0:
# greater must only have candidates > less
for v in candidates[r_greater][c_greater].copy():
if v <= working_grid[r_less][c_less] and v in candidates[r_greater][c_greater]:
candidates[r_greater][c_greater].discard(v)
progress = True
if working_grid[r_greater][c_greater] != 0:
# less must only have candidates < greater
for v in candidates[r_less][c_less].copy():
if v >= working_grid[r_greater][c_greater] and v in candidates[r_less][c_less]:
candidates[r_less][c_less].discard(v)
progress = True
return progress
for ((r1, c1), (r2, c2)), rel in constraints.items():
v1, v2 = working_grid[r1][c1], working_grid[r2][c2]
if v1 != 0 and v2 != 0: # both already filled, skip
continue
if rel == "<":
progress |= _eliminate_by_constraint((r1, c1), (r2, c2))
elif rel == ">":
progress |= _eliminate_by_constraint((r2, c2), (r1, c1))
_update_grid()
# Seek "hidden singles" - cells where a candidate is unique in the row or column
for r in range(len(working_grid)):
for c in range(len(working_grid)):
if working_grid[r][c] != 0:
continue
for v in candidates[r][c]:
if sum(v in candidates[r][cc] for cc in range(len(working_grid))) == 1:
candidates[r][c] = {v} # candidate unique in row
break
if sum(v in candidates[rr][c] for rr in range(len(working_grid))) == 1:
candidates[r][c] = {v} # candidate unique in column
break
_update_grid()
# Seek "naked pairs" if same pair of candidates twice in a row or col, with nothing else in those two cells
# Remove them from other cells in row/col
for r in range(len(working_grid)):
for c in range(len(working_grid)):
if working_grid[r][c] != 0 or len(candidates[r][c]) != 2:
continue
for cc in range(len(working_grid)):
if cc != c and candidates[r][cc] == candidates[r][c]:
for ccc in range(len(working_grid)):
if ccc != c and ccc != cc and candidates[r][c].intersection(candidates[r][ccc]):
candidates[r][ccc] -= candidates[r][c]
progress = True
for rr in range(len(working_grid)):
if rr != r and candidates[rr][c] == candidates[r][c]:
for rrr in range(len(working_grid)):
if rrr != r and rrr != rr and candidates[r][c].intersection(candidates[rrr][c]):
candidates[rrr][c] -= candidates[r][c]
progress = True
_update_grid()
# Seek "hidden pairs" - same pair of candidates occurs in two cells in a line, but nowhere else in the line
# alongside other candidates (hence hidden). All other candidates can be removed from those two cells
for r in range(len(working_grid)):
for c in range(len(working_grid)):
if working_grid[r][c] != 0:
continue
for cc in range(c + 1, len(working_grid)):
if working_grid[r][cc] != 0:
continue
# Check if r, c shares a candidate pair with r, cc (maybe subset, not exact candidate set match)
r_c_pairs = itertools.permutations(candidates[r][c], 2)
r_cc_pairs = itertools.permutations(candidates[r][cc], 2)
for pair in r_c_pairs:
if pair not in r_cc_pairs:
continue
otherwise_unique = True
# If this pair occurs elsewhere in the row, skip
for ccc in range(len(working_grid)):
if ccc in (c, cc):
continue
if pair in itertools.permutations(candidates[r][ccc], 2):
otherwise_unique = False
break
if not otherwise_unique:
continue
# Found a hidden pair, remove all other candidates from these two cells
candidates[r][c] = set(pair)
candidates[r][cc] = set(pair)
for rr in range(r + 1, len(working_grid)):
if working_grid[rr][c] != 0:
continue
# Check if r, c shares a candidate pair with rr, c (maybe subset, not exact candidate set match)
r_c_pairs = itertools.permutations(candidates[r][c], 2)
rr_c_pairs = itertools.permutations(candidates[rr][c], 2)
for pair in r_c_pairs:
if pair not in rr_c_pairs:
continue
otherwise_unique = True
# If this pair occurs elsewhere in the col, skip
for rrr in range(len(working_grid)):
if rrr in (r, rr):
continue
if pair in itertools.permutations(candidates[rrr][c], 2):
otherwise_unique = False
break
if not otherwise_unique:
continue
# Found a hidden pair, remove all other candidates from these two cells
candidates[r][c] = set(pair)
candidates[rr][c] = set(pair)
_update_grid()
# Seek X-wings by rows
for v in range(1, size + 1):
# If candidate is in the same 2 positions in 2 rows, and nowhere else in those rows
# Delete from the 2 intersecting cols
# Find rows which have exactly 2 instances of the value in their candidate sets
rows_with_v = [r for r in range(size) if sum(v in candidates[r][c] for c in range(size)) == 2]
if len(rows_with_v) < 2:
continue
# Check whether the 2 columns with the value are the same in the 2 rows
cols_with_v_per_row = [set() for _ in range(len(rows_with_v))]
for i, r in enumerate(rows_with_v):
for c in range(size):
if v in candidates[r][c]:
cols_with_v_per_row[i].add(c)
# Check if there are a pair of tows with the same cols (there may be more than 2 rows)
for i in range(len(rows_with_v)):
for j in range(i + 1, len(rows_with_v)):
if cols_with_v_per_row[i] == cols_with_v_per_row[j]:
# Found an X-wing, remove candidate from the 2 cols
for c in cols_with_v_per_row[i]:
for rr in range(size):
if rr not in (rows_with_v[i], rows_with_v[j]) and v in candidates[rr][c]:
candidates[rr][c].discard(v)
progress = True
# Seek X-wings by cols
for v in range(1, size + 1):
# If candidate is in the same 2 positions in 2 cols, and nowhere else in those cols
# Delete from the 2 intersecting rows
# Find cols which have exactly 2 instances of the value in their candidate sets
cols_with_v = [c for c in range(size) if sum(v in candidates[r][c] for r in range(size)) == 2]
if len(cols_with_v) < 2:
continue
# Check whether the 2 rows with the value are the same in the 2 cols
rows_with_v_per_col = [set() for _ in range(len(cols_with_v))]
for i, c in enumerate(cols_with_v):
for r in range(size):
if v in candidates[r][c]:
rows_with_v_per_col[i].add(r)
# Check if there are a pair of cols with the same rows (there may be more than 2 cols)
for i in range(len(cols_with_v)):
for j in range(i + 1, len(cols_with_v)):
if rows_with_v_per_col[i] == rows_with_v_per_col[j]:
# Found an X-wing, remove candidate from the 2 rows
for r in rows_with_v_per_col[i]:
for cc in range(size):
if cc not in (cols_with_v[i], cols_with_v[j]) and v in candidates[r][cc]:
candidates[r][cc].discard(v)
progress = True
_update_grid()
return progress
while _try_solve_logical():
continue
return working_grid, candidates
def _solve( def _solve(
self, self,
grid: List[List[int]], grid: List[List[int]],
constraints: Dict[Tuple[Tuple[int, int], Tuple[int, int]], str], constraints: Dict[Tuple[Tuple[int, int], Tuple[int, int]], str],
rng: Random, rng: Random,
find_multiple: bool = False,
) -> List[List[int]] | None: ) -> List[List[int]] | None:
""" """
Backtracking Futoshiki solver. Used to verify generated puzzles. Backtracking Futoshiki solver. Used to verify generated puzzles.
Applies logical rules first then backtracks to fill gaps.
Return solved grid, or None if unsolvable. Return solved grid, or None if unsolvable.
If find_multiple: also return None if more than one solution found.
""" """
grid, candidates = self._solve_logical(grid, constraints)
size = len(grid) size = len(grid)
empty_cell = None empty_cell = None
@ -179,76 +419,25 @@ class FutoshikiDataset(ProceduralDataset):
break break
if empty_cell: if empty_cell:
break break
# If no empty cell, solution is complete # If no empty cell, solution is complete
if not empty_cell: if not empty_cell:
return copy.deepcopy(grid) return copy.deepcopy(grid)
r, c = empty_cell r, c = empty_cell
for val in range(1, size + 1): for val in range(1, size + 1):
if val not in candidates[r][c]:
continue
if self._is_valid(grid, r, c, val, constraints): if self._is_valid(grid, r, c, val, constraints):
grid[r][c] = val grid[r][c] = val
solution = self._solve(grid, constraints, rng, find_multiple) solution = self._solve(grid, constraints, rng)
if solution is not None: if solution is not None:
# If find_multiple, continue searching to check for non-uniqueness
if find_multiple and self._has_other_solution(solution, grid, constraints, rng):
grid[r][c] = 0
return None
grid[r][c] = 0 grid[r][c] = 0
return solution return solution
grid[r][c] = 0 grid[r][c] = 0
return None return None
def _has_other_solution(
self,
existing_solution: List[List[int]],
partial_grid: List[List[int]],
constraints: Dict[Tuple[Tuple[int, int], Tuple[int, int]], str],
rng: Random,
) -> bool:
"""
Check if there's at least one solution different from existing_solution, given the partial_grid so far.
This is a quick hack: we attempt to find another solution with a slight difference.
A full approach is backtracking that tries to find any solution differing from existing_solution in >= 1 cell.
"""
# Each cell not set in partial_grid could be varied
size = len(existing_solution)
# Make a fresh puzzle using partial_grid
puzzle_copy = copy.deepcopy(partial_grid)
def backtrack(i = 0, j = 0) -> bool:
# Move past end of row
if j == size:
i += 1
j = 0
# Completed all rows
if i == size:
# Confirm puzzle_copy differs in at least one cell from existing_solution
for rr in range(size):
for cc in range(size):
if puzzle_copy[rr][cc] != existing_solution[rr][cc]:
return True
return False
if puzzle_copy[i][j] != 0:
# Move on
return backtrack(i, j + 1)
else:
# Try different values
vals = list(range(1, size + 1))
rng.shuffle(vals)
for val in vals:
if self._is_valid(puzzle_copy, i, j, val, constraints):
puzzle_copy[i][j] = val
if backtrack(i, j + 1):
return True
puzzle_copy[i][j] = 0
return False
return backtrack(0, 0)
def _is_valid( def _is_valid(
self, self,
grid: List[List[int]], grid: List[List[int]],
@ -334,7 +523,7 @@ class FutoshikiDataset(ProceduralDataset):
constraints = {} constraints = {}
# For each pair of adjacent cells, we might add a constraint # For each pair of adjacent cells, we might add a constraint
# P(adding a constraint) increases with difficulty on an arbitrary scale # P(adding a constraint) increases with difficulty on an arbitrary scale
base_prob = 0.05 + 0.05 * difficulty base_prob = 0.03 + 0.07 * difficulty
for r in range(size): for r in range(size):
for c in range(size): for c in range(size):
# Horizontal neighbor # Horizontal neighbor
@ -353,44 +542,76 @@ class FutoshikiDataset(ProceduralDataset):
constraints[((r, c), (r + 1, c))] = ">" constraints[((r, c), (r + 1, c))] = ">"
return constraints return constraints
def count_solutions(self, grid, constraints, limit=2) -> int:
size = len(grid)
count = 0
def backtrack():
nonlocal count
# Early exit if limit reached
if count >= limit:
return
# Find the next empty cell
for r in range(size):
for c in range(size):
if grid[r][c] == 0:
for val in range(1, size + 1):
if self._is_valid(grid, r, c, val, constraints):
grid[r][c] = val
backtrack()
grid[r][c] = 0
return
count += 1
backtrack()
return count
def _remove_clues( def _remove_clues(
self, self,
grid: List[List[int]], grid: List[List[int]],
constraints: Dict[Tuple[Tuple[int, int], Tuple[int, int]], str], constraints: Dict[Tuple[Tuple[int, int], Tuple[int, int]], str],
difficulty: int,
rng: Random, rng: Random,
) -> List[List[int]]: ) -> List[List[int]]:
""" """
Remove clues from a full solution to try to maintain a unique-solution puzzle. Remove clues from a full solution to try to maintain a unique-solution puzzle.
The higher the difficulty, the more clues we remove.
We remove in random order until we reach our target, or can't without losing uniqueness. We remove in random order until we reach our target, or can't without losing uniqueness.
""" """
size = len(grid) size = len(grid)
fill_fraction = [0.09, 0.07, 0.05, 0.03] # Easiest -> hardest fill_fraction = 0.1
target_filled = int(fill_fraction[difficulty] * (size * size)) target_filled = int(fill_fraction * (size * size))
coords = [(r, c) for r in range(size) for c in range(size)] coords = [(r, c) for r in range(size) for c in range(size)]
rng.shuffle(coords) rng.shuffle(coords)
def count_filled_cells(g): def _count_filled_cells(g):
return sum(g[r][c] != 0 for r in range(size) for c in range(size)) return sum(g[r][c] != 0 for r in range(size) for c in range(size))
for (r,c) in coords: def _try_remove():
if count_filled_cells(grid) <= target_filled: for (r,c) in coords:
break # Removal target hit if _count_filled_cells(grid) <= target_filled:
break # Removal target hit
saved = grid[r][c] saved = grid[r][c]
if saved == 0: if saved == 0:
continue continue
# Try remove # Try remove
grid[r][c] = 0 grid[r][c] = 0
# Check if unsolvable or non-unique # Check if unsolvable
puzzle_copy = copy.deepcopy(grid) sol = self._solve(copy.deepcopy(grid), constraints, rng)
sol = self._solve(puzzle_copy, constraints, rng, find_multiple=True) if sol is None:
if sol is None: grid[r][c] = saved
# Not solvable or non-unique, revert continue
grid[r][c] = saved # Check if not unique
if self.count_solutions(copy.deepcopy(grid), constraints, limit=2) > 1:
grid[r][c] = saved
_try_remove()
# Second pass if we aren't close to target
if _count_filled_cells(grid) > 2 * target_filled:
rng.shuffle(coords)
_try_remove()
return grid return grid