diff --git a/reasoning_gym/graphs/shortest_path.py b/reasoning_gym/graphs/shortest_path.py index f2740dfa..93ff605f 100644 --- a/reasoning_gym/graphs/shortest_path.py +++ b/reasoning_gym/graphs/shortest_path.py @@ -3,11 +3,11 @@ from collections import deque from dataclasses import dataclass from random import Random -from typing import Optional +from typing import Any, Optional from ..factory import ProceduralDataset, register_dataset -QUESTION_TEMPLATE = """Your task is to find the length of the shortest path from the start to the destination point in a grid. +QUESTION_TEMPLATE = """Your task is to find the shortest path from the start to the destination point in a grid. The grid is represented as a matrix with the following types of cells: - *: your starting point @@ -15,16 +15,24 @@ The grid is represented as a matrix with the following types of cells: - O: an open cell - X: a blocked cell -Therefore, you need to find the length of the shortest path from * to #, moving only through open cells. -If there is no path from * to #, return -1. +Therefore, you need to find the shortest path from * to #, moving only through open cells. +If there is no path from * to #, simply write "infeasible" (without quotes). -Example: +Example 1: - Input: Find the length of the shortest path from * to # in the following grid: X X X X X X * O O X X O X O X X X X O # -- Output: 5 +- Output: right right down down right + +Example 2: +- Input: Find the length of the shortest path from * to # in the following grid: + X X X X X + X * O O X + X O X O X + X X X X # +- Output: infeasible Now, find the length of the shortest path from * to # in the following grid: {grid} @@ -82,29 +90,75 @@ class ShortestPathDataset(ProceduralDataset): """Get a string representation of the matrix""" return "\n".join(" ".join(str(x) for x in row) for row in matrix) - def _get_answer(self, matrix: list[list[str]]) -> int: - """Run BFS to find the shortest path length""" + def _get_answer(self, matrix: list[list[str]]) -> list[str]: + """Run BFS to find the shortest path""" ROWS, COLS = len(matrix), len(matrix[0]) - DIRS = [(0, 1), (1, 0), (0, -1), (-1, 0)] + DIRS = [(0, 1, "right"), (1, 0, "down"), (0, -1, "left"), (-1, 0, "up")] start_r, start_c = next((r, c) for r in range(ROWS) for c in range(COLS) if matrix[r][c] == "*") - queue = deque([(start_r, start_c)]) - steps = 0 + queue = deque([(start_r, start_c, [])]) + visited = set((start_r, start_c)) while queue: - steps += 1 - for _ in range(len(queue)): - r, c = queue.popleft() - for dr, dc in DIRS: - new_r, new_c = r + dr, c + dc - if 0 <= new_r < ROWS and 0 <= new_c < COLS: - if matrix[new_r][new_c] == "#": - return steps - if matrix[new_r][new_c] == "O": - matrix[new_r][new_c] = "X" - queue.append((new_r, new_c)) + r, c, path = queue.popleft() + for dr, dc, direction in DIRS: + new_r, new_c = r + dr, c + dc + if 0 <= new_r < ROWS and 0 <= new_c < COLS and (new_r, new_c) not in visited: + new_path = path + [direction] + if matrix[new_r][new_c] == "#": + return new_path + if matrix[new_r][new_c] == "O": + visited.add((new_r, new_c)) + queue.append((new_r, new_c, new_path)) - return -1 + return [] + + def _is_valid_path(self, matrix: list[list[str]], path: list[str]) -> bool: + """Verifies the path goes from * to # without crossing X cells""" + ROWS, COLS = len(matrix), len(matrix[0]) + DIRS = {"right": (0, 1), "down": (1, 0), "left": (0, -1), "up": (-1, 0)} + + start_r, start_c = next((r, c) for r in range(ROWS) for c in range(COLS) if matrix[r][c] == "*") + end_r, end_c = next((r, c) for r in range(ROWS) for c in range(COLS) if matrix[r][c] == "#") + + r, c = start_r, start_c + for direction in path: + if direction not in DIRS: + return False # Invalid direction + dr, dc = DIRS[direction] + r, c = r + dr, c + dc + if not (0 <= r < ROWS and 0 <= c < COLS): + return False + if matrix[r][c] == "X": + return False + + return (r, c) == (end_r, end_c) + + def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: + """Overwrite this method in derived classes if a single oracle answer is not available.""" + oracle_answer = entry["answer"].strip() + if answer is not None and len(answer) > 0: + answer = answer.strip() + + # Exact answer + if answer == oracle_answer: + return 1.0 + + matrix = entry["metadata"]["matrix"] + answer = answer.split() + oracle_answer = oracle_answer.split() + + # Path is valid and has the same length as the oracle answer + if self._is_valid_path(matrix, answer) and len(answer) == len(oracle_answer): + return 1.0 + + # Path is valid but has a larger length than the oracle answer + elif self._is_valid_path(matrix, answer): + return 0.5 + + return 0.01 + + return 0.0 def __getitem__(self, idx: int) -> dict: """Generate a single Shortest Path question""" @@ -113,10 +167,11 @@ class ShortestPathDataset(ProceduralDataset): matrix = self._get_grid(rng) matrix_str = self._matrix_to_str(matrix) answer = self._get_answer(matrix) + answer_str = " ".join(answer) if answer else "infeasible" return { "question": QUESTION_TEMPLATE.format(grid=matrix_str), - "answer": str(answer), + "answer": answer_str, "metadata": {"matrix": matrix, "solution": answer}, } diff --git a/tests/test_shortest_path.py b/tests/test_shortest_path.py index 78136304..f726b7cf 100644 --- a/tests/test_shortest_path.py +++ b/tests/test_shortest_path.py @@ -102,7 +102,7 @@ def test_shortest_path_answer(): ["X", "*", "O", "#", "X"], ["X", "O", "X", "O", "X"], ] - assert dataset._get_answer(matrix) == 2 + assert " ".join(dataset._get_answer(matrix)) == "right right" # One shot example in prompt matrix = [ @@ -111,7 +111,7 @@ def test_shortest_path_answer(): ["X", "O", "X", "O", "X"], ["X", "X", "X", "O", "#"], ] - assert dataset._get_answer(matrix) == 5 + assert " ".join(dataset._get_answer(matrix)) == "right right down down right" # Impossible solution matrix = [ @@ -120,4 +120,62 @@ def test_shortest_path_answer(): ["X", "O", "X", "O", "X"], ["X", "X", "X", "X", "#"], ] - assert dataset._get_answer(matrix) == -1 + assert dataset._get_answer(matrix) == [] + + # Multiple valid solutions of same size + entry = { + "answer": "right right down down", + "metadata": { + "matrix": [ + ["X", "X", "X", "X", "X"], + ["X", "*", "O", "O", "X"], + ["X", "O", "X", "O", "X"], + ["X", "O", "O", "#", "X"], + ] + }, + } + assert dataset.score_answer("right right down down", entry) == 1.0 + assert dataset.score_answer("down down right right", entry) == 1.0 + + # Partial solution (valid, but longer than oracle) + entry = { + "answer": "right right", + "metadata": { + "matrix": [ + ["X", "X", "X", "X", "X"], + ["X", "*", "O", "#", "X"], + ["X", "O", "X", "O", "X"], + ["X", "O", "O", "O", "X"], + ] + }, + } + assert dataset.score_answer("right right", entry) == 1.0 + assert dataset.score_answer("down down right right up up", entry) == 0.5 + + # Invalid solution (steps over X) + entry = { + "answer": "right right down down", + "metadata": { + "matrix": [ + ["X", "X", "X", "X", "X"], + ["X", "*", "O", "O", "X"], + ["X", "O", "X", "O", "X"], + ["X", "O", "O", "#", "X"], + ] + }, + } + assert dataset.score_answer("right down right down", entry) == 0.01 + + # Answer is None + entry = { + "answer": "right right down down", + "metadata": { + "matrix": [ + ["X", "X", "X", "X", "X"], + ["X", "*", "O", "O", "X"], + ["X", "O", "X", "O", "X"], + ["X", "O", "O", "#", "X"], + ] + }, + } + assert dataset.score_answer(None, entry) == 0.0