predict actual path

This commit is contained in:
Zafir Stojanovski 2025-02-23 18:24:23 +01:00
parent 97b3097984
commit c5f37d5e9f
2 changed files with 140 additions and 27 deletions

View file

@ -3,11 +3,11 @@
from collections import deque
from dataclasses import dataclass
from random import Random
from typing import Optional
from typing import Any, Optional
from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Your task is to find the length of the shortest path from the start to the destination point in a grid.
QUESTION_TEMPLATE = """Your task is to find the shortest path from the start to the destination point in a grid.
The grid is represented as a matrix with the following types of cells:
- *: your starting point
@ -15,16 +15,24 @@ The grid is represented as a matrix with the following types of cells:
- O: an open cell
- X: a blocked cell
Therefore, you need to find the length of the shortest path from * to #, moving only through open cells.
If there is no path from * to #, return -1.
Therefore, you need to find the shortest path from * to #, moving only through open cells.
If there is no path from * to #, simply write "infeasible" (without quotes).
Example:
Example 1:
- Input: Find the length of the shortest path from * to # in the following grid:
X X X X X
X * O O X
X O X O X
X X X O #
- Output: 5
- Output: right right down down right
Example 2:
- Input: Find the length of the shortest path from * to # in the following grid:
X X X X X
X * O O X
X O X O X
X X X X #
- Output: infeasible
Now, find the length of the shortest path from * to # in the following grid:
{grid}
@ -82,29 +90,75 @@ class ShortestPathDataset(ProceduralDataset):
"""Get a string representation of the matrix"""
return "\n".join(" ".join(str(x) for x in row) for row in matrix)
def _get_answer(self, matrix: list[list[str]]) -> int:
"""Run BFS to find the shortest path length"""
def _get_answer(self, matrix: list[list[str]]) -> list[str]:
"""Run BFS to find the shortest path"""
ROWS, COLS = len(matrix), len(matrix[0])
DIRS = [(0, 1), (1, 0), (0, -1), (-1, 0)]
DIRS = [(0, 1, "right"), (1, 0, "down"), (0, -1, "left"), (-1, 0, "up")]
start_r, start_c = next((r, c) for r in range(ROWS) for c in range(COLS) if matrix[r][c] == "*")
queue = deque([(start_r, start_c)])
steps = 0
queue = deque([(start_r, start_c, [])])
visited = set((start_r, start_c))
while queue:
steps += 1
for _ in range(len(queue)):
r, c = queue.popleft()
for dr, dc in DIRS:
new_r, new_c = r + dr, c + dc
if 0 <= new_r < ROWS and 0 <= new_c < COLS:
if matrix[new_r][new_c] == "#":
return steps
if matrix[new_r][new_c] == "O":
matrix[new_r][new_c] = "X"
queue.append((new_r, new_c))
r, c, path = queue.popleft()
for dr, dc, direction in DIRS:
new_r, new_c = r + dr, c + dc
if 0 <= new_r < ROWS and 0 <= new_c < COLS and (new_r, new_c) not in visited:
new_path = path + [direction]
if matrix[new_r][new_c] == "#":
return new_path
if matrix[new_r][new_c] == "O":
visited.add((new_r, new_c))
queue.append((new_r, new_c, new_path))
return -1
return []
def _is_valid_path(self, matrix: list[list[str]], path: list[str]) -> bool:
"""Verifies the path goes from * to # without crossing X cells"""
ROWS, COLS = len(matrix), len(matrix[0])
DIRS = {"right": (0, 1), "down": (1, 0), "left": (0, -1), "up": (-1, 0)}
start_r, start_c = next((r, c) for r in range(ROWS) for c in range(COLS) if matrix[r][c] == "*")
end_r, end_c = next((r, c) for r in range(ROWS) for c in range(COLS) if matrix[r][c] == "#")
r, c = start_r, start_c
for direction in path:
if direction not in DIRS:
return False # Invalid direction
dr, dc = DIRS[direction]
r, c = r + dr, c + dc
if not (0 <= r < ROWS and 0 <= c < COLS):
return False
if matrix[r][c] == "X":
return False
return (r, c) == (end_r, end_c)
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Overwrite this method in derived classes if a single oracle answer is not available."""
oracle_answer = entry["answer"].strip()
if answer is not None and len(answer) > 0:
answer = answer.strip()
# Exact answer
if answer == oracle_answer:
return 1.0
matrix = entry["metadata"]["matrix"]
answer = answer.split()
oracle_answer = oracle_answer.split()
# Path is valid and has the same length as the oracle answer
if self._is_valid_path(matrix, answer) and len(answer) == len(oracle_answer):
return 1.0
# Path is valid but has a larger length than the oracle answer
elif self._is_valid_path(matrix, answer):
return 0.5
return 0.01
return 0.0
def __getitem__(self, idx: int) -> dict:
"""Generate a single Shortest Path question"""
@ -113,10 +167,11 @@ class ShortestPathDataset(ProceduralDataset):
matrix = self._get_grid(rng)
matrix_str = self._matrix_to_str(matrix)
answer = self._get_answer(matrix)
answer_str = " ".join(answer) if answer else "infeasible"
return {
"question": QUESTION_TEMPLATE.format(grid=matrix_str),
"answer": str(answer),
"answer": answer_str,
"metadata": {"matrix": matrix, "solution": answer},
}