Merge remote-tracking branch 'upstream/main'

This commit is contained in:
Cavit Erginsoy 2025-02-03 07:44:32 +00:00
commit 9b1068ea39
85 changed files with 17049 additions and 6302 deletions

View file

@ -1,13 +1,15 @@
"""Word ladder task generator"""
from collections import deque
from dataclasses import dataclass
from random import Random
from typing import List, Optional, Set, Dict, Tuple
from collections import deque
from typing import Dict, List, Optional, Set, Tuple
from reasoning_gym.data import read_data_file
from ..factory import ProceduralDataset, register_dataset
@dataclass
class WordLadderConfig:
"""Configuration for word ladder task generation"""
@ -32,7 +34,9 @@ class WordLadderConfig:
# Modified validation logic
if self.min_chain_length == -1:
if self.max_chain_length != -1:
assert self.max_chain_length >= 3, "When min_chain_length=-1 (shortest path), max_chain_length must be -1 or >=3"
assert (
self.max_chain_length >= 3
), "When min_chain_length=-1 (shortest path), max_chain_length must be -1 or >=3"
elif self.max_chain_length == -1:
raise AssertionError("max_chain_length cannot be -1 unless min_chain_length is also -1")
else:
@ -81,16 +85,17 @@ class WordLadderDataset(ProceduralDataset):
import csv
from io import StringIO
word_sets = {}
try:
# Get CSV content as string
csv_content = read_data_file("words.csv")
# Use StringIO to create a file-like object from the string
csv_file = StringIO(csv_content)
reader = csv.DictReader(csv_file)
for row in reader:
# Process each word length column using config range
for length in range(min_length, max_length + 1):
@ -104,12 +109,12 @@ class WordLadderDataset(ProceduralDataset):
except Exception as e:
raise RuntimeError(f"Error processing words.csv content: {e}") from e
# Validate we have words for each length
for length in range(min_length, max_length + 1):
if length not in word_sets or not word_sets[length]:
raise ValueError(f"No valid words found for length {length}")
return word_sets
def _get_neighbors(self, word: str, word_set: Set[str]) -> Set[str]:
@ -160,7 +165,7 @@ class WordLadderDataset(ProceduralDataset):
# Use basic BFS for shortest path
queue = deque([(start, [start])])
visited = {start}
while queue:
current, path = queue.popleft()
if current == end:
@ -212,12 +217,7 @@ class WordLadderDataset(ProceduralDataset):
return {
"question": f"Transform the word ladder '{start}' to '{end}' by changing one letter at a time.",
"answer": ",".join(path),
"metadata": {
"start_word": start,
"end_word": end,
"word_length": length,
"chain_length": len(path)
}
"metadata": {"start_word": start, "end_word": end, "word_length": length, "chain_length": len(path)},
}
register_dataset("word_ladder", WordLadderDataset, WordLadderConfig)
register_dataset("word_ladder", WordLadderDataset, WordLadderConfig)