Improve efficiency and reduce plural bias in word ladder generation

- Precomputed sorted word lists for each word length (stored in self.words_lists) to avoid redundant sorting on every _generate_word_pair call.
- Updated _generate_word_pair to utilize the cached sorted list, significantly improving computational efficiency.
- Implemented weighted random sampling for 5-letter words, giving words ending with 'S' a lower weight (0.5) to reduce bias without completely filtering them out.
This commit is contained in:
Cavit Erginsoy 2025-02-01 14:37:21 +00:00
parent fce0c4fa3f
commit 511425797f

View file

@ -43,6 +43,11 @@ class WordLadderDataset(ProceduralDataset):
# Load words from CSV file
self.word_sets = self._load_words_from_csv()
# Precompute the sorted word lists for each word length
self.words_lists = {
length: sorted(words)
for length, words in self.word_sets.items()
}
def _load_words_from_csv(self) -> Dict[int, Set[str]]:
"""Load words from CSV file organized by length"""
@ -171,19 +176,34 @@ class WordLadderDataset(ProceduralDataset):
return neighbors
def _generate_word_pair(self, rng: Random, length: int) -> Tuple[str, str, List[str]]:
"""Generate valid start/end words with solution path"""
"""Generate valid start/end words with solution path, with lower weight for 5-letter words ending with 'S'"""
word_set = self.word_sets[length]
max_attempts = 500
words_list = self.words_lists[length]
# Use weighted sampling only for five-letter words
use_weights = (length == 5)
for _ in range(max_attempts):
start, end = rng.sample(sorted(word_set), 2)
if use_weights:
# Compute weights: assign 0.5 weight if a five-letter word ends with 'S', else 1.0
weights = [0.5 if word.endswith('S') else 1.0 for word in words_list]
start = rng.choices(words_list, weights=weights, k=1)[0]
# Remove chosen word to ensure distinct selection for the second word
remaining_words = words_list.copy()
remaining_words.remove(start)
weights_second = [0.5 if word.endswith('S') else 1.0 for word in remaining_words]
end = rng.choices(remaining_words, weights=weights_second, k=1)[0]
else:
start, end = rng.sample(words_list, 2)
path = self._find_path(start, end, word_set)
if path and (
(self.config.min_chain_length == -1 and self.config.max_chain_length == -1) or
(self.config.min_chain_length <= len(path) <= self.config.max_chain_length)
):
return start, end, path
raise RuntimeError(f"Failed to find valid pair for length {length} after {max_attempts} attempts")
def __getitem__(self, idx: int) -> dict: