mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-28 17:29:39 +00:00
test: Add comprehensive unit tests for RushHourDataset
This commit is contained in:
parent
f0ad47a29a
commit
ca5e23e195
2 changed files with 75 additions and 19 deletions
|
|
@ -1,19 +1,19 @@
|
||||||
|
import random
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
import random
|
|
||||||
|
|
||||||
from ..data import read_data_file
|
from ..data import read_data_file
|
||||||
from ..factory import ProceduralDataset, register_dataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
TEST_STRING = "BBoKMxDDDKMoIAALooIoJLEEooJFFNoGGoxN"
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RushHourConfig:
|
class RushHourConfig:
|
||||||
"""Configuration for Rush Hour puzzle generation"""
|
"""Configuration for Rush Hour puzzle generation"""
|
||||||
|
|
||||||
min_moves: int = 1
|
min_moves: int = 1
|
||||||
max_moves: int = 50
|
max_moves: int = 50
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500
|
size: int = 500
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
|
|
@ -100,10 +100,10 @@ class RushHourDataset(ProceduralDataset):
|
||||||
|
|
||||||
def __init__(self, config: RushHourConfig):
|
def __init__(self, config: RushHourConfig):
|
||||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
|
|
||||||
# Load and filter puzzles from data file
|
# Load and filter puzzles from data file
|
||||||
self.puzzles: List[Tuple[str, int]] = [] # (board_config, min_moves)
|
self.puzzles: List[Tuple[str, int]] = [] # (board_config, min_moves)
|
||||||
|
|
||||||
data = read_data_file("rush_18k.txt")
|
data = read_data_file("rush_18k.txt")
|
||||||
for line in data.splitlines():
|
for line in data.splitlines():
|
||||||
if not line.strip():
|
if not line.strip():
|
||||||
|
|
@ -112,21 +112,19 @@ class RushHourDataset(ProceduralDataset):
|
||||||
if len(parts) >= 2:
|
if len(parts) >= 2:
|
||||||
min_moves = int(parts[0])
|
min_moves = int(parts[0])
|
||||||
board_config = parts[1]
|
board_config = parts[1]
|
||||||
|
|
||||||
if config.min_moves <= min_moves <= config.max_moves:
|
if config.min_moves <= min_moves <= config.max_moves:
|
||||||
self.puzzles.append((board_config, min_moves))
|
self.puzzles.append((board_config, min_moves))
|
||||||
|
|
||||||
if not self.puzzles:
|
if not self.puzzles:
|
||||||
raise ValueError(
|
raise ValueError(f"No puzzles found with moves between {config.min_moves} and {config.max_moves}")
|
||||||
f"No puzzles found with moves between {config.min_moves} and {config.max_moves}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> dict:
|
def __getitem__(self, idx: int) -> dict:
|
||||||
"""Generate a single Rush Hour puzzle
|
"""Generate a single Rush Hour puzzle
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
idx: Index of the item to generate
|
idx: Index of the item to generate
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict with keys:
|
dict with keys:
|
||||||
- question: str, the formatted board with instructions
|
- question: str, the formatted board with instructions
|
||||||
|
|
@ -135,23 +133,23 @@ class RushHourDataset(ProceduralDataset):
|
||||||
"""
|
"""
|
||||||
# Create deterministic RNG from base seed and idx
|
# Create deterministic RNG from base seed and idx
|
||||||
rng = random.Random(self.seed + idx)
|
rng = random.Random(self.seed + idx)
|
||||||
|
|
||||||
# Randomly select a puzzle meeting our criteria
|
# Randomly select a puzzle meeting our criteria
|
||||||
board_config, min_moves = rng.choice(self.puzzles)
|
board_config, min_moves = rng.choice(self.puzzles)
|
||||||
|
|
||||||
# Create board and get string representation
|
# Create board and get string representation
|
||||||
board = Board(board_config)
|
board = Board(board_config)
|
||||||
board_display = board.board_str()
|
board_display = board.board_str()
|
||||||
|
|
||||||
instructions = (
|
instructions = (
|
||||||
"Move the red car (AA) to the exit on the right.\n"
|
"Move the red car (AA) to the exit on the right.\n"
|
||||||
"Specify moves in the format: 'F+1 K+1 M-1 C+3 H+2 ...'\n"
|
"Specify moves in the format: 'F+1 K+1 M-1 C+3 H+2 ...'\n"
|
||||||
"where the letter is the vehicle and +/- number is spaces to move right/left or down/up."
|
"where the letter is the vehicle and +/- number is spaces to move right/left or down/up."
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": f"{instructions}\n\nBoard:\n{board_display}",
|
"question": f"{instructions}\n\nBoard:\n{board_display}",
|
||||||
"answer": "", # Multiple valid solutions exist
|
"answer": None, # Multiple valid solutions exist
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"board_config": board_config,
|
"board_config": board_config,
|
||||||
"min_moves": min_moves,
|
"min_moves": min_moves,
|
||||||
|
|
@ -330,4 +328,3 @@ class Board:
|
||||||
|
|
||||||
# Register the dataset
|
# Register the dataset
|
||||||
register_dataset("rush_hour", RushHourDataset, RushHourConfig)
|
register_dataset("rush_hour", RushHourDataset, RushHourConfig)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,65 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.games.rush_hour import Board
|
from reasoning_gym.games.rush_hour import Board, RushHourConfig, RushHourDataset
|
||||||
|
|
||||||
|
|
||||||
|
def test_rush_hour_config_validation():
|
||||||
|
"""Test that invalid configs raise appropriate errors"""
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = RushHourConfig(min_moves=0)
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = RushHourConfig(min_moves=3, max_moves=2)
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = RushHourConfig(size=0)
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
|
||||||
|
def test_rush_hour_deterministic():
|
||||||
|
"""Test that dataset generates same items with same seed"""
|
||||||
|
config = RushHourConfig(seed=42, size=10, min_moves=1, max_moves=50)
|
||||||
|
dataset1 = RushHourDataset(config)
|
||||||
|
dataset2 = RushHourDataset(config)
|
||||||
|
|
||||||
|
for i in range(len(dataset1)):
|
||||||
|
assert dataset1[i]["metadata"] == dataset2[i]["metadata"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_rush_hour_items():
|
||||||
|
"""Test basic properties of generated items"""
|
||||||
|
config = RushHourConfig(min_moves=1, max_moves=10, size=10, seed=42)
|
||||||
|
dataset = RushHourDataset(config)
|
||||||
|
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
item = dataset[i]
|
||||||
|
assert isinstance(item, dict)
|
||||||
|
assert "question" in item
|
||||||
|
assert "answer" in item
|
||||||
|
assert "metadata" in item
|
||||||
|
|
||||||
|
# Verify metadata contains required fields
|
||||||
|
assert "board_config" in item["metadata"]
|
||||||
|
assert "min_moves" in item["metadata"]
|
||||||
|
|
||||||
|
# Verify min_moves is within configured range
|
||||||
|
assert config.min_moves <= item["metadata"]["min_moves"] <= config.max_moves
|
||||||
|
|
||||||
|
# Verify board_config is valid length
|
||||||
|
assert len(item["metadata"]["board_config"]) == 36 # 6x6 board
|
||||||
|
|
||||||
|
|
||||||
|
def test_rush_hour_move_filtering():
|
||||||
|
"""Test that puzzles are filtered by move count"""
|
||||||
|
config = RushHourConfig(min_moves=5, max_moves=10, size=10, seed=42)
|
||||||
|
dataset = RushHourDataset(config)
|
||||||
|
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
item = dataset[i]
|
||||||
|
moves = item["metadata"]["min_moves"]
|
||||||
|
assert 5 <= moves <= 10, f"Puzzle with {moves} moves outside configured range 5-10"
|
||||||
|
|
||||||
|
|
||||||
def test_perform_moves():
|
def test_perform_moves():
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue