diff --git a/reasoning_gym/algorithmic/word_ladder.py b/reasoning_gym/algorithmic/word_ladder.py index c3a21520..77c43289 100644 --- a/reasoning_gym/algorithmic/word_ladder.py +++ b/reasoning_gym/algorithmic/word_ladder.py @@ -43,6 +43,11 @@ class WordLadderDataset(ProceduralDataset): # Load words from CSV file self.word_sets = self._load_words_from_csv() + # Precompute the sorted word lists for each word length + self.words_lists = { + length: sorted(words) + for length, words in self.word_sets.items() + } def _load_words_from_csv(self) -> Dict[int, Set[str]]: """Load words from CSV file organized by length""" @@ -171,19 +176,34 @@ class WordLadderDataset(ProceduralDataset): return neighbors def _generate_word_pair(self, rng: Random, length: int) -> Tuple[str, str, List[str]]: - """Generate valid start/end words with solution path""" + """Generate valid start/end words with solution path, with lower weight for 5-letter words ending with 'S'""" word_set = self.word_sets[length] max_attempts = 500 + words_list = self.words_lists[length] + + # Use weighted sampling only for five-letter words + use_weights = (length == 5) for _ in range(max_attempts): - start, end = rng.sample(sorted(word_set), 2) + if use_weights: + # Compute weights: assign 0.5 weight if a five-letter word ends with 'S', else 1.0 + weights = [0.5 if word.endswith('S') else 1.0 for word in words_list] + start = rng.choices(words_list, weights=weights, k=1)[0] + # Remove chosen word to ensure distinct selection for the second word + remaining_words = words_list.copy() + remaining_words.remove(start) + weights_second = [0.5 if word.endswith('S') else 1.0 for word in remaining_words] + end = rng.choices(remaining_words, weights=weights_second, k=1)[0] + else: + start, end = rng.sample(words_list, 2) + path = self._find_path(start, end, word_set) if path and ( (self.config.min_chain_length == -1 and self.config.max_chain_length == -1) or (self.config.min_chain_length <= len(path) <= self.config.max_chain_length) ): return start, end, path - + raise RuntimeError(f"Failed to find valid pair for length {length} after {max_attempts} attempts") def __getitem__(self, idx: int) -> dict: