diff --git a/reasoning_gym/games/rush_hour.py b/reasoning_gym/games/rush_hour.py index 61fa4f27..98f0a5ad 100644 --- a/reasoning_gym/games/rush_hour.py +++ b/reasoning_gym/games/rush_hour.py @@ -1,19 +1,19 @@ +import random import re from dataclasses import dataclass from typing import List, Optional, Tuple -import random from ..data import read_data_file from ..factory import ProceduralDataset, register_dataset -TEST_STRING = "BBoKMxDDDKMoIAALooIoJLEEooJFFNoGGoxN" @dataclass class RushHourConfig: """Configuration for Rush Hour puzzle generation""" + min_moves: int = 1 max_moves: int = 50 - seed: Optional[int] = None + seed: Optional[int] = None size: int = 500 def validate(self) -> None: @@ -100,10 +100,10 @@ class RushHourDataset(ProceduralDataset): def __init__(self, config: RushHourConfig): super().__init__(config=config, seed=config.seed, size=config.size) - + # Load and filter puzzles from data file self.puzzles: List[Tuple[str, int]] = [] # (board_config, min_moves) - + data = read_data_file("rush_18k.txt") for line in data.splitlines(): if not line.strip(): @@ -112,21 +112,19 @@ class RushHourDataset(ProceduralDataset): if len(parts) >= 2: min_moves = int(parts[0]) board_config = parts[1] - + if config.min_moves <= min_moves <= config.max_moves: self.puzzles.append((board_config, min_moves)) - + if not self.puzzles: - raise ValueError( - f"No puzzles found with moves between {config.min_moves} and {config.max_moves}" - ) + raise ValueError(f"No puzzles found with moves between {config.min_moves} and {config.max_moves}") def __getitem__(self, idx: int) -> dict: """Generate a single Rush Hour puzzle - + Args: idx: Index of the item to generate - + Returns: dict with keys: - question: str, the formatted board with instructions @@ -135,23 +133,23 @@ class RushHourDataset(ProceduralDataset): """ # Create deterministic RNG from base seed and idx rng = random.Random(self.seed + idx) - + # Randomly select a puzzle meeting our criteria board_config, min_moves = rng.choice(self.puzzles) - + # Create board and get string representation board = Board(board_config) board_display = board.board_str() - + instructions = ( "Move the red car (AA) to the exit on the right.\n" "Specify moves in the format: 'F+1 K+1 M-1 C+3 H+2 ...'\n" "where the letter is the vehicle and +/- number is spaces to move right/left or down/up." ) - + return { "question": f"{instructions}\n\nBoard:\n{board_display}", - "answer": "", # Multiple valid solutions exist + "answer": None, # Multiple valid solutions exist "metadata": { "board_config": board_config, "min_moves": min_moves, @@ -330,4 +328,3 @@ class Board: # Register the dataset register_dataset("rush_hour", RushHourDataset, RushHourConfig) - diff --git a/tests/test_rush_hour.py b/tests/test_rush_hour.py index 1ae748f6..5123ebc5 100644 --- a/tests/test_rush_hour.py +++ b/tests/test_rush_hour.py @@ -1,6 +1,65 @@ import pytest -from reasoning_gym.games.rush_hour import Board +from reasoning_gym.games.rush_hour import Board, RushHourConfig, RushHourDataset + + +def test_rush_hour_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = RushHourConfig(min_moves=0) + config.validate() + + with pytest.raises(AssertionError): + config = RushHourConfig(min_moves=3, max_moves=2) + config.validate() + + with pytest.raises(AssertionError): + config = RushHourConfig(size=0) + config.validate() + + +def test_rush_hour_deterministic(): + """Test that dataset generates same items with same seed""" + config = RushHourConfig(seed=42, size=10, min_moves=1, max_moves=50) + dataset1 = RushHourDataset(config) + dataset2 = RushHourDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i]["metadata"] == dataset2[i]["metadata"] + + +def test_rush_hour_items(): + """Test basic properties of generated items""" + config = RushHourConfig(min_moves=1, max_moves=10, size=10, seed=42) + dataset = RushHourDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Verify metadata contains required fields + assert "board_config" in item["metadata"] + assert "min_moves" in item["metadata"] + + # Verify min_moves is within configured range + assert config.min_moves <= item["metadata"]["min_moves"] <= config.max_moves + + # Verify board_config is valid length + assert len(item["metadata"]["board_config"]) == 36 # 6x6 board + + +def test_rush_hour_move_filtering(): + """Test that puzzles are filtered by move count""" + config = RushHourConfig(min_moves=5, max_moves=10, size=10, seed=42) + dataset = RushHourDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + moves = item["metadata"]["min_moves"] + assert 5 <= moves <= 10, f"Puzzle with {moves} moves outside configured range 5-10" def test_perform_moves():