test: Add comprehensive unit tests for RushHourDataset

This commit is contained in:
Andreas Koepf (aider) 2025-02-14 17:27:20 +01:00 committed by Andreas Koepf
parent f0ad47a29a
commit ca5e23e195
2 changed files with 75 additions and 19 deletions

View file

@ -1,16 +1,16 @@
import random
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import random
from ..data import read_data_file from ..data import read_data_file
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
TEST_STRING = "BBoKMxDDDKMoIAALooIoJLEEooJFFNoGGoxN"
@dataclass @dataclass
class RushHourConfig: class RushHourConfig:
"""Configuration for Rush Hour puzzle generation""" """Configuration for Rush Hour puzzle generation"""
min_moves: int = 1 min_moves: int = 1
max_moves: int = 50 max_moves: int = 50
seed: Optional[int] = None seed: Optional[int] = None
@ -117,9 +117,7 @@ class RushHourDataset(ProceduralDataset):
self.puzzles.append((board_config, min_moves)) self.puzzles.append((board_config, min_moves))
if not self.puzzles: if not self.puzzles:
raise ValueError( raise ValueError(f"No puzzles found with moves between {config.min_moves} and {config.max_moves}")
f"No puzzles found with moves between {config.min_moves} and {config.max_moves}"
)
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:
"""Generate a single Rush Hour puzzle """Generate a single Rush Hour puzzle
@ -151,7 +149,7 @@ class RushHourDataset(ProceduralDataset):
return { return {
"question": f"{instructions}\n\nBoard:\n{board_display}", "question": f"{instructions}\n\nBoard:\n{board_display}",
"answer": "", # Multiple valid solutions exist "answer": None, # Multiple valid solutions exist
"metadata": { "metadata": {
"board_config": board_config, "board_config": board_config,
"min_moves": min_moves, "min_moves": min_moves,
@ -330,4 +328,3 @@ class Board:
# Register the dataset # Register the dataset
register_dataset("rush_hour", RushHourDataset, RushHourConfig) register_dataset("rush_hour", RushHourDataset, RushHourConfig)

View file

@ -1,6 +1,65 @@
import pytest import pytest
from reasoning_gym.games.rush_hour import Board from reasoning_gym.games.rush_hour import Board, RushHourConfig, RushHourDataset
def test_rush_hour_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = RushHourConfig(min_moves=0)
config.validate()
with pytest.raises(AssertionError):
config = RushHourConfig(min_moves=3, max_moves=2)
config.validate()
with pytest.raises(AssertionError):
config = RushHourConfig(size=0)
config.validate()
def test_rush_hour_deterministic():
"""Test that dataset generates same items with same seed"""
config = RushHourConfig(seed=42, size=10, min_moves=1, max_moves=50)
dataset1 = RushHourDataset(config)
dataset2 = RushHourDataset(config)
for i in range(len(dataset1)):
assert dataset1[i]["metadata"] == dataset2[i]["metadata"]
def test_rush_hour_items():
"""Test basic properties of generated items"""
config = RushHourConfig(min_moves=1, max_moves=10, size=10, seed=42)
dataset = RushHourDataset(config)
for i in range(len(dataset)):
item = dataset[i]
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Verify metadata contains required fields
assert "board_config" in item["metadata"]
assert "min_moves" in item["metadata"]
# Verify min_moves is within configured range
assert config.min_moves <= item["metadata"]["min_moves"] <= config.max_moves
# Verify board_config is valid length
assert len(item["metadata"]["board_config"]) == 36 # 6x6 board
def test_rush_hour_move_filtering():
"""Test that puzzles are filtered by move count"""
config = RushHourConfig(min_moves=5, max_moves=10, size=10, seed=42)
dataset = RushHourDataset(config)
for i in range(len(dataset)):
item = dataset[i]
moves = item["metadata"]["min_moves"]
assert 5 <= moves <= 10, f"Puzzle with {moves} moves outside configured range 5-10"
def test_perform_moves(): def test_perform_moves():