mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-25 17:10:51 +00:00
Refactor word ladder generation with improved validation and graph-based path finding
- Enhanced configuration validation with size and length constraints - Implemented graph-based neighbor computation and caching - Simplified path finding algorithm with more robust length checking - Added more flexible word set loading with configurable length ranges - Improved error handling for dataset generation
This commit is contained in:
parent
7b61fc5043
commit
d5065955a8
1 changed files with 116 additions and 120 deletions
|
|
@ -12,18 +12,22 @@ from ..factory import ProceduralDataset, register_dataset
|
|||
class WordLadderConfig:
|
||||
"""Configuration for word ladder task generation"""
|
||||
|
||||
min_word_length: int = 3 # Minimum word length
|
||||
max_word_length: int = 5 # Maximum word length
|
||||
min_word_length: int = 4 # Minimum word length
|
||||
max_word_length: int = 4 # Maximum word length
|
||||
min_chain_length: int = -1 # Set to -1 for shortest path or a minimum of 3
|
||||
max_chain_length: int = -1 # Set to -1 for shortest path or a max
|
||||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
size: int = 500
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_word_length > 2, "min_word_length must be 3"
|
||||
assert self.min_word_length >= 3, "min_word_length must be >= 3"
|
||||
assert self.max_word_length >= self.min_word_length, "max_word_length must be >= min_word_length"
|
||||
assert self.max_word_length <= 5, "max_word_length must be 5"
|
||||
assert self.max_word_length <= 5, "max_word_length must be <= 5"
|
||||
|
||||
# Add size validation
|
||||
if self.size > 20000: # Add reasonable upper limit
|
||||
raise ValueError("Dataset size too large for this algorithm and constraints")
|
||||
|
||||
# Modified validation logic
|
||||
if self.min_chain_length == -1:
|
||||
|
|
@ -35,22 +39,46 @@ class WordLadderConfig:
|
|||
assert self.min_chain_length >= 3, "min_chain_length must be 3 or -1"
|
||||
assert self.max_chain_length >= self.min_chain_length, "max_chain_length must be >= min_chain_length"
|
||||
|
||||
def is_valid_path_length(self, length: int) -> bool:
|
||||
"""Check if a path length meets the configuration requirements"""
|
||||
# When min_chain_length is -1, we accept any path of length >= 3
|
||||
if self.min_chain_length == -1:
|
||||
if self.max_chain_length == -1:
|
||||
return length >= 3
|
||||
return 3 <= length <= self.max_chain_length
|
||||
|
||||
# Otherwise check against both min and max
|
||||
return (self.min_chain_length <= length <=
|
||||
(self.max_chain_length if self.max_chain_length != -1 else float('inf')))
|
||||
|
||||
class WordLadderDataset(ProceduralDataset):
|
||||
"""Generates word ladder transformation tasks"""
|
||||
|
||||
def __init__(self, config: WordLadderConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
self.config = config
|
||||
self.word_sets = {}
|
||||
self.word_graphs = {}
|
||||
|
||||
# Load words from CSV file
|
||||
self.word_sets = self._load_words_from_csv()
|
||||
# Precompute the sorted word lists for each word length
|
||||
self.words_lists = {
|
||||
length: sorted(words)
|
||||
for length, words in self.word_sets.items()
|
||||
}
|
||||
# Load words from CSV
|
||||
self.word_sets = self._load_words_from_csv(
|
||||
min_length=self.config.min_word_length,
|
||||
max_length=self.config.max_word_length
|
||||
)
|
||||
|
||||
# Precompute word graphs for all lengths
|
||||
for length in range(self.config.min_word_length, self.config.max_word_length + 1):
|
||||
self.word_graphs[length] = self._build_word_graph(length)
|
||||
|
||||
config.validate()
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _load_words_from_csv(self) -> Dict[int, Set[str]]:
|
||||
|
||||
@classmethod
|
||||
def _load_words_from_csv(cls, min_length: int = 3, max_length: int = 5) -> Dict[int, Set[str]]:
|
||||
"""Load words from CSV file organized by length"""
|
||||
# Validate length range before processing
|
||||
assert 3 <= min_length <= max_length <= 5, "Word length must be between 3 and 5 inclusive"
|
||||
|
||||
import csv
|
||||
from io import StringIO
|
||||
word_sets = {}
|
||||
|
|
@ -64,153 +92,122 @@ class WordLadderDataset(ProceduralDataset):
|
|||
reader = csv.DictReader(csv_file)
|
||||
|
||||
for row in reader:
|
||||
# Process each word length column
|
||||
for length in range(3, 6):
|
||||
# Process each word length column using config range
|
||||
for length in range(min_length, max_length + 1):
|
||||
col_name = f'{length}_letter'
|
||||
word = row.get(col_name, '')
|
||||
|
||||
if not word: # Skip empty entries
|
||||
continue
|
||||
|
||||
if self.config.min_word_length <= length <= self.config.max_word_length:
|
||||
word_sets.setdefault(length, set()).add(word.upper())
|
||||
word_sets.setdefault(length, set()).add(word.upper())
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error processing words.csv content: {e}") from e
|
||||
|
||||
# Validate we have words for each length
|
||||
for length in range(self.config.min_word_length, self.config.max_word_length + 1):
|
||||
for length in range(min_length, max_length + 1):
|
||||
if length not in word_sets or not word_sets[length]:
|
||||
raise ValueError(f"No valid words found for length {length}")
|
||||
|
||||
return word_sets
|
||||
|
||||
def _differs_by_one(self, word1: str, word2: str) -> bool:
|
||||
"""Check if two words differ by exactly one letter"""
|
||||
if len(word1) != len(word2):
|
||||
return False
|
||||
differences = 0
|
||||
for c1, c2 in zip(word1, word2):
|
||||
if c1 != c2:
|
||||
differences += 1
|
||||
if differences > 1:
|
||||
return False
|
||||
return differences == 1
|
||||
def _get_neighbors(self, word: str, word_set: Set[str]) -> Set[str]:
|
||||
"""Get neighbors from either precomputed graph or by computing on demand"""
|
||||
# Try precomputed graph first
|
||||
if len(word) in self.word_graphs and word in self.word_graphs[len(word)]:
|
||||
return self.word_graphs[len(word)].get(word, set())
|
||||
|
||||
# Fall back to computing neighbors directly for custom word sets
|
||||
neighbors = set()
|
||||
for i in range(len(word)):
|
||||
for c in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ':
|
||||
neighbor = word[:i] + c + word[i+1:]
|
||||
if neighbor != word and neighbor in word_set:
|
||||
neighbors.add(neighbor)
|
||||
return neighbors
|
||||
|
||||
def _build_word_graph(self, word_length: int) -> Dict[str, Set[str]]:
|
||||
"""Build graph of word connections for given length, using caching"""
|
||||
# Return cached graph if it exists
|
||||
if word_length in self.word_graphs:
|
||||
return self.word_graphs[word_length]
|
||||
|
||||
# Build new graph
|
||||
word_set = self.word_sets[word_length]
|
||||
graph = {}
|
||||
|
||||
# Build connections
|
||||
for word in word_set:
|
||||
neighbors = set()
|
||||
for i in range(word_length):
|
||||
for c in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ':
|
||||
neighbor = word[:i] + c + word[i+1:]
|
||||
if neighbor != word and neighbor in word_set:
|
||||
neighbors.add(neighbor)
|
||||
graph[word] = neighbors
|
||||
|
||||
# Cache and return
|
||||
self.word_graphs[word_length] = graph
|
||||
return self.word_graphs[word_length]
|
||||
|
||||
def _find_path(self, start: str, end: str, word_set: Set[str]) -> Optional[List[str]]:
|
||||
"""Find path between start and end words that meets length requirements"""
|
||||
if start == end:
|
||||
return [start]
|
||||
|
||||
# First find shortest path length
|
||||
shortest_path = self._bfs_shortest_path(start, end, word_set)
|
||||
if not shortest_path:
|
||||
return None
|
||||
"""Simplified path finding using BFS for shortest paths"""
|
||||
# Early exit if words are direct neighbors
|
||||
if end in self._get_neighbors(start, word_set):
|
||||
return [start, end]
|
||||
|
||||
min_length = self.config.min_chain_length
|
||||
if len(shortest_path) > min_length:
|
||||
return shortest_path # Shortest path is already longer than required
|
||||
|
||||
# Now look for longer paths using DFS with depth constraint
|
||||
return self._dfs_with_depth(start, end, word_set, min_length)
|
||||
|
||||
def _bfs_shortest_path(self, start: str, end: str, word_set: Set[str]) -> Optional[List[str]]:
|
||||
"""BFS implementation to find shortest path"""
|
||||
# Use basic BFS for shortest path
|
||||
queue = deque([(start, [start])])
|
||||
visited = {start}
|
||||
|
||||
while queue:
|
||||
current, path = queue.popleft()
|
||||
if current == end:
|
||||
return path
|
||||
if self.config.is_valid_path_length(len(path)):
|
||||
return path
|
||||
return None
|
||||
|
||||
for neighbor in self._get_neighbors(current, word_set):
|
||||
if neighbor not in visited:
|
||||
visited.add(neighbor)
|
||||
queue.append((neighbor, path + [neighbor]))
|
||||
return None
|
||||
|
||||
def _dfs_with_depth(self, start: str, end: str, word_set: Set[str], target_length: int) -> Optional[List[str]]:
|
||||
"""DFS implementation looking for paths of exact length"""
|
||||
stack = [(start, [start], set([start]))]
|
||||
new_path = path + [neighbor]
|
||||
queue.append((neighbor, new_path))
|
||||
|
||||
while stack:
|
||||
current, path, visited = stack.pop()
|
||||
|
||||
if len(path) == target_length:
|
||||
if current == end:
|
||||
return path
|
||||
continue
|
||||
|
||||
if len(path) > target_length:
|
||||
continue
|
||||
|
||||
# Explore neighbors in random order to find different paths
|
||||
neighbors = list(self._get_neighbors(current, word_set))
|
||||
Random().shuffle(neighbors)
|
||||
|
||||
for neighbor in neighbors:
|
||||
if neighbor not in visited:
|
||||
new_visited = set(visited)
|
||||
new_visited.add(neighbor)
|
||||
stack.append((neighbor, path + [neighbor], new_visited))
|
||||
|
||||
return None
|
||||
|
||||
def _get_neighbors(self, word: str, word_set: Set[str]) -> Set[str]:
|
||||
"""Get all valid neighbors that differ by one letter"""
|
||||
neighbors = set()
|
||||
word_chars = list(word)
|
||||
|
||||
for i in range(len(word_chars)):
|
||||
original = word_chars[i]
|
||||
for c in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ':
|
||||
if c == original:
|
||||
continue
|
||||
word_chars[i] = c
|
||||
new_word = ''.join(word_chars)
|
||||
if new_word in word_set:
|
||||
neighbors.add(new_word)
|
||||
word_chars[i] = original
|
||||
|
||||
return neighbors
|
||||
|
||||
def _generate_word_pair(self, rng: Random, length: int) -> Tuple[str, str, List[str]]:
|
||||
"""Generate valid start/end words with solution path, with lower weight for 5-letter words ending with 'S'"""
|
||||
"""Simplified word pair generation"""
|
||||
word_set = self.word_sets[length]
|
||||
max_attempts = 500
|
||||
words_list = sorted(word_set)
|
||||
max_attempts = 100
|
||||
|
||||
words_list = self.words_lists[length]
|
||||
|
||||
# Use weighted sampling only for five-letter words
|
||||
use_weights = (length == 5)
|
||||
for _ in range(max_attempts):
|
||||
if use_weights:
|
||||
# Compute weights: assign 0.5 weight if a five-letter word ends with 'S', else 1.0
|
||||
weights = [0.5 if word.endswith('S') else 1.0 for word in words_list]
|
||||
start = rng.choices(words_list, weights=weights, k=1)[0]
|
||||
# Remove chosen word to ensure distinct selection for the second word
|
||||
remaining_words = words_list.copy()
|
||||
remaining_words.remove(start)
|
||||
weights_second = [0.5 if word.endswith('S') else 1.0 for word in remaining_words]
|
||||
end = rng.choices(remaining_words, weights=weights_second, k=1)[0]
|
||||
else:
|
||||
start, end = rng.sample(words_list, 2)
|
||||
start = rng.choice(words_list)
|
||||
end = rng.choice(words_list)
|
||||
|
||||
if start == end:
|
||||
continue
|
||||
|
||||
path = self._find_path(start, end, word_set)
|
||||
if path and (
|
||||
(self.config.min_chain_length == -1 and self.config.max_chain_length == -1) or
|
||||
(self.config.min_chain_length <= len(path) <= self.config.max_chain_length)
|
||||
):
|
||||
if path:
|
||||
return start, end, path
|
||||
|
||||
raise RuntimeError(f"Failed to find valid pair for length {length} after {max_attempts} attempts")
|
||||
|
||||
raise RuntimeError(f"Failed to find valid pair for length {length}")
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single word ladder task"""
|
||||
rng = Random(self.seed + idx)
|
||||
length = rng.randint(self.config.min_word_length, self.config.max_word_length)
|
||||
start, end, path = self._generate_word_pair(rng, length)
|
||||
if idx >= self.size:
|
||||
raise IndexError(f"Dataset index {idx} out of range for size {self.size}")
|
||||
|
||||
try:
|
||||
rng = Random(self.seed + idx)
|
||||
length = rng.randint(self.config.min_word_length, self.config.max_word_length)
|
||||
start, end, path = self._generate_word_pair(rng, length)
|
||||
except RuntimeError as e:
|
||||
# If we run out of valid paths, adjust the virtual size
|
||||
self.size = idx
|
||||
raise IndexError(f"Dataset exhausted at index {idx}. {str(e)}")
|
||||
|
||||
return {
|
||||
"question": f"Transform the word ladder '{start}' to '{end}' by changing one letter at a time.",
|
||||
|
|
@ -223,5 +220,4 @@ class WordLadderDataset(ProceduralDataset):
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
register_dataset("word_ladder", WordLadderDataset, WordLadderConfig)
|
||||
Loading…
Add table
Add a link
Reference in a new issue