mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-24 17:05:03 +00:00
lint
This commit is contained in:
parent
1e27021e11
commit
6c564b3dd9
13 changed files with 305 additions and 317 deletions
|
|
@ -1,6 +1,7 @@
|
|||
import pytest
|
||||
from random import Random
|
||||
import time
|
||||
from random import Random
|
||||
|
||||
import pytest
|
||||
|
||||
from reasoning_gym.algorithmic.word_ladder import WordLadderConfig, WordLadderDataset
|
||||
|
||||
|
|
@ -63,14 +64,14 @@ def test_word_ladder_dataset_unique_pairs():
|
|||
"""Test that generated word pairs are unique"""
|
||||
config = WordLadderConfig(size=50, seed=42)
|
||||
dataset = WordLadderDataset(config)
|
||||
|
||||
|
||||
# Track all generated pairs
|
||||
seen_pairs = set()
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
pair = (
|
||||
min(item["metadata"]["start_word"], item["metadata"]["end_word"]),
|
||||
max(item["metadata"]["start_word"], item["metadata"]["end_word"])
|
||||
max(item["metadata"]["start_word"], item["metadata"]["end_word"]),
|
||||
)
|
||||
assert pair not in seen_pairs, f"Duplicate pair found: {pair}"
|
||||
seen_pairs.add(pair)
|
||||
|
|
@ -80,23 +81,13 @@ def test_word_ladder_dataset_items():
|
|||
"""Test basic properties of generated items"""
|
||||
# Test with specific chain length constraints
|
||||
config1 = WordLadderConfig(
|
||||
min_word_length=3,
|
||||
max_word_length=5,
|
||||
min_chain_length=3,
|
||||
max_chain_length=5,
|
||||
size=10,
|
||||
seed=42
|
||||
min_word_length=3, max_word_length=5, min_chain_length=3, max_chain_length=5, size=10, seed=42
|
||||
)
|
||||
dataset1 = WordLadderDataset(config1)
|
||||
|
||||
# Test with shortest path mode
|
||||
config2 = WordLadderConfig(
|
||||
min_word_length=3,
|
||||
max_word_length=5,
|
||||
min_chain_length=-1,
|
||||
max_chain_length=-1,
|
||||
size=10,
|
||||
seed=42
|
||||
min_word_length=3, max_word_length=5, min_chain_length=-1, max_chain_length=-1, size=10, seed=42
|
||||
)
|
||||
dataset2 = WordLadderDataset(config2)
|
||||
|
||||
|
|
@ -124,13 +115,13 @@ def test_word_ladder_dataset_items():
|
|||
|
||||
# Verify solution chain from answer
|
||||
solution_chain = item["answer"].split(",")
|
||||
|
||||
|
||||
# Verify chain length based on config
|
||||
if dataset.config.min_chain_length == -1:
|
||||
assert len(solution_chain) >= 3
|
||||
else:
|
||||
assert dataset.config.min_chain_length <= len(solution_chain) <= dataset.config.max_chain_length
|
||||
|
||||
|
||||
assert solution_chain[0] == metadata["start_word"]
|
||||
assert solution_chain[-1] == metadata["end_word"]
|
||||
|
||||
|
|
@ -171,35 +162,32 @@ def test_word_ladder_path_finding():
|
|||
min_chain_length=-1, # Shortest path mode
|
||||
max_chain_length=-1,
|
||||
size=10,
|
||||
seed=42
|
||||
seed=42,
|
||||
)
|
||||
dataset = WordLadderDataset(config)
|
||||
|
||||
|
||||
# Test finding path between known words
|
||||
word_set = dataset.word_sets[4]
|
||||
path = dataset._find_path("WORD", "FIND", word_set)
|
||||
|
||||
|
||||
# Verify path properties
|
||||
assert path is not None
|
||||
assert path[0] == "WORD"
|
||||
assert path[-1] == "FIND"
|
||||
assert len(path) >= 3
|
||||
|
||||
|
||||
# Verify each step differs by only one letter
|
||||
for i in range(len(path)-1):
|
||||
for i in range(len(path) - 1):
|
||||
current = path[i]
|
||||
next_word = path[i+1]
|
||||
next_word = path[i + 1]
|
||||
assert next_word in dataset._get_neighbors(current, word_set)
|
||||
|
||||
|
||||
def test_word_ladder_csv_loading():
|
||||
"""Test word loading from CSV"""
|
||||
config = WordLadderConfig(
|
||||
min_word_length=3,
|
||||
max_word_length=5
|
||||
)
|
||||
config = WordLadderConfig(min_word_length=3, max_word_length=5)
|
||||
dataset = WordLadderDataset(config)
|
||||
|
||||
|
||||
# Verify word sets for each length
|
||||
for length in range(3, 6):
|
||||
assert length in dataset.word_sets
|
||||
|
|
@ -210,7 +198,7 @@ def test_word_ladder_csv_loading():
|
|||
assert len(word) == length
|
||||
assert word.isupper()
|
||||
assert word.isalpha()
|
||||
|
||||
|
||||
# Test invalid length range
|
||||
with pytest.raises(AssertionError):
|
||||
bad_config = WordLadderConfig(min_word_length=2, max_word_length=7)
|
||||
|
|
@ -219,23 +207,18 @@ def test_word_ladder_csv_loading():
|
|||
|
||||
def test_word_ladder_pair_generation():
|
||||
"""Test word pair generation logic"""
|
||||
config = WordLadderConfig(
|
||||
min_word_length=4,
|
||||
max_word_length=4,
|
||||
size=10,
|
||||
seed=42
|
||||
)
|
||||
config = WordLadderConfig(min_word_length=4, max_word_length=4, size=10, seed=42)
|
||||
dataset = WordLadderDataset(config)
|
||||
|
||||
|
||||
# Test pair generation
|
||||
rng = Random(42)
|
||||
start, end, path = dataset._generate_word_pair(rng, 4)
|
||||
|
||||
|
||||
# Verify path properties
|
||||
assert start == path[0]
|
||||
assert end == path[-1]
|
||||
assert len(path) >= 3
|
||||
|
||||
|
||||
# Verify path is valid (each step differs by one letter)
|
||||
for i in range(len(path) - 1):
|
||||
current = path[i]
|
||||
|
|
@ -247,17 +230,17 @@ def test_word_graph_caching():
|
|||
"""Test word graph caching functionality"""
|
||||
config = WordLadderConfig(seed=42)
|
||||
dataset = WordLadderDataset(config)
|
||||
|
||||
|
||||
# Verify initial graphs are precomputed
|
||||
assert len(dataset.word_graphs) > 0
|
||||
|
||||
|
||||
# Get initial graph for length 4
|
||||
graph_4 = dataset.word_graphs[4]
|
||||
|
||||
|
||||
# Verify cached graph is returned
|
||||
cached_graph = dataset._build_word_graph(4)
|
||||
assert cached_graph is dataset.word_graphs[4]
|
||||
|
||||
|
||||
# Verify graph structure
|
||||
for word, neighbors in graph_4.items():
|
||||
assert len(word) == 4
|
||||
|
|
@ -269,33 +252,24 @@ def test_word_graph_caching():
|
|||
|
||||
def test_word_ladder_path_validation():
|
||||
"""Test path length validation logic"""
|
||||
config = WordLadderConfig(
|
||||
min_chain_length=4,
|
||||
max_chain_length=6
|
||||
)
|
||||
|
||||
config = WordLadderConfig(min_chain_length=4, max_chain_length=6)
|
||||
|
||||
# Test specific length mode
|
||||
assert config.is_valid_path_length(4) # Min length
|
||||
assert config.is_valid_path_length(5) # Middle length
|
||||
assert config.is_valid_path_length(6) # Max length
|
||||
assert not config.is_valid_path_length(3) # Too short
|
||||
assert not config.is_valid_path_length(7) # Too long
|
||||
|
||||
|
||||
# Test shortest path mode
|
||||
config_shortest = WordLadderConfig(
|
||||
min_chain_length=-1,
|
||||
max_chain_length=-1
|
||||
)
|
||||
config_shortest = WordLadderConfig(min_chain_length=-1, max_chain_length=-1)
|
||||
assert config_shortest.is_valid_path_length(3)
|
||||
assert config_shortest.is_valid_path_length(4)
|
||||
assert config_shortest.is_valid_path_length(10)
|
||||
assert not config_shortest.is_valid_path_length(2)
|
||||
|
||||
|
||||
# Test mixed mode (shortest with max limit)
|
||||
config_mixed = WordLadderConfig(
|
||||
min_chain_length=-1,
|
||||
max_chain_length=5
|
||||
)
|
||||
config_mixed = WordLadderConfig(min_chain_length=-1, max_chain_length=5)
|
||||
assert config_mixed.is_valid_path_length(3)
|
||||
assert config_mixed.is_valid_path_length(4)
|
||||
assert config_mixed.is_valid_path_length(5)
|
||||
|
|
@ -305,44 +279,41 @@ def test_word_ladder_path_validation():
|
|||
def test_word_ladder_solution_optimality():
|
||||
"""Test that generated solutions are optimal when min_chain_length=-1"""
|
||||
config = WordLadderConfig(
|
||||
min_word_length=4,
|
||||
max_word_length=4,
|
||||
min_chain_length=-1,
|
||||
max_chain_length=-1,
|
||||
size=20,
|
||||
seed=42
|
||||
min_word_length=4, max_word_length=4, min_chain_length=-1, max_chain_length=-1, size=20, seed=42
|
||||
)
|
||||
dataset = WordLadderDataset(config)
|
||||
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
solution_chain = item["answer"].split(",")
|
||||
start_word = item["metadata"]["start_word"]
|
||||
end_word = item["metadata"]["end_word"]
|
||||
|
||||
|
||||
# Verify this is the shortest possible path
|
||||
word_set = dataset.word_sets[len(start_word)]
|
||||
|
||||
|
||||
# Build graph and use BFS to find shortest path
|
||||
from collections import deque
|
||||
|
||||
queue = deque([(start_word, [start_word])])
|
||||
visited = {start_word}
|
||||
shortest_path = None
|
||||
|
||||
|
||||
while queue and not shortest_path:
|
||||
current_word, path = queue.popleft()
|
||||
if current_word == end_word:
|
||||
shortest_path = path
|
||||
break
|
||||
|
||||
|
||||
for neighbor in dataset._get_neighbors(current_word, word_set):
|
||||
if neighbor not in visited:
|
||||
visited.add(neighbor)
|
||||
queue.append((neighbor, path + [neighbor]))
|
||||
|
||||
|
||||
assert shortest_path is not None, f"No path found between {start_word} and {end_word}"
|
||||
assert len(solution_chain) == len(shortest_path), \
|
||||
f"Solution {solution_chain} is not optimal. Shortest path: {shortest_path}"
|
||||
assert len(solution_chain) == len(
|
||||
shortest_path
|
||||
), f"Solution {solution_chain} is not optimal. Shortest path: {shortest_path}"
|
||||
|
||||
|
||||
def test_word_ladder_performance():
|
||||
|
|
@ -351,13 +322,13 @@ def test_word_ladder_performance():
|
|||
start_time = time.time()
|
||||
dataset = WordLadderDataset(config)
|
||||
init_time = time.time() - start_time
|
||||
|
||||
|
||||
# Test item generation time
|
||||
start_time = time.time()
|
||||
for i in range(len(dataset)):
|
||||
_ = dataset[i]
|
||||
access_time = time.time() - start_time
|
||||
|
||||
|
||||
# These thresholds should be adjusted based on requirements
|
||||
assert init_time < 2.0, f"Initialization took too long: {init_time:.2f}s"
|
||||
assert access_time < 1.0, f"Data access took too long: {access_time:.2f}s"
|
||||
|
|
@ -369,24 +340,18 @@ def test_word_ladder_edge_cases():
|
|||
config = WordLadderConfig(size=1)
|
||||
dataset = WordLadderDataset(config)
|
||||
assert len(dataset) == 1
|
||||
|
||||
|
||||
# Test with same start/end word length but maximum distance
|
||||
config = WordLadderConfig(
|
||||
min_word_length=4,
|
||||
max_word_length=4,
|
||||
min_chain_length=-1,
|
||||
max_chain_length=-1,
|
||||
size=10
|
||||
)
|
||||
config = WordLadderConfig(min_word_length=4, max_word_length=4, min_chain_length=-1, max_chain_length=-1, size=10)
|
||||
dataset = WordLadderDataset(config)
|
||||
|
||||
|
||||
# Find the pair with longest solution
|
||||
max_length = 0
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
chain_length = len(item["answer"].split(","))
|
||||
max_length = max(max_length, chain_length)
|
||||
|
||||
|
||||
assert max_length > 3, "No challenging word pairs generated"
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue