diff --git a/reasoning_gym/algorithmic/word_ladder.py b/reasoning_gym/algorithmic/word_ladder.py index 77c43289..d0342cf7 100644 --- a/reasoning_gym/algorithmic/word_ladder.py +++ b/reasoning_gym/algorithmic/word_ladder.py @@ -12,18 +12,22 @@ from ..factory import ProceduralDataset, register_dataset class WordLadderConfig: """Configuration for word ladder task generation""" - min_word_length: int = 3 # Minimum word length - max_word_length: int = 5 # Maximum word length + min_word_length: int = 4 # Minimum word length + max_word_length: int = 4 # Maximum word length min_chain_length: int = -1 # Set to -1 for shortest path or a minimum of 3 max_chain_length: int = -1 # Set to -1 for shortest path or a max seed: Optional[int] = None - size: int = 500 # Virtual dataset size + size: int = 500 def validate(self) -> None: """Validate configuration parameters""" - assert self.min_word_length > 2, "min_word_length must be 3" + assert self.min_word_length >= 3, "min_word_length must be >= 3" assert self.max_word_length >= self.min_word_length, "max_word_length must be >= min_word_length" - assert self.max_word_length <= 5, "max_word_length must be 5" + assert self.max_word_length <= 5, "max_word_length must be <= 5" + + # Add size validation + if self.size > 20000: # Add reasonable upper limit + raise ValueError("Dataset size too large for this algorithm and constraints") # Modified validation logic if self.min_chain_length == -1: @@ -35,22 +39,46 @@ class WordLadderConfig: assert self.min_chain_length >= 3, "min_chain_length must be 3 or -1" assert self.max_chain_length >= self.min_chain_length, "max_chain_length must be >= min_chain_length" + def is_valid_path_length(self, length: int) -> bool: + """Check if a path length meets the configuration requirements""" + # When min_chain_length is -1, we accept any path of length >= 3 + if self.min_chain_length == -1: + if self.max_chain_length == -1: + return length >= 3 + return 3 <= length <= self.max_chain_length + + # Otherwise check against both min and max + return (self.min_chain_length <= length <= + (self.max_chain_length if self.max_chain_length != -1 else float('inf'))) + class WordLadderDataset(ProceduralDataset): """Generates word ladder transformation tasks""" def __init__(self, config: WordLadderConfig): - super().__init__(config=config, seed=config.seed, size=config.size) + self.config = config + self.word_sets = {} + self.word_graphs = {} - # Load words from CSV file - self.word_sets = self._load_words_from_csv() - # Precompute the sorted word lists for each word length - self.words_lists = { - length: sorted(words) - for length, words in self.word_sets.items() - } + # Load words from CSV + self.word_sets = self._load_words_from_csv( + min_length=self.config.min_word_length, + max_length=self.config.max_word_length + ) + + # Precompute word graphs for all lengths + for length in range(self.config.min_word_length, self.config.max_word_length + 1): + self.word_graphs[length] = self._build_word_graph(length) + + config.validate() + super().__init__(config=config, seed=config.seed, size=config.size) - def _load_words_from_csv(self) -> Dict[int, Set[str]]: + + @classmethod + def _load_words_from_csv(cls, min_length: int = 3, max_length: int = 5) -> Dict[int, Set[str]]: """Load words from CSV file organized by length""" + # Validate length range before processing + assert 3 <= min_length <= max_length <= 5, "Word length must be between 3 and 5 inclusive" + import csv from io import StringIO word_sets = {} @@ -64,153 +92,122 @@ class WordLadderDataset(ProceduralDataset): reader = csv.DictReader(csv_file) for row in reader: - # Process each word length column - for length in range(3, 6): + # Process each word length column using config range + for length in range(min_length, max_length + 1): col_name = f'{length}_letter' word = row.get(col_name, '') if not word: # Skip empty entries continue - if self.config.min_word_length <= length <= self.config.max_word_length: - word_sets.setdefault(length, set()).add(word.upper()) + word_sets.setdefault(length, set()).add(word.upper()) except Exception as e: raise RuntimeError(f"Error processing words.csv content: {e}") from e # Validate we have words for each length - for length in range(self.config.min_word_length, self.config.max_word_length + 1): + for length in range(min_length, max_length + 1): if length not in word_sets or not word_sets[length]: raise ValueError(f"No valid words found for length {length}") return word_sets - def _differs_by_one(self, word1: str, word2: str) -> bool: - """Check if two words differ by exactly one letter""" - if len(word1) != len(word2): - return False - differences = 0 - for c1, c2 in zip(word1, word2): - if c1 != c2: - differences += 1 - if differences > 1: - return False - return differences == 1 + def _get_neighbors(self, word: str, word_set: Set[str]) -> Set[str]: + """Get neighbors from either precomputed graph or by computing on demand""" + # Try precomputed graph first + if len(word) in self.word_graphs and word in self.word_graphs[len(word)]: + return self.word_graphs[len(word)].get(word, set()) + + # Fall back to computing neighbors directly for custom word sets + neighbors = set() + for i in range(len(word)): + for c in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ': + neighbor = word[:i] + c + word[i+1:] + if neighbor != word and neighbor in word_set: + neighbors.add(neighbor) + return neighbors + + def _build_word_graph(self, word_length: int) -> Dict[str, Set[str]]: + """Build graph of word connections for given length, using caching""" + # Return cached graph if it exists + if word_length in self.word_graphs: + return self.word_graphs[word_length] + + # Build new graph + word_set = self.word_sets[word_length] + graph = {} + + # Build connections + for word in word_set: + neighbors = set() + for i in range(word_length): + for c in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ': + neighbor = word[:i] + c + word[i+1:] + if neighbor != word and neighbor in word_set: + neighbors.add(neighbor) + graph[word] = neighbors + + # Cache and return + self.word_graphs[word_length] = graph + return self.word_graphs[word_length] def _find_path(self, start: str, end: str, word_set: Set[str]) -> Optional[List[str]]: - """Find path between start and end words that meets length requirements""" - if start == end: - return [start] - - # First find shortest path length - shortest_path = self._bfs_shortest_path(start, end, word_set) - if not shortest_path: - return None + """Simplified path finding using BFS for shortest paths""" + # Early exit if words are direct neighbors + if end in self._get_neighbors(start, word_set): + return [start, end] - min_length = self.config.min_chain_length - if len(shortest_path) > min_length: - return shortest_path # Shortest path is already longer than required - - # Now look for longer paths using DFS with depth constraint - return self._dfs_with_depth(start, end, word_set, min_length) - - def _bfs_shortest_path(self, start: str, end: str, word_set: Set[str]) -> Optional[List[str]]: - """BFS implementation to find shortest path""" + # Use basic BFS for shortest path queue = deque([(start, [start])]) visited = {start} while queue: current, path = queue.popleft() if current == end: - return path + if self.config.is_valid_path_length(len(path)): + return path + return None for neighbor in self._get_neighbors(current, word_set): if neighbor not in visited: visited.add(neighbor) - queue.append((neighbor, path + [neighbor])) - return None - - def _dfs_with_depth(self, start: str, end: str, word_set: Set[str], target_length: int) -> Optional[List[str]]: - """DFS implementation looking for paths of exact length""" - stack = [(start, [start], set([start]))] + new_path = path + [neighbor] + queue.append((neighbor, new_path)) - while stack: - current, path, visited = stack.pop() - - if len(path) == target_length: - if current == end: - return path - continue - - if len(path) > target_length: - continue - - # Explore neighbors in random order to find different paths - neighbors = list(self._get_neighbors(current, word_set)) - Random().shuffle(neighbors) - - for neighbor in neighbors: - if neighbor not in visited: - new_visited = set(visited) - new_visited.add(neighbor) - stack.append((neighbor, path + [neighbor], new_visited)) - return None - def _get_neighbors(self, word: str, word_set: Set[str]) -> Set[str]: - """Get all valid neighbors that differ by one letter""" - neighbors = set() - word_chars = list(word) - - for i in range(len(word_chars)): - original = word_chars[i] - for c in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ': - if c == original: - continue - word_chars[i] = c - new_word = ''.join(word_chars) - if new_word in word_set: - neighbors.add(new_word) - word_chars[i] = original - - return neighbors - def _generate_word_pair(self, rng: Random, length: int) -> Tuple[str, str, List[str]]: - """Generate valid start/end words with solution path, with lower weight for 5-letter words ending with 'S'""" + """Simplified word pair generation""" word_set = self.word_sets[length] - max_attempts = 500 + words_list = sorted(word_set) + max_attempts = 100 - words_list = self.words_lists[length] - - # Use weighted sampling only for five-letter words - use_weights = (length == 5) for _ in range(max_attempts): - if use_weights: - # Compute weights: assign 0.5 weight if a five-letter word ends with 'S', else 1.0 - weights = [0.5 if word.endswith('S') else 1.0 for word in words_list] - start = rng.choices(words_list, weights=weights, k=1)[0] - # Remove chosen word to ensure distinct selection for the second word - remaining_words = words_list.copy() - remaining_words.remove(start) - weights_second = [0.5 if word.endswith('S') else 1.0 for word in remaining_words] - end = rng.choices(remaining_words, weights=weights_second, k=1)[0] - else: - start, end = rng.sample(words_list, 2) + start = rng.choice(words_list) + end = rng.choice(words_list) + if start == end: + continue + path = self._find_path(start, end, word_set) - if path and ( - (self.config.min_chain_length == -1 and self.config.max_chain_length == -1) or - (self.config.min_chain_length <= len(path) <= self.config.max_chain_length) - ): + if path: return start, end, path - - raise RuntimeError(f"Failed to find valid pair for length {length} after {max_attempts} attempts") + + raise RuntimeError(f"Failed to find valid pair for length {length}") def __getitem__(self, idx: int) -> dict: """Generate a single word ladder task""" - rng = Random(self.seed + idx) - length = rng.randint(self.config.min_word_length, self.config.max_word_length) - start, end, path = self._generate_word_pair(rng, length) + if idx >= self.size: + raise IndexError(f"Dataset index {idx} out of range for size {self.size}") + + try: + rng = Random(self.seed + idx) + length = rng.randint(self.config.min_word_length, self.config.max_word_length) + start, end, path = self._generate_word_pair(rng, length) + except RuntimeError as e: + # If we run out of valid paths, adjust the virtual size + self.size = idx + raise IndexError(f"Dataset exhausted at index {idx}. {str(e)}") return { "question": f"Transform the word ladder '{start}' to '{end}' by changing one letter at a time.", @@ -223,5 +220,4 @@ class WordLadderDataset(ProceduralDataset): } } - register_dataset("word_ladder", WordLadderDataset, WordLadderConfig) \ No newline at end of file