From dce5d9367d6db8560db938e14c7fb4e6155a7770 Mon Sep 17 00:00:00 2001 From: Oliver Date: Sun, 9 Feb 2025 21:23:53 +0000 Subject: [PATCH] Greatly speed up solver --- reasoning_gym/games/futoshiki.py | 393 ++++++++++++++++++++++++------- 1 file changed, 307 insertions(+), 86 deletions(-) diff --git a/reasoning_gym/games/futoshiki.py b/reasoning_gym/games/futoshiki.py index 4a6aa125..8e5100d5 100644 --- a/reasoning_gym/games/futoshiki.py +++ b/reasoning_gym/games/futoshiki.py @@ -2,10 +2,9 @@ import copy import itertools -import random from dataclasses import dataclass from random import Random -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Set, Tuple from ..factory import ProceduralDataset, register_dataset @@ -59,7 +58,7 @@ class FutoshikiDataset(ProceduralDataset): # Add random adjacency constraints consistent with generated solved grid constraints = self._generate_random_constraints(solution, self.config.difficulty, rng) # Starting with full solution, remove clues to desired difficulty - puzzle = self._remove_clues(copy.deepcopy(solution), constraints, self.config.difficulty, rng) + puzzle = self._remove_clues(copy.deepcopy(solution), constraints, rng) # Format as strings puzzle_str = self._puzzle_to_string(puzzle, constraints) @@ -82,6 +81,10 @@ class FutoshikiDataset(ProceduralDataset): puzzle_grid: List[List[int]], constraints: Dict[Tuple[Tuple[int, int], Tuple[int, int]], str] ) -> str: + """ + Formats a Futoshiki puzzle grid as a string with constraints. + Constraints are represented as '<', '>', '\u2227', or '\u2228' between adjacent cells. + """ n = len(puzzle_grid) def cell_str(val: int) -> str: @@ -150,24 +153,261 @@ class FutoshikiDataset(ProceduralDataset): return "\n".join(lines) - # currently this gets a bit slow for larger grid sizes as it relies on brute force backtracking - # possible improvements: implement optimisations, using common rules in Futoshiki to reduce search space - # see: https://www.futoshiki.com/how-to-solve - # also see other solvers' approaches e.g. https://github.com/davidswarbrick/futoshiki-solver/blob/master/Futoshiki.py - # however I attempted some optimisations based on the code of the above parser, such as the recursive constraint following, and it was actually quite a lot slower + def _solve_logical( + self, + grid: List[List[int]], + constraints: Dict[Tuple[Tuple[int, int], Tuple[int, int]], str], + ) -> Tuple[List[List[int]], List[List[Set[int]]]]: + """ + Apply logical rules to progress solution. + Returns current state if logical rules can't progress further. + Logical rules are implemented based on the descriptions here: https://futoshiki.uk/ + """ + size, working_grid = len(grid), copy.deepcopy(grid) + + # Starting point all numbers are candidates for all unfilled squares + candidates: List[List[Set[int]]] = [ + [ + set(range(1, len(grid) + 1)) if grid[r][c] == 0 else {grid[r][c]} + for c in range(len(grid)) + ] + for r in range(len(grid)) + ] + + # Any cells > another cannot be 1, and any cells < another cannot be `size` + # This is separated from the repeated function below to avoid redundant checks + for ((r1, c1), (_, _)), rel in constraints.items(): + if rel == ">": + candidates[r1][c1].discard(1) + elif rel == "<": + candidates[r1][c1].discard(size) + + def _update_grid(): + """Update solution wherever a cell's candidates set is reduced to a single element.""" + for r in range(len(working_grid)): + for c in range(len(working_grid)): + if working_grid[r][c] == 0 and len(candidates[r][c]) == 1: + working_grid[r][c] = next(iter(candidates[r][c])) + + def _try_solve_logical() -> bool: + progress = False + + # Eliminate candidates based on numbers already placed + # If a number is placed in a cell, it cannot be a candidate for any other cell in the same row or column + for r in range(len(working_grid)): + for c in range(len(working_grid)): + if working_grid[r][c] == 0: + continue + for cc in range(len(working_grid)): + if cc != c and working_grid[r][c] in candidates[r][cc]: + candidates[r][cc].discard(working_grid[r][c]) + progress = True + for rr in range(len(working_grid)): + if rr != r and working_grid[r][c] in candidates[rr][c]: + candidates[rr][c].discard(working_grid[r][c]) + progress = True + + _update_grid() + + # Eliminate candidates based on constraints + # Based on currently filled values, eliminate candidates that violate constraints + def _eliminate_by_constraint(rc_less: Tuple[int, int], rc_greater: Tuple[int, int]) -> bool: + r_less, c_less = rc_less + r_greater, c_greater = rc_greater + progress = False + + if working_grid[r_less][c_less] != 0: + # greater must only have candidates > less + for v in candidates[r_greater][c_greater].copy(): + if v <= working_grid[r_less][c_less] and v in candidates[r_greater][c_greater]: + candidates[r_greater][c_greater].discard(v) + progress = True + + if working_grid[r_greater][c_greater] != 0: + # less must only have candidates < greater + for v in candidates[r_less][c_less].copy(): + if v >= working_grid[r_greater][c_greater] and v in candidates[r_less][c_less]: + candidates[r_less][c_less].discard(v) + progress = True + + return progress + + for ((r1, c1), (r2, c2)), rel in constraints.items(): + v1, v2 = working_grid[r1][c1], working_grid[r2][c2] + if v1 != 0 and v2 != 0: # both already filled, skip + continue + if rel == "<": + progress |= _eliminate_by_constraint((r1, c1), (r2, c2)) + elif rel == ">": + progress |= _eliminate_by_constraint((r2, c2), (r1, c1)) + + _update_grid() + + # Seek "hidden singles" - cells where a candidate is unique in the row or column + for r in range(len(working_grid)): + for c in range(len(working_grid)): + if working_grid[r][c] != 0: + continue + for v in candidates[r][c]: + if sum(v in candidates[r][cc] for cc in range(len(working_grid))) == 1: + candidates[r][c] = {v} # candidate unique in row + break + if sum(v in candidates[rr][c] for rr in range(len(working_grid))) == 1: + candidates[r][c] = {v} # candidate unique in column + break + + _update_grid() + + # Seek "naked pairs" if same pair of candidates twice in a row or col, with nothing else in those two cells + # Remove them from other cells in row/col + for r in range(len(working_grid)): + for c in range(len(working_grid)): + if working_grid[r][c] != 0 or len(candidates[r][c]) != 2: + continue + for cc in range(len(working_grid)): + if cc != c and candidates[r][cc] == candidates[r][c]: + for ccc in range(len(working_grid)): + if ccc != c and ccc != cc and candidates[r][c].intersection(candidates[r][ccc]): + candidates[r][ccc] -= candidates[r][c] + progress = True + for rr in range(len(working_grid)): + if rr != r and candidates[rr][c] == candidates[r][c]: + for rrr in range(len(working_grid)): + if rrr != r and rrr != rr and candidates[r][c].intersection(candidates[rrr][c]): + candidates[rrr][c] -= candidates[r][c] + progress = True + + _update_grid() + + # Seek "hidden pairs" - same pair of candidates occurs in two cells in a line, but nowhere else in the line + # alongside other candidates (hence hidden). All other candidates can be removed from those two cells + for r in range(len(working_grid)): + for c in range(len(working_grid)): + if working_grid[r][c] != 0: + continue + for cc in range(c + 1, len(working_grid)): + if working_grid[r][cc] != 0: + continue + # Check if r, c shares a candidate pair with r, cc (maybe subset, not exact candidate set match) + r_c_pairs = itertools.permutations(candidates[r][c], 2) + r_cc_pairs = itertools.permutations(candidates[r][cc], 2) + for pair in r_c_pairs: + if pair not in r_cc_pairs: + continue + otherwise_unique = True + # If this pair occurs elsewhere in the row, skip + for ccc in range(len(working_grid)): + if ccc in (c, cc): + continue + if pair in itertools.permutations(candidates[r][ccc], 2): + otherwise_unique = False + break + if not otherwise_unique: + continue + # Found a hidden pair, remove all other candidates from these two cells + candidates[r][c] = set(pair) + candidates[r][cc] = set(pair) + + for rr in range(r + 1, len(working_grid)): + if working_grid[rr][c] != 0: + continue + # Check if r, c shares a candidate pair with rr, c (maybe subset, not exact candidate set match) + r_c_pairs = itertools.permutations(candidates[r][c], 2) + rr_c_pairs = itertools.permutations(candidates[rr][c], 2) + for pair in r_c_pairs: + if pair not in rr_c_pairs: + continue + otherwise_unique = True + # If this pair occurs elsewhere in the col, skip + for rrr in range(len(working_grid)): + if rrr in (r, rr): + continue + if pair in itertools.permutations(candidates[rrr][c], 2): + otherwise_unique = False + break + if not otherwise_unique: + continue + # Found a hidden pair, remove all other candidates from these two cells + candidates[r][c] = set(pair) + candidates[rr][c] = set(pair) + + _update_grid() + + # Seek X-wings by rows + for v in range(1, size + 1): + # If candidate is in the same 2 positions in 2 rows, and nowhere else in those rows + # Delete from the 2 intersecting cols + + # Find rows which have exactly 2 instances of the value in their candidate sets + rows_with_v = [r for r in range(size) if sum(v in candidates[r][c] for c in range(size)) == 2] + if len(rows_with_v) < 2: + continue + # Check whether the 2 columns with the value are the same in the 2 rows + cols_with_v_per_row = [set() for _ in range(len(rows_with_v))] + for i, r in enumerate(rows_with_v): + for c in range(size): + if v in candidates[r][c]: + cols_with_v_per_row[i].add(c) + # Check if there are a pair of tows with the same cols (there may be more than 2 rows) + for i in range(len(rows_with_v)): + for j in range(i + 1, len(rows_with_v)): + if cols_with_v_per_row[i] == cols_with_v_per_row[j]: + # Found an X-wing, remove candidate from the 2 cols + for c in cols_with_v_per_row[i]: + for rr in range(size): + if rr not in (rows_with_v[i], rows_with_v[j]) and v in candidates[rr][c]: + candidates[rr][c].discard(v) + progress = True + + # Seek X-wings by cols + for v in range(1, size + 1): + # If candidate is in the same 2 positions in 2 cols, and nowhere else in those cols + # Delete from the 2 intersecting rows + + # Find cols which have exactly 2 instances of the value in their candidate sets + cols_with_v = [c for c in range(size) if sum(v in candidates[r][c] for r in range(size)) == 2] + if len(cols_with_v) < 2: + continue + # Check whether the 2 rows with the value are the same in the 2 cols + rows_with_v_per_col = [set() for _ in range(len(cols_with_v))] + for i, c in enumerate(cols_with_v): + for r in range(size): + if v in candidates[r][c]: + rows_with_v_per_col[i].add(r) + # Check if there are a pair of cols with the same rows (there may be more than 2 cols) + for i in range(len(cols_with_v)): + for j in range(i + 1, len(cols_with_v)): + if rows_with_v_per_col[i] == rows_with_v_per_col[j]: + # Found an X-wing, remove candidate from the 2 rows + for r in rows_with_v_per_col[i]: + for cc in range(size): + if cc not in (cols_with_v[i], cols_with_v[j]) and v in candidates[r][cc]: + candidates[r][cc].discard(v) + progress = True + + _update_grid() + + return progress + + while _try_solve_logical(): + continue + + return working_grid, candidates def _solve( self, grid: List[List[int]], constraints: Dict[Tuple[Tuple[int, int], Tuple[int, int]], str], rng: Random, - find_multiple: bool = False, ) -> List[List[int]] | None: """ Backtracking Futoshiki solver. Used to verify generated puzzles. + Applies logical rules first then backtracks to fill gaps. Return solved grid, or None if unsolvable. - If find_multiple: also return None if more than one solution found. """ + + grid, candidates = self._solve_logical(grid, constraints) + size = len(grid) empty_cell = None @@ -179,76 +419,25 @@ class FutoshikiDataset(ProceduralDataset): break if empty_cell: break - + # If no empty cell, solution is complete if not empty_cell: return copy.deepcopy(grid) r, c = empty_cell for val in range(1, size + 1): + if val not in candidates[r][c]: + continue if self._is_valid(grid, r, c, val, constraints): grid[r][c] = val - solution = self._solve(grid, constraints, rng, find_multiple) + solution = self._solve(grid, constraints, rng) if solution is not None: - # If find_multiple, continue searching to check for non-uniqueness - if find_multiple and self._has_other_solution(solution, grid, constraints, rng): - grid[r][c] = 0 - return None - grid[r][c] = 0 return solution grid[r][c] = 0 - + return None - def _has_other_solution( - self, - existing_solution: List[List[int]], - partial_grid: List[List[int]], - constraints: Dict[Tuple[Tuple[int, int], Tuple[int, int]], str], - rng: Random, - ) -> bool: - """ - Check if there's at least one solution different from existing_solution, given the partial_grid so far. - This is a quick hack: we attempt to find another solution with a slight difference. - A full approach is backtracking that tries to find any solution differing from existing_solution in >= 1 cell. - """ - # Each cell not set in partial_grid could be varied - size = len(existing_solution) - # Make a fresh puzzle using partial_grid - puzzle_copy = copy.deepcopy(partial_grid) - - def backtrack(i = 0, j = 0) -> bool: - # Move past end of row - if j == size: - i += 1 - j = 0 - # Completed all rows - if i == size: - # Confirm puzzle_copy differs in at least one cell from existing_solution - for rr in range(size): - for cc in range(size): - if puzzle_copy[rr][cc] != existing_solution[rr][cc]: - return True - return False - - if puzzle_copy[i][j] != 0: - # Move on - return backtrack(i, j + 1) - else: - # Try different values - vals = list(range(1, size + 1)) - rng.shuffle(vals) - for val in vals: - if self._is_valid(puzzle_copy, i, j, val, constraints): - puzzle_copy[i][j] = val - if backtrack(i, j + 1): - return True - puzzle_copy[i][j] = 0 - return False - - return backtrack(0, 0) - def _is_valid( self, grid: List[List[int]], @@ -334,7 +523,7 @@ class FutoshikiDataset(ProceduralDataset): constraints = {} # For each pair of adjacent cells, we might add a constraint # P(adding a constraint) increases with difficulty on an arbitrary scale - base_prob = 0.05 + 0.05 * difficulty + base_prob = 0.03 + 0.07 * difficulty for r in range(size): for c in range(size): # Horizontal neighbor @@ -353,44 +542,76 @@ class FutoshikiDataset(ProceduralDataset): constraints[((r, c), (r + 1, c))] = ">" return constraints + def count_solutions(self, grid, constraints, limit=2) -> int: + size = len(grid) + count = 0 + + def backtrack(): + nonlocal count + # Early exit if limit reached + if count >= limit: + return + # Find the next empty cell + for r in range(size): + for c in range(size): + if grid[r][c] == 0: + for val in range(1, size + 1): + if self._is_valid(grid, r, c, val, constraints): + grid[r][c] = val + backtrack() + grid[r][c] = 0 + return + count += 1 + + backtrack() + return count + def _remove_clues( self, grid: List[List[int]], constraints: Dict[Tuple[Tuple[int, int], Tuple[int, int]], str], - difficulty: int, rng: Random, ) -> List[List[int]]: """ Remove clues from a full solution to try to maintain a unique-solution puzzle. - The higher the difficulty, the more clues we remove. We remove in random order until we reach our target, or can't without losing uniqueness. """ size = len(grid) - fill_fraction = [0.09, 0.07, 0.05, 0.03] # Easiest -> hardest - target_filled = int(fill_fraction[difficulty] * (size * size)) + fill_fraction = 0.1 + target_filled = int(fill_fraction * (size * size)) coords = [(r, c) for r in range(size) for c in range(size)] rng.shuffle(coords) - def count_filled_cells(g): + def _count_filled_cells(g): return sum(g[r][c] != 0 for r in range(size) for c in range(size)) - for (r,c) in coords: - if count_filled_cells(grid) <= target_filled: - break # Removal target hit + def _try_remove(): + for (r,c) in coords: + if _count_filled_cells(grid) <= target_filled: + break # Removal target hit - saved = grid[r][c] - if saved == 0: - continue - # Try remove - grid[r][c] = 0 + saved = grid[r][c] + if saved == 0: + continue + # Try remove + grid[r][c] = 0 - # Check if unsolvable or non-unique - puzzle_copy = copy.deepcopy(grid) - sol = self._solve(puzzle_copy, constraints, rng, find_multiple=True) - if sol is None: - # Not solvable or non-unique, revert - grid[r][c] = saved + # Check if unsolvable + sol = self._solve(copy.deepcopy(grid), constraints, rng) + if sol is None: + grid[r][c] = saved + continue + # Check if not unique + if self.count_solutions(copy.deepcopy(grid), constraints, limit=2) > 1: + grid[r][c] = saved + + _try_remove() + + # Second pass if we aren't close to target + if _count_filled_cells(grid) > 2 * target_filled: + rng.shuffle(coords) + _try_remove() return grid