mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-26 17:13:17 +00:00
Merge remote-tracking branch 'upstream/main'
This commit is contained in:
commit
9b1068ea39
85 changed files with 17049 additions and 6302 deletions
|
|
@ -14,9 +14,9 @@ from .number_filtering import NumberFilteringConfig, NumberFilteringDataset
|
|||
from .number_sorting import NumberSortingConfig, NumberSortingDataset
|
||||
from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset
|
||||
from .spell_backward import SpellBackwardConfig, SpellBackwardDataset
|
||||
from .word_ladder import WordLadderConfig, WordLadderDataset
|
||||
from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset
|
||||
from .word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset
|
||||
from .word_ladder import WordLadderConfig, WordLadderDataset
|
||||
|
||||
__all__ = [
|
||||
"SpellBackwardConfig",
|
||||
|
|
|
|||
|
|
@ -60,14 +60,32 @@ class BaseConversionDataset(ProceduralDataset):
|
|||
value, source_base, target_base = self._generate_conversion(rng)
|
||||
|
||||
# Convert decimal to source base representation
|
||||
source_repr = format(value, f"x" if source_base == 16 else f"b" if source_base == 2 else "").strip()
|
||||
if source_base not in (2, 16):
|
||||
source_repr = format(value, f"{source_base}x").lower().strip()
|
||||
if source_base == 16:
|
||||
source_repr = format(value, "x")
|
||||
elif source_base == 2:
|
||||
source_repr = format(value, "b")
|
||||
else:
|
||||
# Manual conversion for other bases
|
||||
n = value
|
||||
digits = []
|
||||
while n:
|
||||
digits.append(int(n % source_base))
|
||||
n //= source_base
|
||||
source_repr = "".join(str(d) if d < 10 else chr(ord("a") + d - 10) for d in reversed(digits) or [0])
|
||||
|
||||
# Convert decimal to target base for answer
|
||||
target_repr = format(value, f"x" if target_base == 16 else f"b" if target_base == 2 else "").strip()
|
||||
if target_base not in (2, 16):
|
||||
target_repr = format(value, f"{target_base}x").lower().strip()
|
||||
if target_base == 16:
|
||||
target_repr = format(value, "x")
|
||||
elif target_base == 2:
|
||||
target_repr = format(value, "b")
|
||||
else:
|
||||
# Manual conversion for other bases
|
||||
n = value
|
||||
digits = []
|
||||
while n:
|
||||
digits.append(int(n % target_base))
|
||||
n //= target_base
|
||||
target_repr = "".join(str(d) if d < 10 else chr(ord("a") + d - 10) for d in reversed(digits) or [0])
|
||||
|
||||
source_name = self._format_base_name(source_base)
|
||||
target_name = self._format_base_name(target_base)
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ class LetterCountingDataset(ProceduralDataset):
|
|||
letters = {"a"} # Fallback if span has no letters
|
||||
|
||||
# Select random letter that appears in the span
|
||||
target_letter = rng.choice(list(letters))
|
||||
target_letter = rng.choice(sorted(letters))
|
||||
|
||||
# Count occurrences
|
||||
count = sum(word.lower().count(target_letter) for word in span)
|
||||
|
|
|
|||
|
|
@ -1,13 +1,15 @@
|
|||
"""Word ladder task generator"""
|
||||
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import List, Optional, Set, Dict, Tuple
|
||||
from collections import deque
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
from reasoning_gym.data import read_data_file
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class WordLadderConfig:
|
||||
"""Configuration for word ladder task generation"""
|
||||
|
|
@ -32,7 +34,9 @@ class WordLadderConfig:
|
|||
# Modified validation logic
|
||||
if self.min_chain_length == -1:
|
||||
if self.max_chain_length != -1:
|
||||
assert self.max_chain_length >= 3, "When min_chain_length=-1 (shortest path), max_chain_length must be -1 or >=3"
|
||||
assert (
|
||||
self.max_chain_length >= 3
|
||||
), "When min_chain_length=-1 (shortest path), max_chain_length must be -1 or >=3"
|
||||
elif self.max_chain_length == -1:
|
||||
raise AssertionError("max_chain_length cannot be -1 unless min_chain_length is also -1")
|
||||
else:
|
||||
|
|
@ -81,16 +85,17 @@ class WordLadderDataset(ProceduralDataset):
|
|||
|
||||
import csv
|
||||
from io import StringIO
|
||||
|
||||
word_sets = {}
|
||||
|
||||
|
||||
try:
|
||||
# Get CSV content as string
|
||||
csv_content = read_data_file("words.csv")
|
||||
|
||||
|
||||
# Use StringIO to create a file-like object from the string
|
||||
csv_file = StringIO(csv_content)
|
||||
reader = csv.DictReader(csv_file)
|
||||
|
||||
|
||||
for row in reader:
|
||||
# Process each word length column using config range
|
||||
for length in range(min_length, max_length + 1):
|
||||
|
|
@ -104,12 +109,12 @@ class WordLadderDataset(ProceduralDataset):
|
|||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error processing words.csv content: {e}") from e
|
||||
|
||||
|
||||
# Validate we have words for each length
|
||||
for length in range(min_length, max_length + 1):
|
||||
if length not in word_sets or not word_sets[length]:
|
||||
raise ValueError(f"No valid words found for length {length}")
|
||||
|
||||
|
||||
return word_sets
|
||||
|
||||
def _get_neighbors(self, word: str, word_set: Set[str]) -> Set[str]:
|
||||
|
|
@ -160,7 +165,7 @@ class WordLadderDataset(ProceduralDataset):
|
|||
# Use basic BFS for shortest path
|
||||
queue = deque([(start, [start])])
|
||||
visited = {start}
|
||||
|
||||
|
||||
while queue:
|
||||
current, path = queue.popleft()
|
||||
if current == end:
|
||||
|
|
@ -212,12 +217,7 @@ class WordLadderDataset(ProceduralDataset):
|
|||
return {
|
||||
"question": f"Transform the word ladder '{start}' to '{end}' by changing one letter at a time.",
|
||||
"answer": ",".join(path),
|
||||
"metadata": {
|
||||
"start_word": start,
|
||||
"end_word": end,
|
||||
"word_length": length,
|
||||
"chain_length": len(path)
|
||||
}
|
||||
"metadata": {"start_word": start, "end_word": end, "word_length": length, "chain_length": len(path)},
|
||||
}
|
||||
|
||||
register_dataset("word_ladder", WordLadderDataset, WordLadderConfig)
|
||||
register_dataset("word_ladder", WordLadderDataset, WordLadderConfig)
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ class WordSortingDataset(ProceduralDataset):
|
|||
# Load and preprocess text
|
||||
text = read_data_file("in_the_year_2889.txt")
|
||||
# Extract unique words within length constraints
|
||||
self.words = list(
|
||||
self.words = sorted(
|
||||
set(
|
||||
word
|
||||
for word in re.findall(r"\b\w+\b", text)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue