mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
INIT
This commit is contained in:
parent
4f14a20725
commit
d57a7947a4
6 changed files with 13417 additions and 29 deletions
|
|
@ -16,6 +16,7 @@ from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDat
|
|||
from .spell_backward import SpellBackwardConfig, SpellBackwardDataset
|
||||
from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset
|
||||
from .word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset
|
||||
from .word_ladder import WordLadderConfig, WordLadderDataset
|
||||
|
||||
__all__ = [
|
||||
"SpellBackwardConfig",
|
||||
|
|
@ -39,4 +40,6 @@ __all__ = [
|
|||
"WordSortingConfig",
|
||||
"WordSortingDataset",
|
||||
"TextTransformation",
|
||||
"WordLadderConfig",
|
||||
"WordLadderDataset",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -14,8 +14,8 @@ class WordLadderConfig:
|
|||
|
||||
min_word_length: int = 3 # Minimum word length
|
||||
max_word_length: int = 5 # Maximum word length
|
||||
min_chain_length: int = 3 # Minimum transformations required
|
||||
max_chain_length: int = 11 # Maximum transformations required
|
||||
min_chain_length: int = -1 # Set to -1 for shortest path or a minimum of 3
|
||||
max_chain_length: int = -1 # Set to -1 for shortest path or a max
|
||||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
|
|
@ -24,8 +24,16 @@ class WordLadderConfig:
|
|||
assert self.min_word_length > 2, "min_word_length must be 3"
|
||||
assert self.max_word_length >= self.min_word_length, "max_word_length must be >= min_word_length"
|
||||
assert self.max_word_length <= 5, "max_word_length must be 5"
|
||||
assert self.min_chain_length > 2, "min_chain_length must be 3"
|
||||
assert self.max_chain_length >= self.min_chain_length, "max_chain_length must be >= min_chain_length"
|
||||
|
||||
# Modified validation logic
|
||||
if self.min_chain_length == -1:
|
||||
if self.max_chain_length != -1:
|
||||
assert self.max_chain_length >= 3, "When min_chain_length=-1 (shortest path), max_chain_length must be -1 or >=3"
|
||||
elif self.max_chain_length == -1:
|
||||
raise AssertionError("max_chain_length cannot be -1 unless min_chain_length is also -1")
|
||||
else:
|
||||
assert self.min_chain_length >= 3, "min_chain_length must be 3 or -1"
|
||||
assert self.max_chain_length >= self.min_chain_length, "max_chain_length must be >= min_chain_length"
|
||||
|
||||
class WordLadderDataset(ProceduralDataset):
|
||||
"""Generates word ladder transformation tasks"""
|
||||
|
|
@ -85,41 +93,83 @@ class WordLadderDataset(ProceduralDataset):
|
|||
return differences == 1
|
||||
|
||||
def _find_path(self, start: str, end: str, word_set: Set[str]) -> Optional[List[str]]:
|
||||
"""Find shortest path between start and end words using BFS"""
|
||||
"""Find path between start and end words that meets length requirements"""
|
||||
if start == end:
|
||||
return [start]
|
||||
|
||||
# First find shortest path length
|
||||
shortest_path = self._bfs_shortest_path(start, end, word_set)
|
||||
if not shortest_path:
|
||||
return None
|
||||
|
||||
min_length = self.config.min_chain_length
|
||||
if len(shortest_path) > min_length:
|
||||
return shortest_path # Shortest path is already longer than required
|
||||
|
||||
# Now look for longer paths using DFS with depth constraint
|
||||
return self._dfs_with_depth(start, end, word_set, min_length)
|
||||
|
||||
def _bfs_shortest_path(self, start: str, end: str, word_set: Set[str]) -> Optional[List[str]]:
|
||||
"""BFS implementation to find shortest path"""
|
||||
queue = deque([(start, [start])])
|
||||
visited = {start}
|
||||
|
||||
while queue:
|
||||
current, path = queue.popleft()
|
||||
|
||||
# Try changing one letter at a time
|
||||
word_chars = list(current)
|
||||
for i in range(len(word_chars)):
|
||||
for c in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ':
|
||||
if word_chars[i] == c:
|
||||
continue
|
||||
|
||||
# Create new word
|
||||
word_chars[i] = c
|
||||
new_word = ''.join(word_chars)
|
||||
|
||||
# Check if it's a valid word and not visited
|
||||
if new_word in word_set and new_word not in visited:
|
||||
new_path = path + [new_word]
|
||||
if new_word == end:
|
||||
return new_path
|
||||
|
||||
queue.append((new_word, new_path))
|
||||
visited.add(new_word)
|
||||
if current == end:
|
||||
return path
|
||||
|
||||
# Restore original character
|
||||
word_chars[i] = current[i]
|
||||
|
||||
for neighbor in self._get_neighbors(current, word_set):
|
||||
if neighbor not in visited:
|
||||
visited.add(neighbor)
|
||||
queue.append((neighbor, path + [neighbor]))
|
||||
return None
|
||||
|
||||
def _dfs_with_depth(self, start: str, end: str, word_set: Set[str], target_length: int) -> Optional[List[str]]:
|
||||
"""DFS implementation looking for paths of exact length"""
|
||||
stack = [(start, [start], set([start]))]
|
||||
|
||||
while stack:
|
||||
current, path, visited = stack.pop()
|
||||
|
||||
if len(path) == target_length:
|
||||
if current == end:
|
||||
return path
|
||||
continue
|
||||
|
||||
if len(path) > target_length:
|
||||
continue
|
||||
|
||||
# Explore neighbors in random order to find different paths
|
||||
neighbors = list(self._get_neighbors(current, word_set))
|
||||
Random().shuffle(neighbors)
|
||||
|
||||
for neighbor in neighbors:
|
||||
if neighbor not in visited:
|
||||
new_visited = set(visited)
|
||||
new_visited.add(neighbor)
|
||||
stack.append((neighbor, path + [neighbor], new_visited))
|
||||
|
||||
return None
|
||||
|
||||
def _get_neighbors(self, word: str, word_set: Set[str]) -> Set[str]:
|
||||
"""Get all valid neighbors that differ by one letter"""
|
||||
neighbors = set()
|
||||
word_chars = list(word)
|
||||
|
||||
for i in range(len(word_chars)):
|
||||
original = word_chars[i]
|
||||
for c in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ':
|
||||
if c == original:
|
||||
continue
|
||||
word_chars[i] = c
|
||||
new_word = ''.join(word_chars)
|
||||
if new_word in word_set:
|
||||
neighbors.add(new_word)
|
||||
word_chars[i] = original
|
||||
|
||||
return neighbors
|
||||
|
||||
def _generate_word_pair(self, rng: Random, length: int) -> Tuple[str, str, List[str]]:
|
||||
"""Generate valid start/end words with solution path"""
|
||||
word_set = self.word_sets[length]
|
||||
|
|
@ -128,7 +178,10 @@ class WordLadderDataset(ProceduralDataset):
|
|||
for _ in range(max_attempts):
|
||||
start, end = rng.sample(sorted(word_set), 2)
|
||||
path = self._find_path(start, end, word_set)
|
||||
if path and self.config.min_chain_length <= len(path) <= self.config.max_chain_length:
|
||||
if path and (
|
||||
(self.config.min_chain_length == -1 and self.config.max_chain_length == -1) or
|
||||
(self.config.min_chain_length <= len(path) <= self.config.max_chain_length)
|
||||
):
|
||||
return start, end, path
|
||||
|
||||
raise RuntimeError(f"Failed to find valid pair for length {length} after {max_attempts} attempts")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue