mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-28 17:29:39 +00:00
Improve efficiency and reduce plural bias in word ladder generation
- Precomputed sorted word lists for each word length (stored in self.words_lists) to avoid redundant sorting on every _generate_word_pair call. - Updated _generate_word_pair to utilize the cached sorted list, significantly improving computational efficiency. - Implemented weighted random sampling for 5-letter words, giving words ending with 'S' a lower weight (0.5) to reduce bias without completely filtering them out.
This commit is contained in:
parent
fce0c4fa3f
commit
511425797f
1 changed files with 23 additions and 3 deletions
|
|
@ -43,6 +43,11 @@ class WordLadderDataset(ProceduralDataset):
|
|||
|
||||
# Load words from CSV file
|
||||
self.word_sets = self._load_words_from_csv()
|
||||
# Precompute the sorted word lists for each word length
|
||||
self.words_lists = {
|
||||
length: sorted(words)
|
||||
for length, words in self.word_sets.items()
|
||||
}
|
||||
|
||||
def _load_words_from_csv(self) -> Dict[int, Set[str]]:
|
||||
"""Load words from CSV file organized by length"""
|
||||
|
|
@ -171,19 +176,34 @@ class WordLadderDataset(ProceduralDataset):
|
|||
return neighbors
|
||||
|
||||
def _generate_word_pair(self, rng: Random, length: int) -> Tuple[str, str, List[str]]:
|
||||
"""Generate valid start/end words with solution path"""
|
||||
"""Generate valid start/end words with solution path, with lower weight for 5-letter words ending with 'S'"""
|
||||
word_set = self.word_sets[length]
|
||||
max_attempts = 500
|
||||
|
||||
words_list = self.words_lists[length]
|
||||
|
||||
# Use weighted sampling only for five-letter words
|
||||
use_weights = (length == 5)
|
||||
for _ in range(max_attempts):
|
||||
start, end = rng.sample(sorted(word_set), 2)
|
||||
if use_weights:
|
||||
# Compute weights: assign 0.5 weight if a five-letter word ends with 'S', else 1.0
|
||||
weights = [0.5 if word.endswith('S') else 1.0 for word in words_list]
|
||||
start = rng.choices(words_list, weights=weights, k=1)[0]
|
||||
# Remove chosen word to ensure distinct selection for the second word
|
||||
remaining_words = words_list.copy()
|
||||
remaining_words.remove(start)
|
||||
weights_second = [0.5 if word.endswith('S') else 1.0 for word in remaining_words]
|
||||
end = rng.choices(remaining_words, weights=weights_second, k=1)[0]
|
||||
else:
|
||||
start, end = rng.sample(words_list, 2)
|
||||
|
||||
path = self._find_path(start, end, word_set)
|
||||
if path and (
|
||||
(self.config.min_chain_length == -1 and self.config.max_chain_length == -1) or
|
||||
(self.config.min_chain_length <= len(path) <= self.config.max_chain_length)
|
||||
):
|
||||
return start, end, path
|
||||
|
||||
|
||||
raise RuntimeError(f"Failed to find valid pair for length {length} after {max_attempts} attempts")
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue