diff --git a/reasoning_gym/graphs/__init__.py b/reasoning_gym/graphs/__init__.py index 4d1ccd8f..fc4b9f19 100644 --- a/reasoning_gym/graphs/__init__.py +++ b/reasoning_gym/graphs/__init__.py @@ -2,6 +2,7 @@ from .course_schedule import CourseScheduleConfig, CourseScheduleDataset from .family_relationships import FamilyRelationshipsConfig, FamilyRelationshipsDataset from .largest_island import LargestIslandDataset from .quantum_lock import QuantumLockConfig, QuantumLockDataset +from .shortest_path import ShortestPathConfig, ShortestPathDataset __all__ = [ "FamilyRelationshipsConfig", @@ -11,4 +12,6 @@ __all__ = [ "LargestIslandDataset", "CourseScheduleDataset", "CourseScheduleConfig", + "ShortestPathConfig", + "ShortestPathDataset", ] diff --git a/reasoning_gym/graphs/shortest_path.py b/reasoning_gym/graphs/shortest_path.py new file mode 100644 index 00000000..f2740dfa --- /dev/null +++ b/reasoning_gym/graphs/shortest_path.py @@ -0,0 +1,124 @@ +"""Find the shortest path between a start and end point in a grid""" + +from collections import deque +from dataclasses import dataclass +from random import Random +from typing import Optional + +from ..factory import ProceduralDataset, register_dataset + +QUESTION_TEMPLATE = """Your task is to find the length of the shortest path from the start to the destination point in a grid. + +The grid is represented as a matrix with the following types of cells: +- *: your starting point +- #: your destination point +- O: an open cell +- X: a blocked cell + +Therefore, you need to find the length of the shortest path from * to #, moving only through open cells. +If there is no path from * to #, return -1. + +Example: +- Input: Find the length of the shortest path from * to # in the following grid: + X X X X X + X * O O X + X O X O X + X X X O # +- Output: 5 + +Now, find the length of the shortest path from * to # in the following grid: +{grid} +""" + + +@dataclass +class ShortestPathConfig: + """Configuration for Shortest Path dataset generation""" + + min_rows: int = 10 + max_rows: int = 30 + min_cols: int = 10 + max_cols: int = 30 + p_blocked: float = 0.4 + + size: int = 500 # Virtual dataset size + seed: Optional[int] = None + + def validate(self): + """Validate configuration parameters""" + assert 1 <= self.min_rows, "min_rows must be at least 1" + assert self.min_rows <= self.max_rows, "min_rows must be less than or equal to max_rows" + assert 1 <= self.min_cols, "min_cols must be at least 1" + assert self.min_cols <= self.max_cols, "min_cols must be less than or equal to max_cols" + assert 0 <= self.p_blocked <= 1, "p_blocked must be between 0 and 1" + + +class ShortestPathDataset(ProceduralDataset): + """Generates Shortest Path exercises with configurable difficulty""" + + def __init__(self, config: ShortestPathConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + + def _get_grid(self, rng: Random) -> list[list[str]]: + """Generate a random grid with open and blocked cells""" + + rows, cols = rng.randint(self.config.min_rows, self.config.max_rows), rng.randint( + self.config.min_cols, self.config.max_cols + ) + grid = [["X" if rng.random() < self.config.p_blocked else "O" for _ in range(cols)] for _ in range(rows)] + + start_r, start_c = rng.randint(0, rows - 1), rng.randint(0, cols - 1) + grid[start_r][start_c] = "*" + + while True: + end_r, end_c = rng.randint(0, rows - 1), rng.randint(0, cols - 1) + if (end_r, end_c) != (start_r, start_c): + grid[end_r][end_c] = "#" + break + + return grid + + def _matrix_to_str(self, matrix: list[list[int]]) -> str: + """Get a string representation of the matrix""" + return "\n".join(" ".join(str(x) for x in row) for row in matrix) + + def _get_answer(self, matrix: list[list[str]]) -> int: + """Run BFS to find the shortest path length""" + ROWS, COLS = len(matrix), len(matrix[0]) + DIRS = [(0, 1), (1, 0), (0, -1), (-1, 0)] + + start_r, start_c = next((r, c) for r in range(ROWS) for c in range(COLS) if matrix[r][c] == "*") + queue = deque([(start_r, start_c)]) + steps = 0 + + while queue: + steps += 1 + for _ in range(len(queue)): + r, c = queue.popleft() + for dr, dc in DIRS: + new_r, new_c = r + dr, c + dc + if 0 <= new_r < ROWS and 0 <= new_c < COLS: + if matrix[new_r][new_c] == "#": + return steps + if matrix[new_r][new_c] == "O": + matrix[new_r][new_c] = "X" + queue.append((new_r, new_c)) + + return -1 + + def __getitem__(self, idx: int) -> dict: + """Generate a single Shortest Path question""" + rng = Random(self.seed + idx) + + matrix = self._get_grid(rng) + matrix_str = self._matrix_to_str(matrix) + answer = self._get_answer(matrix) + + return { + "question": QUESTION_TEMPLATE.format(grid=matrix_str), + "answer": str(answer), + "metadata": {"matrix": matrix, "solution": answer}, + } + + +register_dataset("shortest_path", ShortestPathDataset, ShortestPathConfig) diff --git a/tests/test_shortest_path.py b/tests/test_shortest_path.py new file mode 100644 index 00000000..78136304 --- /dev/null +++ b/tests/test_shortest_path.py @@ -0,0 +1,123 @@ +"""Tests for Shortest Path questions generation""" + +import pytest + +from reasoning_gym.graphs.shortest_path import ShortestPathConfig, ShortestPathDataset + + +def test_shortest_path_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = ShortestPathConfig(min_rows=-1) # Negative not allowed + config.validate() + + with pytest.raises(AssertionError): + config = ShortestPathConfig(min_rows=0) # Zero not allowed + config.validate() + + with pytest.raises(AssertionError): + config = ShortestPathConfig(min_cols=-1) # Negative not allowed + config.validate() + + with pytest.raises(AssertionError): + config = ShortestPathConfig(min_cols=0) # Zero not allowed + config.validate() + + with pytest.raises(AssertionError): + config = ShortestPathConfig(min_rows=10, max_rows=5) # min > max + config.validate() + + with pytest.raises(AssertionError): + config = ShortestPathConfig(min_cols=10, max_cols=5) + config.validate() + + with pytest.raises(AssertionError): + config = ShortestPathConfig(p_blocked=-0.1) + config.validate() + + with pytest.raises(AssertionError): + config = ShortestPathConfig(p_blocked=1.1) + config.validate() + + +def test_shortest_path_dataset_deterministic(): + """Test that dataset generates same items with same seed""" + config = ShortestPathConfig(seed=42, size=10) + dataset1 = ShortestPathDataset(config) + dataset2 = ShortestPathDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_shortest_path_dataset_items(): + """Test basic properties of generated items""" + config = ShortestPathConfig(min_rows=3, max_rows=5, min_cols=3, max_cols=5, size=10, seed=42) + dataset = ShortestPathDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + # Check item structure + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata + assert "matrix" in item["metadata"] + assert "solution" in item["metadata"] + + matrix = item["metadata"]["matrix"] + solution = item["metadata"]["solution"] + + # Verify values + assert len(matrix) >= 3 + assert len(matrix) <= 5 + assert all(len(row) >= 3 for row in matrix) + assert all(len(row) <= 5 for row in matrix) + assert any(cell == "*" for row in matrix for cell in row) # Start cell + assert any(cell == "#" for row in matrix for cell in row) # End cell + + +def test_shortest_path_dataset_iteration(): + """Test that iteration respects dataset size""" + config = ShortestPathConfig(size=5, seed=42) + dataset = ShortestPathDataset(config) + + items = list(dataset) + assert len(items) == config.size + + # Test multiple iterations yield same items + assert items == list(dataset) + + +def test_shortest_path_answer(): + """Test the _get_distances method""" + config = ShortestPathConfig(seed=42) + dataset = ShortestPathDataset(config) + + # Simple + matrix = [ + ["X", "X", "X", "X", "X"], + ["X", "*", "O", "#", "X"], + ["X", "O", "X", "O", "X"], + ] + assert dataset._get_answer(matrix) == 2 + + # One shot example in prompt + matrix = [ + ["X", "X", "X", "X", "X"], + ["X", "*", "O", "O", "X"], + ["X", "O", "X", "O", "X"], + ["X", "X", "X", "O", "#"], + ] + assert dataset._get_answer(matrix) == 5 + + # Impossible solution + matrix = [ + ["X", "X", "X", "X", "X"], + ["X", "*", "O", "O", "X"], + ["X", "O", "X", "O", "X"], + ["X", "X", "X", "X", "#"], + ] + assert dataset._get_answer(matrix) == -1