This commit is contained in:
Andreas Koepf 2025-01-30 23:14:32 +01:00
parent 048a165314
commit 5ae329becd
6 changed files with 148 additions and 124 deletions

View file

@ -14,9 +14,9 @@ from .number_filtering import NumberFilteringConfig, NumberFilteringDataset
from .number_sorting import NumberSortingConfig, NumberSortingDataset
from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset
from .spell_backward import SpellBackwardConfig, SpellBackwardDataset
from .word_ladder import WordLadderConfig, WordLadderDataset
from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset
from .word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset
from .word_ladder import WordLadderConfig, WordLadderDataset
__all__ = [
"SpellBackwardConfig",

View file

@ -1,46 +1,51 @@
"""Word ladder task generator"""
from collections import deque
from dataclasses import dataclass
from random import Random
from typing import List, Optional, Set, Dict, Tuple
from collections import deque
from typing import Dict, List, Optional, Set, Tuple
from reasoning_gym.data import read_data_file
from ..factory import ProceduralDataset, register_dataset
@dataclass
class WordLadderConfig:
"""Configuration for word ladder task generation"""
min_word_length: int = 3 # Minimum word length
max_word_length: int = 5 # Maximum word length
min_chain_length: int = -1 # Set to -1 for shortest path or a minimum of 3
max_chain_length: int = -1 # Set to -1 for shortest path or a max
min_word_length: int = 3 # Minimum word length
max_word_length: int = 5 # Maximum word length
min_chain_length: int = -1 # Set to -1 for shortest path or a minimum of 3
max_chain_length: int = -1 # Set to -1 for shortest path or a max
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
size: int = 500 # Virtual dataset size
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_word_length > 2, "min_word_length must be 3"
assert self.max_word_length >= self.min_word_length, "max_word_length must be >= min_word_length"
assert self.max_word_length <= 5, "max_word_length must be 5"
# Modified validation logic
if self.min_chain_length == -1:
if self.max_chain_length != -1:
assert self.max_chain_length >= 3, "When min_chain_length=-1 (shortest path), max_chain_length must be -1 or >=3"
assert (
self.max_chain_length >= 3
), "When min_chain_length=-1 (shortest path), max_chain_length must be -1 or >=3"
elif self.max_chain_length == -1:
raise AssertionError("max_chain_length cannot be -1 unless min_chain_length is also -1")
else:
assert self.min_chain_length >= 3, "min_chain_length must be 3 or -1"
assert self.max_chain_length >= self.min_chain_length, "max_chain_length must be >= min_chain_length"
class WordLadderDataset(ProceduralDataset):
"""Generates word ladder transformation tasks"""
def __init__(self, config: WordLadderConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
# Load words from CSV file
self.word_sets = self._load_words_from_csv()
@ -48,36 +53,37 @@ class WordLadderDataset(ProceduralDataset):
"""Load words from CSV file organized by length"""
import csv
from io import StringIO
word_sets = {}
try:
# Get CSV content as string
csv_content = read_data_file("words.csv")
# Use StringIO to create a file-like object from the string
csv_file = StringIO(csv_content)
reader = csv.DictReader(csv_file)
for row in reader:
# Process each word length column
for length in range(3, 6):
col_name = f'{length}_letter'
word = row.get(col_name, '')
col_name = f"{length}_letter"
word = row.get(col_name, "")
if not word: # Skip empty entries
continue
if self.config.min_word_length <= length <= self.config.max_word_length:
word_sets.setdefault(length, set()).add(word.upper())
except Exception as e:
raise RuntimeError(f"Error processing words.csv content: {e}") from e
# Validate we have words for each length
for length in range(self.config.min_word_length, self.config.max_word_length + 1):
if length not in word_sets or not word_sets[length]:
raise ValueError(f"No valid words found for length {length}")
return word_sets
def _differs_by_one(self, word1: str, word2: str) -> bool:
@ -96,16 +102,16 @@ class WordLadderDataset(ProceduralDataset):
"""Find path between start and end words that meets length requirements"""
if start == end:
return [start]
# First find shortest path length
shortest_path = self._bfs_shortest_path(start, end, word_set)
if not shortest_path:
return None
min_length = self.config.min_chain_length
if len(shortest_path) > min_length:
return shortest_path # Shortest path is already longer than required
# Now look for longer paths using DFS with depth constraint
return self._dfs_with_depth(start, end, word_set, min_length)
@ -113,12 +119,12 @@ class WordLadderDataset(ProceduralDataset):
"""BFS implementation to find shortest path"""
queue = deque([(start, [start])])
visited = {start}
while queue:
current, path = queue.popleft()
if current == end:
return path
for neighbor in self._get_neighbors(current, word_set):
if neighbor not in visited:
visited.add(neighbor)
@ -128,62 +134,62 @@ class WordLadderDataset(ProceduralDataset):
def _dfs_with_depth(self, start: str, end: str, word_set: Set[str], target_length: int) -> Optional[List[str]]:
"""DFS implementation looking for paths of exact length"""
stack = [(start, [start], set([start]))]
while stack:
current, path, visited = stack.pop()
if len(path) == target_length:
if current == end:
return path
continue
if len(path) > target_length:
continue
# Explore neighbors in random order to find different paths
neighbors = list(self._get_neighbors(current, word_set))
Random().shuffle(neighbors)
for neighbor in neighbors:
if neighbor not in visited:
new_visited = set(visited)
new_visited.add(neighbor)
stack.append((neighbor, path + [neighbor], new_visited))
return None
def _get_neighbors(self, word: str, word_set: Set[str]) -> Set[str]:
"""Get all valid neighbors that differ by one letter"""
neighbors = set()
word_chars = list(word)
for i in range(len(word_chars)):
original = word_chars[i]
for c in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ':
for c in "ABCDEFGHIJKLMNOPQRSTUVWXYZ":
if c == original:
continue
word_chars[i] = c
new_word = ''.join(word_chars)
new_word = "".join(word_chars)
if new_word in word_set:
neighbors.add(new_word)
word_chars[i] = original
return neighbors
def _generate_word_pair(self, rng: Random, length: int) -> Tuple[str, str, List[str]]:
"""Generate valid start/end words with solution path"""
word_set = self.word_sets[length]
max_attempts = 500
for _ in range(max_attempts):
start, end = rng.sample(sorted(word_set), 2)
path = self._find_path(start, end, word_set)
if path and (
(self.config.min_chain_length == -1 and self.config.max_chain_length == -1) or
(self.config.min_chain_length <= len(path) <= self.config.max_chain_length)
(self.config.min_chain_length == -1 and self.config.max_chain_length == -1)
or (self.config.min_chain_length <= len(path) <= self.config.max_chain_length)
):
return start, end, path
raise RuntimeError(f"Failed to find valid pair for length {length} after {max_attempts} attempts")
def __getitem__(self, idx: int) -> dict:
@ -191,17 +197,12 @@ class WordLadderDataset(ProceduralDataset):
rng = Random(self.seed + idx)
length = rng.randint(self.config.min_word_length, self.config.max_word_length)
start, end, path = self._generate_word_pair(rng, length)
return {
"question": f"Transform the word '{start}' into '{end}' by changing one letter at a time. Each step must create a valid English word (including plurals) and keep the same word length. Show the sequence of words needed.",
"answer": ",".join(path),
"metadata": {
"start_word": start,
"end_word": end,
"word_length": length,
"chain_length": len(path)
}
"metadata": {"start_word": start, "end_word": end, "word_length": length, "chain_length": len(path)},
}
register_dataset("word_ladder", WordLadderDataset, WordLadderConfig)
register_dataset("word_ladder", WordLadderDataset, WordLadderConfig)