mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-05-02 17:45:58 +00:00
shortest path
This commit is contained in:
parent
a1a305c8d7
commit
df914dfb49
3 changed files with 250 additions and 0 deletions
123
tests/test_shortest_path.py
Normal file
123
tests/test_shortest_path.py
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
"""Tests for Shortest Path questions generation"""
|
||||
|
||||
import pytest
|
||||
|
||||
from reasoning_gym.graphs.shortest_path import ShortestPathConfig, ShortestPathDataset
|
||||
|
||||
|
||||
def test_shortest_path_config_validation():
|
||||
"""Test that invalid configs raise appropriate errors"""
|
||||
with pytest.raises(AssertionError):
|
||||
config = ShortestPathConfig(min_rows=-1) # Negative not allowed
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = ShortestPathConfig(min_rows=0) # Zero not allowed
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = ShortestPathConfig(min_cols=-1) # Negative not allowed
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = ShortestPathConfig(min_cols=0) # Zero not allowed
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = ShortestPathConfig(min_rows=10, max_rows=5) # min > max
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = ShortestPathConfig(min_cols=10, max_cols=5)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = ShortestPathConfig(p_blocked=-0.1)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = ShortestPathConfig(p_blocked=1.1)
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_shortest_path_dataset_deterministic():
|
||||
"""Test that dataset generates same items with same seed"""
|
||||
config = ShortestPathConfig(seed=42, size=10)
|
||||
dataset1 = ShortestPathDataset(config)
|
||||
dataset2 = ShortestPathDataset(config)
|
||||
|
||||
for i in range(len(dataset1)):
|
||||
assert dataset1[i] == dataset2[i]
|
||||
|
||||
|
||||
def test_shortest_path_dataset_items():
|
||||
"""Test basic properties of generated items"""
|
||||
config = ShortestPathConfig(min_rows=3, max_rows=5, min_cols=3, max_cols=5, size=10, seed=42)
|
||||
dataset = ShortestPathDataset(config)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
# Check item structure
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Check metadata
|
||||
assert "matrix" in item["metadata"]
|
||||
assert "solution" in item["metadata"]
|
||||
|
||||
matrix = item["metadata"]["matrix"]
|
||||
solution = item["metadata"]["solution"]
|
||||
|
||||
# Verify values
|
||||
assert len(matrix) >= 3
|
||||
assert len(matrix) <= 5
|
||||
assert all(len(row) >= 3 for row in matrix)
|
||||
assert all(len(row) <= 5 for row in matrix)
|
||||
assert any(cell == "*" for row in matrix for cell in row) # Start cell
|
||||
assert any(cell == "#" for row in matrix for cell in row) # End cell
|
||||
|
||||
|
||||
def test_shortest_path_dataset_iteration():
|
||||
"""Test that iteration respects dataset size"""
|
||||
config = ShortestPathConfig(size=5, seed=42)
|
||||
dataset = ShortestPathDataset(config)
|
||||
|
||||
items = list(dataset)
|
||||
assert len(items) == config.size
|
||||
|
||||
# Test multiple iterations yield same items
|
||||
assert items == list(dataset)
|
||||
|
||||
|
||||
def test_shortest_path_answer():
|
||||
"""Test the _get_distances method"""
|
||||
config = ShortestPathConfig(seed=42)
|
||||
dataset = ShortestPathDataset(config)
|
||||
|
||||
# Simple
|
||||
matrix = [
|
||||
["X", "X", "X", "X", "X"],
|
||||
["X", "*", "O", "#", "X"],
|
||||
["X", "O", "X", "O", "X"],
|
||||
]
|
||||
assert dataset._get_answer(matrix) == 2
|
||||
|
||||
# One shot example in prompt
|
||||
matrix = [
|
||||
["X", "X", "X", "X", "X"],
|
||||
["X", "*", "O", "O", "X"],
|
||||
["X", "O", "X", "O", "X"],
|
||||
["X", "X", "X", "O", "#"],
|
||||
]
|
||||
assert dataset._get_answer(matrix) == 5
|
||||
|
||||
# Impossible solution
|
||||
matrix = [
|
||||
["X", "X", "X", "X", "X"],
|
||||
["X", "*", "O", "O", "X"],
|
||||
["X", "O", "X", "O", "X"],
|
||||
["X", "X", "X", "X", "#"],
|
||||
]
|
||||
assert dataset._get_answer(matrix) == -1
|
||||
Loading…
Add table
Add a link
Reference in a new issue