Merge remote-tracking branch 'upstream/main'

This commit is contained in:
Cavit Erginsoy 2025-02-03 07:44:32 +00:00
commit 9b1068ea39
85 changed files with 17049 additions and 6302 deletions

View file

@ -14,9 +14,9 @@ from .number_filtering import NumberFilteringConfig, NumberFilteringDataset
from .number_sorting import NumberSortingConfig, NumberSortingDataset
from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset
from .spell_backward import SpellBackwardConfig, SpellBackwardDataset
from .word_ladder import WordLadderConfig, WordLadderDataset
from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset
from .word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset
from .word_ladder import WordLadderConfig, WordLadderDataset
__all__ = [
"SpellBackwardConfig",

View file

@ -60,14 +60,32 @@ class BaseConversionDataset(ProceduralDataset):
value, source_base, target_base = self._generate_conversion(rng)
# Convert decimal to source base representation
source_repr = format(value, f"x" if source_base == 16 else f"b" if source_base == 2 else "").strip()
if source_base not in (2, 16):
source_repr = format(value, f"{source_base}x").lower().strip()
if source_base == 16:
source_repr = format(value, "x")
elif source_base == 2:
source_repr = format(value, "b")
else:
# Manual conversion for other bases
n = value
digits = []
while n:
digits.append(int(n % source_base))
n //= source_base
source_repr = "".join(str(d) if d < 10 else chr(ord("a") + d - 10) for d in reversed(digits) or [0])
# Convert decimal to target base for answer
target_repr = format(value, f"x" if target_base == 16 else f"b" if target_base == 2 else "").strip()
if target_base not in (2, 16):
target_repr = format(value, f"{target_base}x").lower().strip()
if target_base == 16:
target_repr = format(value, "x")
elif target_base == 2:
target_repr = format(value, "b")
else:
# Manual conversion for other bases
n = value
digits = []
while n:
digits.append(int(n % target_base))
n //= target_base
target_repr = "".join(str(d) if d < 10 else chr(ord("a") + d - 10) for d in reversed(digits) or [0])
source_name = self._format_base_name(source_base)
target_name = self._format_base_name(target_base)

View file

@ -51,7 +51,7 @@ class LetterCountingDataset(ProceduralDataset):
letters = {"a"} # Fallback if span has no letters
# Select random letter that appears in the span
target_letter = rng.choice(list(letters))
target_letter = rng.choice(sorted(letters))
# Count occurrences
count = sum(word.lower().count(target_letter) for word in span)

View file

@ -1,13 +1,15 @@
"""Word ladder task generator"""
from collections import deque
from dataclasses import dataclass
from random import Random
from typing import List, Optional, Set, Dict, Tuple
from collections import deque
from typing import Dict, List, Optional, Set, Tuple
from reasoning_gym.data import read_data_file
from ..factory import ProceduralDataset, register_dataset
@dataclass
class WordLadderConfig:
"""Configuration for word ladder task generation"""
@ -32,7 +34,9 @@ class WordLadderConfig:
# Modified validation logic
if self.min_chain_length == -1:
if self.max_chain_length != -1:
assert self.max_chain_length >= 3, "When min_chain_length=-1 (shortest path), max_chain_length must be -1 or >=3"
assert (
self.max_chain_length >= 3
), "When min_chain_length=-1 (shortest path), max_chain_length must be -1 or >=3"
elif self.max_chain_length == -1:
raise AssertionError("max_chain_length cannot be -1 unless min_chain_length is also -1")
else:
@ -81,16 +85,17 @@ class WordLadderDataset(ProceduralDataset):
import csv
from io import StringIO
word_sets = {}
try:
# Get CSV content as string
csv_content = read_data_file("words.csv")
# Use StringIO to create a file-like object from the string
csv_file = StringIO(csv_content)
reader = csv.DictReader(csv_file)
for row in reader:
# Process each word length column using config range
for length in range(min_length, max_length + 1):
@ -104,12 +109,12 @@ class WordLadderDataset(ProceduralDataset):
except Exception as e:
raise RuntimeError(f"Error processing words.csv content: {e}") from e
# Validate we have words for each length
for length in range(min_length, max_length + 1):
if length not in word_sets or not word_sets[length]:
raise ValueError(f"No valid words found for length {length}")
return word_sets
def _get_neighbors(self, word: str, word_set: Set[str]) -> Set[str]:
@ -160,7 +165,7 @@ class WordLadderDataset(ProceduralDataset):
# Use basic BFS for shortest path
queue = deque([(start, [start])])
visited = {start}
while queue:
current, path = queue.popleft()
if current == end:
@ -212,12 +217,7 @@ class WordLadderDataset(ProceduralDataset):
return {
"question": f"Transform the word ladder '{start}' to '{end}' by changing one letter at a time.",
"answer": ",".join(path),
"metadata": {
"start_word": start,
"end_word": end,
"word_length": length,
"chain_length": len(path)
}
"metadata": {"start_word": start, "end_word": end, "word_length": length, "chain_length": len(path)},
}
register_dataset("word_ladder", WordLadderDataset, WordLadderConfig)
register_dataset("word_ladder", WordLadderDataset, WordLadderConfig)

View file

@ -49,7 +49,7 @@ class WordSortingDataset(ProceduralDataset):
# Load and preprocess text
text = read_data_file("in_the_year_2889.txt")
# Extract unique words within length constraints
self.words = list(
self.words = sorted(
set(
word
for word in re.findall(r"\b\w+\b", text)