test: Add comprehensive unit tests for RushHourDataset

This commit is contained in:
Andreas Koepf (aider) 2025-02-14 17:27:20 +01:00 committed by Andreas Koepf
parent f0ad47a29a
commit ca5e23e195
2 changed files with 75 additions and 19 deletions

View file

@ -1,19 +1,19 @@
import random
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import random
from ..data import read_data_file from ..data import read_data_file
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
TEST_STRING = "BBoKMxDDDKMoIAALooIoJLEEooJFFNoGGoxN"
@dataclass @dataclass
class RushHourConfig: class RushHourConfig:
"""Configuration for Rush Hour puzzle generation""" """Configuration for Rush Hour puzzle generation"""
min_moves: int = 1 min_moves: int = 1
max_moves: int = 50 max_moves: int = 50
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 size: int = 500
def validate(self) -> None: def validate(self) -> None:
@ -100,10 +100,10 @@ class RushHourDataset(ProceduralDataset):
def __init__(self, config: RushHourConfig): def __init__(self, config: RushHourConfig):
super().__init__(config=config, seed=config.seed, size=config.size) super().__init__(config=config, seed=config.seed, size=config.size)
# Load and filter puzzles from data file # Load and filter puzzles from data file
self.puzzles: List[Tuple[str, int]] = [] # (board_config, min_moves) self.puzzles: List[Tuple[str, int]] = [] # (board_config, min_moves)
data = read_data_file("rush_18k.txt") data = read_data_file("rush_18k.txt")
for line in data.splitlines(): for line in data.splitlines():
if not line.strip(): if not line.strip():
@ -112,21 +112,19 @@ class RushHourDataset(ProceduralDataset):
if len(parts) >= 2: if len(parts) >= 2:
min_moves = int(parts[0]) min_moves = int(parts[0])
board_config = parts[1] board_config = parts[1]
if config.min_moves <= min_moves <= config.max_moves: if config.min_moves <= min_moves <= config.max_moves:
self.puzzles.append((board_config, min_moves)) self.puzzles.append((board_config, min_moves))
if not self.puzzles: if not self.puzzles:
raise ValueError( raise ValueError(f"No puzzles found with moves between {config.min_moves} and {config.max_moves}")
f"No puzzles found with moves between {config.min_moves} and {config.max_moves}"
)
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:
"""Generate a single Rush Hour puzzle """Generate a single Rush Hour puzzle
Args: Args:
idx: Index of the item to generate idx: Index of the item to generate
Returns: Returns:
dict with keys: dict with keys:
- question: str, the formatted board with instructions - question: str, the formatted board with instructions
@ -135,23 +133,23 @@ class RushHourDataset(ProceduralDataset):
""" """
# Create deterministic RNG from base seed and idx # Create deterministic RNG from base seed and idx
rng = random.Random(self.seed + idx) rng = random.Random(self.seed + idx)
# Randomly select a puzzle meeting our criteria # Randomly select a puzzle meeting our criteria
board_config, min_moves = rng.choice(self.puzzles) board_config, min_moves = rng.choice(self.puzzles)
# Create board and get string representation # Create board and get string representation
board = Board(board_config) board = Board(board_config)
board_display = board.board_str() board_display = board.board_str()
instructions = ( instructions = (
"Move the red car (AA) to the exit on the right.\n" "Move the red car (AA) to the exit on the right.\n"
"Specify moves in the format: 'F+1 K+1 M-1 C+3 H+2 ...'\n" "Specify moves in the format: 'F+1 K+1 M-1 C+3 H+2 ...'\n"
"where the letter is the vehicle and +/- number is spaces to move right/left or down/up." "where the letter is the vehicle and +/- number is spaces to move right/left or down/up."
) )
return { return {
"question": f"{instructions}\n\nBoard:\n{board_display}", "question": f"{instructions}\n\nBoard:\n{board_display}",
"answer": "", # Multiple valid solutions exist "answer": None, # Multiple valid solutions exist
"metadata": { "metadata": {
"board_config": board_config, "board_config": board_config,
"min_moves": min_moves, "min_moves": min_moves,
@ -330,4 +328,3 @@ class Board:
# Register the dataset # Register the dataset
register_dataset("rush_hour", RushHourDataset, RushHourConfig) register_dataset("rush_hour", RushHourDataset, RushHourConfig)

View file

@ -1,6 +1,65 @@
import pytest import pytest
from reasoning_gym.games.rush_hour import Board from reasoning_gym.games.rush_hour import Board, RushHourConfig, RushHourDataset
def test_rush_hour_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = RushHourConfig(min_moves=0)
config.validate()
with pytest.raises(AssertionError):
config = RushHourConfig(min_moves=3, max_moves=2)
config.validate()
with pytest.raises(AssertionError):
config = RushHourConfig(size=0)
config.validate()
def test_rush_hour_deterministic():
"""Test that dataset generates same items with same seed"""
config = RushHourConfig(seed=42, size=10, min_moves=1, max_moves=50)
dataset1 = RushHourDataset(config)
dataset2 = RushHourDataset(config)
for i in range(len(dataset1)):
assert dataset1[i]["metadata"] == dataset2[i]["metadata"]
def test_rush_hour_items():
"""Test basic properties of generated items"""
config = RushHourConfig(min_moves=1, max_moves=10, size=10, seed=42)
dataset = RushHourDataset(config)
for i in range(len(dataset)):
item = dataset[i]
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Verify metadata contains required fields
assert "board_config" in item["metadata"]
assert "min_moves" in item["metadata"]
# Verify min_moves is within configured range
assert config.min_moves <= item["metadata"]["min_moves"] <= config.max_moves
# Verify board_config is valid length
assert len(item["metadata"]["board_config"]) == 36 # 6x6 board
def test_rush_hour_move_filtering():
"""Test that puzzles are filtered by move count"""
config = RushHourConfig(min_moves=5, max_moves=10, size=10, seed=42)
dataset = RushHourDataset(config)
for i in range(len(dataset)):
item = dataset[i]
moves = item["metadata"]["min_moves"]
assert 5 <= moves <= 10, f"Puzzle with {moves} moves outside configured range 5-10"
def test_perform_moves(): def test_perform_moves():