mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-26 17:13:17 +00:00
Revert "Restructure {reasoning_gym, tests}/{core, exercises, curricula}"
This reverts commit 10dbb374b0.
This commit is contained in:
parent
b756f26c09
commit
4c3ae0aebf
109 changed files with 0 additions and 0 deletions
45
reasoning_gym/algorithmic/__init__.py
Normal file
45
reasoning_gym/algorithmic/__init__.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
"""
|
||||
Algorithmic tasks for training reasoning capabilities:
|
||||
- Text processing
|
||||
- Counting
|
||||
- Sorting
|
||||
- Pattern matching
|
||||
"""
|
||||
|
||||
from .base_conversion import BaseConversionConfig, BaseConversionDataset
|
||||
from .caesar_cipher import CaesarCipherConfig, CaesarCipherDataset
|
||||
from .letter_counting import LetterCountingConfig, LetterCountingDataset
|
||||
from .letter_jumble import LetterJumbleConfig, LetterJumbleDataset
|
||||
from .number_filtering import NumberFilteringConfig, NumberFilteringDataset
|
||||
from .number_sorting import NumberSortingConfig, NumberSortingDataset
|
||||
from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset
|
||||
from .spell_backward import SpellBackwardConfig, SpellBackwardDataset
|
||||
from .word_ladder import WordLadderConfig, WordLadderDataset
|
||||
from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset
|
||||
from .word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset
|
||||
|
||||
__all__ = [
|
||||
"SpellBackwardConfig",
|
||||
"SpellBackwardDataset",
|
||||
"BaseConversionConfig",
|
||||
"BaseConversionDataset",
|
||||
"CaesarCipherConfig",
|
||||
"CaesarCipherDataset",
|
||||
"LetterCountingConfig",
|
||||
"LetterCountingDataset",
|
||||
"LetterJumbleConfig",
|
||||
"LetterJumbleDataset",
|
||||
"NumberFilteringConfig",
|
||||
"NumberFilteringDataset",
|
||||
"NumberSortingConfig",
|
||||
"NumberSortingDataset",
|
||||
"SentenceReorderingConfig",
|
||||
"SentenceReorderingDataset",
|
||||
"WordSequenceReversalConfig",
|
||||
"WordSequenceReversalDataset",
|
||||
"WordSortingConfig",
|
||||
"WordSortingDataset",
|
||||
"TextTransformation",
|
||||
"WordLadderConfig",
|
||||
"WordLadderDataset",
|
||||
]
|
||||
109
reasoning_gym/algorithmic/base_conversion.py
Normal file
109
reasoning_gym/algorithmic/base_conversion.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
"""Base conversion task generator"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseConversionConfig:
|
||||
"""Configuration for base conversion task generation"""
|
||||
|
||||
min_base: int = 2 # Minimum base (2=binary)
|
||||
max_base: int = 16 # Maximum base (16=hex)
|
||||
min_value: int = 0 # Minimum decimal value to convert
|
||||
max_value: int = 1000 # Maximum decimal value to convert
|
||||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert 2 <= self.min_base <= 36, "min_base must be between 2 and 36"
|
||||
assert self.min_base <= self.max_base <= 36, "max_base must be between min_base and 36"
|
||||
assert self.min_value >= 0, "min_value must be non-negative"
|
||||
assert self.max_value > self.min_value, "max_value must be > min_value"
|
||||
|
||||
|
||||
class BaseConversionDataset(ProceduralDataset):
|
||||
"""Generates base conversion tasks"""
|
||||
|
||||
def __init__(self, config: BaseConversionConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _format_base_name(self, base: int) -> str:
|
||||
"""Get human-readable name for common bases"""
|
||||
if base == 2:
|
||||
return "binary"
|
||||
elif base == 16:
|
||||
return "hexadecimal"
|
||||
else:
|
||||
return f"base-{base}"
|
||||
|
||||
def _generate_conversion(self, rng: Random) -> Tuple[int, int, int]:
|
||||
"""Generate random value and source/target bases"""
|
||||
value = rng.randint(self.config.min_value, self.config.max_value)
|
||||
|
||||
# Choose source and target bases
|
||||
source_base = rng.randint(self.config.min_base, self.config.max_base)
|
||||
target_base = rng.randint(self.config.min_base, self.config.max_base)
|
||||
while target_base == source_base: # Ensure different bases
|
||||
target_base = rng.randint(self.config.min_base, self.config.max_base)
|
||||
|
||||
return value, source_base, target_base
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single base conversion task"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
value, source_base, target_base = self._generate_conversion(rng)
|
||||
|
||||
# Convert decimal to source base representation
|
||||
if source_base == 16:
|
||||
source_repr = format(value, "x")
|
||||
elif source_base == 2:
|
||||
source_repr = format(value, "b")
|
||||
else:
|
||||
# Manual conversion for other bases
|
||||
n = value
|
||||
digits = []
|
||||
while n:
|
||||
digits.append(int(n % source_base))
|
||||
n //= source_base
|
||||
source_repr = "".join(str(d) if d < 10 else chr(ord("a") + d - 10) for d in reversed(digits) or [0])
|
||||
|
||||
# Convert decimal to target base for answer
|
||||
if target_base == 16:
|
||||
target_repr = format(value, "x")
|
||||
elif target_base == 2:
|
||||
target_repr = format(value, "b")
|
||||
else:
|
||||
# Manual conversion for other bases
|
||||
n = value
|
||||
digits = []
|
||||
while n:
|
||||
digits.append(int(n % target_base))
|
||||
n //= target_base
|
||||
target_repr = "".join(str(d) if d < 10 else chr(ord("a") + d - 10) for d in reversed(digits) or [0])
|
||||
|
||||
source_name = self._format_base_name(source_base)
|
||||
target_name = self._format_base_name(target_base)
|
||||
|
||||
# Add hint for bases > 10 about using lowercase letters
|
||||
hint = " (use lowercase letters a-z for digits above 9)" if target_base > 10 else ""
|
||||
|
||||
return {
|
||||
"question": f"Convert the {source_name} number {source_repr} to {target_name}{hint}",
|
||||
"answer": target_repr,
|
||||
"metadata": {
|
||||
"decimal_value": value,
|
||||
"source_base": source_base,
|
||||
"target_base": target_base,
|
||||
"source_repr": source_repr,
|
||||
"target_repr": target_repr,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
register_dataset("base_conversion", BaseConversionDataset, BaseConversionConfig)
|
||||
84
reasoning_gym/algorithmic/caesar_cipher.py
Normal file
84
reasoning_gym/algorithmic/caesar_cipher.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
"""Caesar cipher task generator"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from string import ascii_uppercase
|
||||
from typing import List, Optional
|
||||
|
||||
from reasoning_gym.data import read_data_file
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class CaesarCipherConfig:
|
||||
"""Configuration for Caesar cipher task generation"""
|
||||
|
||||
delimiter: str = "." # Delimiter for splitting text into sentences
|
||||
min_words: int = 3 # Minimum words per sentence
|
||||
max_words: int = 20 # Maximum words per sentence
|
||||
min_rotation: int = 1 # Minimum Caesar rotation
|
||||
max_rotation: int = 25 # Maximum Caesar rotation
|
||||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_words > 0, "min_words must be positive"
|
||||
assert self.max_words >= self.min_words, "max_words must be >= min_words"
|
||||
assert 0 < self.min_rotation <= self.max_rotation < 26, "rotation must be in range [1,25]"
|
||||
|
||||
|
||||
class CaesarCipherDataset(ProceduralDataset):
|
||||
"""Generates Caesar cipher encryption/decryption tasks"""
|
||||
|
||||
def __init__(self, config: CaesarCipherConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
# Load and preprocess text
|
||||
text = read_data_file("in_the_year_2889.txt")
|
||||
|
||||
# Split into sentences and filter
|
||||
sentences = [s.strip() for s in text.split(config.delimiter) if s.strip()]
|
||||
|
||||
# Process each sentence
|
||||
self.valid_sentences = []
|
||||
for sentence in sentences:
|
||||
# Split into words and filter for alpha-only
|
||||
words = [w.upper() for w in sentence.split() if w.isalpha()]
|
||||
if self.config.min_words <= len(words) <= self.config.max_words:
|
||||
self.valid_sentences.append(" ".join(words))
|
||||
|
||||
def _caesar_encrypt(self, text: str, rotation: int) -> str:
|
||||
"""Apply Caesar cipher encryption with given rotation"""
|
||||
result = []
|
||||
for char in text:
|
||||
if char.isalpha():
|
||||
# Convert to 0-25 range, rotate, convert back to ASCII
|
||||
base = ord("A")
|
||||
rotated = (ord(char) - base + rotation) % 26
|
||||
result.append(chr(base + rotated))
|
||||
else:
|
||||
result.append(char)
|
||||
return "".join(result)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single Caesar cipher task"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
# Select random sentence and rotation
|
||||
sentence = rng.choice(self.valid_sentences)
|
||||
rotation = rng.randint(self.config.min_rotation, self.config.max_rotation)
|
||||
|
||||
# Generate cipher text
|
||||
cipher_text = self._caesar_encrypt(sentence, rotation)
|
||||
|
||||
return {
|
||||
"question": f"Decrypt this Caesar cipher text: {cipher_text}",
|
||||
"answer": sentence,
|
||||
"metadata": {"rotation": rotation, "cipher_text": cipher_text, "clear_text": sentence},
|
||||
}
|
||||
|
||||
|
||||
register_dataset("caesar_cipher", CaesarCipherDataset, CaesarCipherConfig)
|
||||
66
reasoning_gym/algorithmic/letter_counting.py
Normal file
66
reasoning_gym/algorithmic/letter_counting.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
"""Letter counting task generator"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import List, Optional
|
||||
|
||||
from reasoning_gym.data import read_data_file
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class LetterCountingConfig:
|
||||
"""Configuration for letter counting task generation"""
|
||||
|
||||
min_words: int = 5 # Minimum words in span
|
||||
max_words: int = 15 # Maximum words in span
|
||||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_words > 0, "min_words must be positive"
|
||||
assert self.max_words >= self.min_words, "max_words must be >= min_words"
|
||||
|
||||
|
||||
class LetterCountingDataset(ProceduralDataset):
|
||||
"""Generates letter counting tasks from text spans"""
|
||||
|
||||
def __init__(self, config: LetterCountingConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
# Load and preprocess text
|
||||
text = read_data_file("in_the_year_2889.txt")
|
||||
# Extract words and clean them to contain only alphanumeric characters
|
||||
self.words = [word for word in re.findall(r"\b\w+\b", text) if word.isalnum()]
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single letter counting task"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
# Select random span of words
|
||||
span_length = rng.randint(self.config.min_words, self.config.max_words)
|
||||
start_idx = rng.randint(0, len(self.words) - span_length)
|
||||
span = self.words[start_idx : start_idx + span_length]
|
||||
|
||||
# Get all unique letters from span
|
||||
letters = set("".join(span).lower())
|
||||
if not letters:
|
||||
letters = {"a"} # Fallback if span has no letters
|
||||
|
||||
# Select random letter that appears in the span
|
||||
target_letter = rng.choice(sorted(letters))
|
||||
|
||||
# Count occurrences
|
||||
count = sum(word.lower().count(target_letter) for word in span)
|
||||
|
||||
return {
|
||||
"question": f'How many times does the letter "{target_letter}" appear in the text: "{" ".join(span)}"?',
|
||||
"answer": str(count),
|
||||
"metadata": {"span_length": span_length, "target_letter": target_letter, "span": span},
|
||||
}
|
||||
|
||||
|
||||
register_dataset("letter_counting", LetterCountingDataset, LetterCountingConfig)
|
||||
103
reasoning_gym/algorithmic/letter_jumble.py
Normal file
103
reasoning_gym/algorithmic/letter_jumble.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
"""Word letter jumbling task generator"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import List, Optional
|
||||
|
||||
from reasoning_gym.data import read_data_file
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class LetterJumbleConfig:
|
||||
"""Configuration for letter jumbling task generation"""
|
||||
|
||||
min_word_len: int = 1 # Minimum word length
|
||||
max_word_len: int = 64 # Maximum word length
|
||||
min_words: int = 3 # Minimum words per task
|
||||
max_words: int = 20 # Maximum words per task
|
||||
min_corruption_level: float = 0.1 # Minimum fraction of characters to swap
|
||||
max_corruption_level: float = 0.9 # Maximum fraction of characters to swap
|
||||
consecutive_words: bool = True # Whether to select consecutive words from text
|
||||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_word_len > 0, "min_word_len must be positive"
|
||||
assert self.max_word_len >= self.min_word_len, "max_word_len must be >= min_word_len"
|
||||
assert self.min_words > 0, "min_words must be positive"
|
||||
assert self.max_words >= self.min_words, "max_words must be >= min_words"
|
||||
assert 0 <= self.min_corruption_level <= 1, "min_corruption_level must be in [0,1]"
|
||||
assert 0 <= self.max_corruption_level <= 1, "max_corruption_level must be in [0,1]"
|
||||
assert (
|
||||
self.max_corruption_level >= self.min_corruption_level
|
||||
), "max_corruption_level must be >= min_corruption_level"
|
||||
|
||||
|
||||
class LetterJumbleDataset(ProceduralDataset):
|
||||
"""Generates word letter jumbling tasks"""
|
||||
|
||||
def __init__(self, config: LetterJumbleConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
# Load and preprocess text
|
||||
text = read_data_file("in_the_year_2889.txt")
|
||||
# Extract words and filter by length
|
||||
self.words = [
|
||||
word
|
||||
for word in re.findall(r"\b\w+\b", text)
|
||||
if self.config.min_word_len <= len(word) <= self.config.max_word_len and word.isalpha()
|
||||
]
|
||||
|
||||
def _scramble_word(self, word: str, corruption_level: float, rng: Random) -> str:
|
||||
"""Scramble a word by swapping random pairs of characters"""
|
||||
if len(word) < 2: # Can't scramble 1-character words
|
||||
return word
|
||||
|
||||
word = list(word)
|
||||
num_swaps = max(1, int(len(word) * corruption_level)) # Ensure at least one swap
|
||||
|
||||
for _ in range(num_swaps):
|
||||
# Pick two different random positions
|
||||
pos1, pos2 = rng.sample(range(len(word)), 2)
|
||||
# Swap characters
|
||||
word[pos1], word[pos2] = word[pos2], word[pos1]
|
||||
|
||||
return "".join(word)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single word jumbling task"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
# Select number of words and corruption level
|
||||
num_words = rng.randint(self.config.min_words, self.config.max_words)
|
||||
corruption_level = rng.uniform(self.config.min_corruption_level, self.config.max_corruption_level)
|
||||
|
||||
# Select words based on configuration
|
||||
if self.config.consecutive_words:
|
||||
# Select consecutive words from a random starting position
|
||||
start_idx = rng.randint(0, len(self.words) - num_words)
|
||||
selected_words = self.words[start_idx : start_idx + num_words]
|
||||
else:
|
||||
# Select random words
|
||||
selected_words = rng.sample(self.words, num_words)
|
||||
|
||||
# Scramble each word
|
||||
scrambled_words = [self._scramble_word(word, corruption_level, rng) for word in selected_words]
|
||||
|
||||
return {
|
||||
"question": f"Unscramble these words: {' '.join(scrambled_words)}",
|
||||
"answer": " ".join(selected_words),
|
||||
"metadata": {
|
||||
"num_words": num_words,
|
||||
"corruption_level": corruption_level,
|
||||
"scrambled_words": scrambled_words,
|
||||
"original_words": selected_words,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
register_dataset("letter_jumble", LetterJumbleDataset, LetterJumbleConfig)
|
||||
101
reasoning_gym/algorithmic/number_filtering.py
Normal file
101
reasoning_gym/algorithmic/number_filtering.py
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
"""Number filtering task generator"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class NumberFilteringConfig:
|
||||
"""Configuration for number filtering task generation"""
|
||||
|
||||
min_numbers: int = 3 # Minimum numbers in list
|
||||
max_numbers: int = 10 # Maximum numbers in list
|
||||
min_decimals: int = 0 # Minimum decimal places
|
||||
max_decimals: int = 4 # Maximum decimal places
|
||||
min_value: float = -100.0 # Minimum number value
|
||||
max_value: float = 100.0 # Maximum number value
|
||||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_numbers > 0, "min_numbers must be positive"
|
||||
assert self.max_numbers >= self.min_numbers, "max_numbers must be >= min_numbers"
|
||||
assert self.min_decimals >= 0, "min_decimals must be non-negative"
|
||||
assert self.max_decimals >= self.min_decimals, "max_decimals must be >= min_decimals"
|
||||
assert self.max_value > self.min_value, "max_value must be > min_value"
|
||||
|
||||
|
||||
class NumberFilteringDataset(ProceduralDataset):
|
||||
"""Generates number filtering tasks"""
|
||||
|
||||
def __init__(self, config: NumberFilteringConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _format_number(self, num: float, decimals: int) -> str:
|
||||
"""Format a number with specified decimal places"""
|
||||
return f"{num:.{decimals}f}"
|
||||
|
||||
def _generate_numbers(self, rng: Random) -> Tuple[List[float], List[str]]:
|
||||
"""Generate list of numbers and their string representations"""
|
||||
count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
||||
numbers = []
|
||||
str_numbers = []
|
||||
|
||||
for _ in range(count):
|
||||
num = rng.uniform(self.config.min_value, self.config.max_value)
|
||||
decimals = rng.randint(self.config.min_decimals, self.config.max_decimals)
|
||||
str_num = self._format_number(num, decimals)
|
||||
numbers.append(float(str_num)) # Convert back to simulate precision loss
|
||||
str_numbers.append(str_num)
|
||||
|
||||
return numbers, str_numbers
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single number filtering task"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
# Generate numbers and their string representations
|
||||
numbers, str_numbers = self._generate_numbers(rng)
|
||||
|
||||
# Determine filter value between min and max of generated numbers
|
||||
min_val = min(numbers)
|
||||
max_val = max(numbers)
|
||||
filter_value = rng.uniform(min_val, max_val)
|
||||
decimals = rng.randint(self.config.min_decimals, self.config.max_decimals)
|
||||
filter_str = self._format_number(filter_value, decimals)
|
||||
filter_value = float(filter_str) # Convert back to simulate precision loss
|
||||
|
||||
# Randomly choose filter operation
|
||||
keep_larger = rng.choice([True, False])
|
||||
larger_smaller = "larger" if keep_larger else "smaller"
|
||||
keep_remove = "keep" if rng.choice([True, False]) else "remove"
|
||||
|
||||
# Apply filter based on chosen operation
|
||||
if keep_remove == "keep":
|
||||
result = [n for n in numbers if (n > filter_value if keep_larger else n < filter_value)]
|
||||
else: # remove
|
||||
result = [n for n in numbers if (n <= filter_value if keep_larger else n >= filter_value)]
|
||||
|
||||
# Format results as strings with original precision
|
||||
result_strs = [str_numbers[numbers.index(n)] for n in result]
|
||||
|
||||
return {
|
||||
"question": (
|
||||
f"{keep_remove.capitalize()} all numbers {larger_smaller} than {filter_str} "
|
||||
f"in this list: {str_numbers}"
|
||||
),
|
||||
"answer": str(result_strs) if result_strs else "[]",
|
||||
"metadata": {
|
||||
"original_numbers": str_numbers,
|
||||
"filter_value": filter_str,
|
||||
"operation": f"{keep_remove}_{larger_smaller}",
|
||||
"result": result_strs,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
register_dataset("number_filtering", NumberFilteringDataset, NumberFilteringConfig)
|
||||
89
reasoning_gym/algorithmic/number_sorting.py
Normal file
89
reasoning_gym/algorithmic/number_sorting.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
"""Number sorting task generator"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class NumberSortingConfig:
|
||||
"""Configuration for number sorting task generation"""
|
||||
|
||||
min_numbers: int = 3 # Minimum numbers to sort
|
||||
max_numbers: int = 10 # Maximum numbers to sort
|
||||
min_decimals: int = 0 # Minimum decimal places
|
||||
max_decimals: int = 2 # Maximum decimal places
|
||||
min_value: float = -100.0 # Minimum value
|
||||
max_value: float = 100.0 # Maximum value
|
||||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_numbers > 0, "min_numbers must be positive"
|
||||
assert self.min_numbers <= self.max_numbers, "max_numbers must be >= min_numbers"
|
||||
assert self.min_decimals >= 0, "min_decimals must be non-negative"
|
||||
assert self.min_decimals <= self.max_decimals, "max_decimals must be >= min_decimals"
|
||||
assert self.min_value < self.max_value, "max_value must be > min_value"
|
||||
|
||||
|
||||
class NumberSortingDataset(ProceduralDataset):
|
||||
"""Generates number sorting tasks"""
|
||||
|
||||
def __init__(self, config: NumberSortingConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _format_number(self, num: float, decimals: int) -> str:
|
||||
"""Format number with specified decimal places"""
|
||||
formatted = f"{num:.{decimals}f}"
|
||||
# Reparse to ensure exact decimal representation
|
||||
return f"{float(formatted):.{decimals}f}"
|
||||
|
||||
def _generate_numbers(self, rng: Random) -> Tuple[List[float], List[str]]:
|
||||
"""Generate list of numbers and their string representations"""
|
||||
count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
||||
decimals = rng.randint(self.config.min_decimals, self.config.max_decimals)
|
||||
|
||||
numbers = []
|
||||
number_strs = []
|
||||
|
||||
for _ in range(count):
|
||||
num = rng.uniform(self.config.min_value, self.config.max_value)
|
||||
num_str = self._format_number(num, decimals)
|
||||
# Reparse to ensure exact value
|
||||
num = float(num_str)
|
||||
numbers.append(num)
|
||||
number_strs.append(num_str)
|
||||
|
||||
return numbers, number_strs
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single sorting task"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
numbers, number_strs = self._generate_numbers(rng)
|
||||
|
||||
# Generate both ascending and descending answers
|
||||
asc_numbers = sorted(numbers)
|
||||
desc_numbers = sorted(numbers, reverse=True)
|
||||
|
||||
# Format answers as string lists
|
||||
decimals = len(number_strs[0].split(".")[-1]) if "." in number_strs[0] else 0
|
||||
asc_answer = [self._format_number(n, decimals) for n in asc_numbers]
|
||||
desc_answer = [self._format_number(n, decimals) for n in desc_numbers]
|
||||
|
||||
# Randomly choose ascending or descending
|
||||
is_ascending = rng.choice([True, False])
|
||||
direction = "ascending" if is_ascending else "descending"
|
||||
answer = asc_answer if is_ascending else desc_answer
|
||||
|
||||
return {
|
||||
"question": f"Sort these numbers in {direction} order: {', '.join(number_strs)}",
|
||||
"answer": str(answer),
|
||||
"metadata": {"original_numbers": number_strs, "direction": direction, "sorted_numbers": answer},
|
||||
}
|
||||
|
||||
|
||||
register_dataset("number_sorting", NumberSortingDataset, NumberSortingConfig)
|
||||
96
reasoning_gym/algorithmic/sentence_reordering.py
Normal file
96
reasoning_gym/algorithmic/sentence_reordering.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
"""Sentence re-ordering task generator"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import List, Optional
|
||||
|
||||
from ..data import read_data_file
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class SentenceReorderingConfig:
|
||||
"""Configuration for sentence reordering task generation"""
|
||||
|
||||
min_words_in_sentence: int = 3
|
||||
max_words_in_sentence: int = 20
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_words_in_sentence > 0, "min_words_in_sentence must be positive"
|
||||
assert (
|
||||
self.max_words_in_sentence >= self.min_words_in_sentence
|
||||
), "max_words_in_sentence must be >= min_words_in_sentence"
|
||||
assert (
|
||||
self.max_words_in_sentence >= self.min_words_in_sentence
|
||||
), "max_words_in_sentence must be >= min_words_in_sentence"
|
||||
|
||||
|
||||
class SentenceReorderingDataset(ProceduralDataset):
|
||||
"""Generates sentence reordering tasks from text spans"""
|
||||
|
||||
def __init__(self, config: SentenceReorderingConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
# Load and preprocess text
|
||||
text = read_data_file("in_the_year_2889.txt")
|
||||
# Extract sentences make sure they are greater than or equal to the number of words in a sentence
|
||||
# Ensure that only the length of alphanumeric characters in the sentence is considered
|
||||
self.sentences = [
|
||||
sentence.strip()
|
||||
for sentence in re.findall(r"[^.!?]+[.!?]", text) # Changed pattern to include the ending punctuation
|
||||
if self.config.min_words_in_sentence
|
||||
<= len(re.findall(r"\b\w+\b", sentence))
|
||||
<= self.config.max_words_in_sentence
|
||||
]
|
||||
|
||||
def _generate_sentence_dataset(self, sentence: str, seed: int, idx: int, shuffle=True):
|
||||
"""
|
||||
Generate a procedural dataset by shuffling the words in the input sentence.
|
||||
|
||||
Args:
|
||||
sentence (str): The correct sentence to use for dataset generation.
|
||||
seed (int): The seed to use for random number generation.
|
||||
idx (int): The index to add to the seed for random number generation.
|
||||
shuffle (bool): Whether to shuffle the words to create the input sentence.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the input sentence and the correct sentence (goal).
|
||||
"""
|
||||
rng = Random(seed + idx)
|
||||
words = sentence.split() # Split the sentence into words
|
||||
scrambled_words = words.copy()
|
||||
if shuffle:
|
||||
rng.shuffle(scrambled_words) # Shuffle the words to generate the input
|
||||
input_sentence = " ".join(scrambled_words)
|
||||
goal_sentence = " ".join(words)
|
||||
return {"input": input_sentence, "goal": goal_sentence}
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single sentence reordering task"""
|
||||
rng = Random(self.seed + idx)
|
||||
sentence_dataset = self._generate_sentence_dataset(rng.choice(self.sentences), self.seed, idx)
|
||||
|
||||
# Ensure only 'input' and 'goal' keys are present
|
||||
if set(sentence_dataset.keys()) != {"input", "goal"}:
|
||||
raise KeyError("The dictionary must contain only 'input' and 'goal' keys")
|
||||
|
||||
# Solve the task by sorting words to match the goal sentence
|
||||
input_words = sentence_dataset["input"].split()
|
||||
question = " ".join(input_words)
|
||||
goal_words = sentence_dataset["goal"].split()
|
||||
solved_sentence = " ".join(sorted(input_words, key=lambda word: goal_words.index(word)))
|
||||
# Check for length of alphanumeric characters in the solved sentence
|
||||
word_count = len(re.findall(r"\b\w+\b", solved_sentence))
|
||||
|
||||
return {
|
||||
"question": f"Restore the correct order of words in the following sentence: {question}",
|
||||
"answer": solved_sentence,
|
||||
"metadata": {"word_count": word_count},
|
||||
}
|
||||
|
||||
|
||||
register_dataset("sentence_reordering", SentenceReorderingDataset, SentenceReorderingConfig)
|
||||
53
reasoning_gym/algorithmic/spell_backward.py
Normal file
53
reasoning_gym/algorithmic/spell_backward.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
"""Spell backward task generator"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Optional
|
||||
|
||||
from ..data import read_data_file
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpellBackwardConfig:
|
||||
"""Configuration for spelling words backward task generation"""
|
||||
|
||||
min_word_len: int = 3 # Minimum word length
|
||||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_word_len > 0, "min_word_len must be positive"
|
||||
|
||||
|
||||
class SpellBackwardDataset(ProceduralDataset):
|
||||
"""Generates tasks to spell words backward"""
|
||||
|
||||
def __init__(self, config: SpellBackwardConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
# Load and preprocess text
|
||||
text = read_data_file("in_the_year_2889.txt")
|
||||
# Extract words and clean them to contain only alphanumeric characters
|
||||
self.words = [
|
||||
word for word in re.findall(r"\b\w+\b", text) if word.isalnum() and len(word) >= config.min_word_len
|
||||
]
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single spell backward task"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
# Select random word
|
||||
word = rng.choice(self.words)
|
||||
answer = word[::-1]
|
||||
|
||||
return {
|
||||
"question": f"Spell this word backward (example: sun -> nus): {word}",
|
||||
"answer": answer,
|
||||
"metadata": {"word": word, "word_len": len(word)},
|
||||
}
|
||||
|
||||
|
||||
register_dataset("spell_backward", SpellBackwardDataset, SpellBackwardConfig)
|
||||
208
reasoning_gym/algorithmic/word_ladder.py
Normal file
208
reasoning_gym/algorithmic/word_ladder.py
Normal file
|
|
@ -0,0 +1,208 @@
|
|||
"""Word ladder task generator"""
|
||||
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
from reasoning_gym.data import read_data_file
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class WordLadderConfig:
|
||||
"""Configuration for word ladder task generation"""
|
||||
|
||||
min_word_length: int = 3 # Minimum word length
|
||||
max_word_length: int = 5 # Maximum word length
|
||||
min_chain_length: int = -1 # Set to -1 for shortest path or a minimum of 3
|
||||
max_chain_length: int = -1 # Set to -1 for shortest path or a max
|
||||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_word_length > 2, "min_word_length must be 3"
|
||||
assert self.max_word_length >= self.min_word_length, "max_word_length must be >= min_word_length"
|
||||
assert self.max_word_length <= 5, "max_word_length must be 5"
|
||||
|
||||
# Modified validation logic
|
||||
if self.min_chain_length == -1:
|
||||
if self.max_chain_length != -1:
|
||||
assert (
|
||||
self.max_chain_length >= 3
|
||||
), "When min_chain_length=-1 (shortest path), max_chain_length must be -1 or >=3"
|
||||
elif self.max_chain_length == -1:
|
||||
raise AssertionError("max_chain_length cannot be -1 unless min_chain_length is also -1")
|
||||
else:
|
||||
assert self.min_chain_length >= 3, "min_chain_length must be 3 or -1"
|
||||
assert self.max_chain_length >= self.min_chain_length, "max_chain_length must be >= min_chain_length"
|
||||
|
||||
|
||||
class WordLadderDataset(ProceduralDataset):
|
||||
"""Generates word ladder transformation tasks"""
|
||||
|
||||
def __init__(self, config: WordLadderConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
# Load words from CSV file
|
||||
self.word_sets = self._load_words_from_csv()
|
||||
|
||||
def _load_words_from_csv(self) -> Dict[int, Set[str]]:
|
||||
"""Load words from CSV file organized by length"""
|
||||
import csv
|
||||
from io import StringIO
|
||||
|
||||
word_sets = {}
|
||||
|
||||
try:
|
||||
# Get CSV content as string
|
||||
csv_content = read_data_file("words.csv")
|
||||
|
||||
# Use StringIO to create a file-like object from the string
|
||||
csv_file = StringIO(csv_content)
|
||||
reader = csv.DictReader(csv_file)
|
||||
|
||||
for row in reader:
|
||||
# Process each word length column
|
||||
for length in range(3, 6):
|
||||
col_name = f"{length}_letter"
|
||||
word = row.get(col_name, "")
|
||||
|
||||
if not word: # Skip empty entries
|
||||
continue
|
||||
|
||||
if self.config.min_word_length <= length <= self.config.max_word_length:
|
||||
word_sets.setdefault(length, set()).add(word.upper())
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error processing words.csv content: {e}") from e
|
||||
|
||||
# Validate we have words for each length
|
||||
for length in range(self.config.min_word_length, self.config.max_word_length + 1):
|
||||
if length not in word_sets or not word_sets[length]:
|
||||
raise ValueError(f"No valid words found for length {length}")
|
||||
|
||||
return word_sets
|
||||
|
||||
def _differs_by_one(self, word1: str, word2: str) -> bool:
|
||||
"""Check if two words differ by exactly one letter"""
|
||||
if len(word1) != len(word2):
|
||||
return False
|
||||
differences = 0
|
||||
for c1, c2 in zip(word1, word2):
|
||||
if c1 != c2:
|
||||
differences += 1
|
||||
if differences > 1:
|
||||
return False
|
||||
return differences == 1
|
||||
|
||||
def _find_path(self, start: str, end: str, word_set: Set[str]) -> Optional[List[str]]:
|
||||
"""Find path between start and end words that meets length requirements"""
|
||||
if start == end:
|
||||
return [start]
|
||||
|
||||
# First find shortest path length
|
||||
shortest_path = self._bfs_shortest_path(start, end, word_set)
|
||||
if not shortest_path:
|
||||
return None
|
||||
|
||||
min_length = self.config.min_chain_length
|
||||
if len(shortest_path) > min_length:
|
||||
return shortest_path # Shortest path is already longer than required
|
||||
|
||||
# Now look for longer paths using DFS with depth constraint
|
||||
return self._dfs_with_depth(start, end, word_set, min_length)
|
||||
|
||||
def _bfs_shortest_path(self, start: str, end: str, word_set: Set[str]) -> Optional[List[str]]:
|
||||
"""BFS implementation to find shortest path"""
|
||||
queue = deque([(start, [start])])
|
||||
visited = {start}
|
||||
|
||||
while queue:
|
||||
current, path = queue.popleft()
|
||||
if current == end:
|
||||
return path
|
||||
|
||||
for neighbor in self._get_neighbors(current, word_set):
|
||||
if neighbor not in visited:
|
||||
visited.add(neighbor)
|
||||
queue.append((neighbor, path + [neighbor]))
|
||||
return None
|
||||
|
||||
def _dfs_with_depth(self, start: str, end: str, word_set: Set[str], target_length: int) -> Optional[List[str]]:
|
||||
"""DFS implementation looking for paths of exact length"""
|
||||
stack = [(start, [start], set([start]))]
|
||||
|
||||
while stack:
|
||||
current, path, visited = stack.pop()
|
||||
|
||||
if len(path) == target_length:
|
||||
if current == end:
|
||||
return path
|
||||
continue
|
||||
|
||||
if len(path) > target_length:
|
||||
continue
|
||||
|
||||
# Explore neighbors in random order to find different paths
|
||||
neighbors = list(self._get_neighbors(current, word_set))
|
||||
Random().shuffle(neighbors)
|
||||
|
||||
for neighbor in neighbors:
|
||||
if neighbor not in visited:
|
||||
new_visited = set(visited)
|
||||
new_visited.add(neighbor)
|
||||
stack.append((neighbor, path + [neighbor], new_visited))
|
||||
|
||||
return None
|
||||
|
||||
def _get_neighbors(self, word: str, word_set: Set[str]) -> Set[str]:
|
||||
"""Get all valid neighbors that differ by one letter"""
|
||||
neighbors = set()
|
||||
word_chars = list(word)
|
||||
|
||||
for i in range(len(word_chars)):
|
||||
original = word_chars[i]
|
||||
for c in "ABCDEFGHIJKLMNOPQRSTUVWXYZ":
|
||||
if c == original:
|
||||
continue
|
||||
word_chars[i] = c
|
||||
new_word = "".join(word_chars)
|
||||
if new_word in word_set:
|
||||
neighbors.add(new_word)
|
||||
word_chars[i] = original
|
||||
|
||||
return neighbors
|
||||
|
||||
def _generate_word_pair(self, rng: Random, length: int) -> Tuple[str, str, List[str]]:
|
||||
"""Generate valid start/end words with solution path"""
|
||||
word_set = self.word_sets[length]
|
||||
max_attempts = 500
|
||||
|
||||
for _ in range(max_attempts):
|
||||
start, end = rng.sample(sorted(word_set), 2)
|
||||
path = self._find_path(start, end, word_set)
|
||||
if path and (
|
||||
(self.config.min_chain_length == -1 and self.config.max_chain_length == -1)
|
||||
or (self.config.min_chain_length <= len(path) <= self.config.max_chain_length)
|
||||
):
|
||||
return start, end, path
|
||||
|
||||
raise RuntimeError(f"Failed to find valid pair for length {length} after {max_attempts} attempts")
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single word ladder task"""
|
||||
rng = Random(self.seed + idx)
|
||||
length = rng.randint(self.config.min_word_length, self.config.max_word_length)
|
||||
start, end, path = self._generate_word_pair(rng, length)
|
||||
|
||||
return {
|
||||
"question": f"Transform the word '{start}' into '{end}' by changing one letter at a time. Each step must create a valid English word (including plurals) and keep the same word length. Show the sequence of words needed.",
|
||||
"answer": ",".join(path),
|
||||
"metadata": {"start_word": start, "end_word": end, "word_length": length, "chain_length": len(path)},
|
||||
}
|
||||
|
||||
|
||||
register_dataset("word_ladder", WordLadderDataset, WordLadderConfig)
|
||||
58
reasoning_gym/algorithmic/word_sequence_reversal.py
Normal file
58
reasoning_gym/algorithmic/word_sequence_reversal.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
"""Word reversal task generator"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import List, Optional
|
||||
|
||||
from ..data import read_data_file
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class WordSequenceReversalConfig:
|
||||
"""Configuration for word sequence reversal task generation"""
|
||||
|
||||
min_words: int = 3 # Minimum words in list
|
||||
max_words: int = 8 # Maximum words in list
|
||||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_words > 0, "min_words must be positive"
|
||||
assert self.max_words >= self.min_words, "max_words must be >= min_words"
|
||||
|
||||
|
||||
class WordSequenceReversalDataset(ProceduralDataset):
|
||||
"""Generates word sequence reversal tasks from text spans"""
|
||||
|
||||
def __init__(self, config: WordSequenceReversalConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
# Load and preprocess text
|
||||
text = read_data_file("in_the_year_2889.txt")
|
||||
# Extract words and clean them to contain only alphanumeric characters
|
||||
self.words = [word for word in re.findall(r"\b\w+\b", text) if word.isalnum()]
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single word reversal task"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
# Select random words
|
||||
num_words = rng.randint(self.config.min_words, self.config.max_words)
|
||||
word_indices = rng.sample(range(len(self.words)), num_words)
|
||||
words = [self.words[i] for i in word_indices]
|
||||
|
||||
# Create question and answer
|
||||
question = ", ".join(words)
|
||||
answer = ", ".join(reversed(words))
|
||||
|
||||
return {
|
||||
"question": f"Reverse this list of words: {question}",
|
||||
"answer": answer,
|
||||
"metadata": {"num_words": num_words, "words": words},
|
||||
}
|
||||
|
||||
|
||||
register_dataset("word_sequence_reversal", WordSequenceReversalDataset, WordSequenceReversalConfig)
|
||||
109
reasoning_gym/algorithmic/word_sorting.py
Normal file
109
reasoning_gym/algorithmic/word_sorting.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
"""Word sorting task generator"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from random import Random
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ..data import read_data_file
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
class TextTransformation(StrEnum):
|
||||
"""Text transformation options"""
|
||||
|
||||
LOWERCASE = "lowercase"
|
||||
UPPERCASE = "uppercase"
|
||||
ORIGINAL = "original"
|
||||
RANDOMCASE = "randomcase"
|
||||
|
||||
|
||||
@dataclass
|
||||
class WordSortingConfig:
|
||||
"""Configuration for word sorting task generation"""
|
||||
|
||||
min_words: int = 3 # Minimum words to sort
|
||||
max_words: int = 10 # Maximum words to sort
|
||||
min_word_length: int = 3 # Minimum word length
|
||||
max_word_length: int = 12 # Maximum word length
|
||||
transformation: TextTransformation = TextTransformation.ORIGINAL
|
||||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_words > 0, "min_words must be positive"
|
||||
assert self.min_words <= self.max_words, "max_words must be >= min_words"
|
||||
assert self.min_word_length > 0, "min_word_length must be positive"
|
||||
assert self.min_word_length <= self.max_word_length, "max_word_length must be >= min_word_length"
|
||||
assert isinstance(self.transformation, TextTransformation), "transformation must be a TextTransformation"
|
||||
|
||||
|
||||
class WordSortingDataset(ProceduralDataset):
|
||||
"""Generates word sorting tasks"""
|
||||
|
||||
def __init__(self, config: WordSortingConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
# Load and preprocess text
|
||||
text = read_data_file("in_the_year_2889.txt")
|
||||
# Extract unique words within length constraints
|
||||
self.words = sorted(
|
||||
set(
|
||||
word
|
||||
for word in re.findall(r"\b\w+\b", text)
|
||||
if self.config.min_word_length <= len(word) <= self.config.max_word_length
|
||||
)
|
||||
)
|
||||
|
||||
def _transform_word(self, word: str, rng: Random) -> str:
|
||||
"""Apply configured transformation to word"""
|
||||
if self.config.transformation == TextTransformation.LOWERCASE:
|
||||
return word.lower()
|
||||
elif self.config.transformation == TextTransformation.UPPERCASE:
|
||||
return word.upper()
|
||||
elif self.config.transformation == TextTransformation.RANDOMCASE:
|
||||
return "".join(c.upper() if rng.choice([True, False]) else c.lower() for c in word)
|
||||
return word # ORIGINAL case
|
||||
|
||||
def _generate_words(self, rng: Random) -> Tuple[List[str], List[str]]:
|
||||
"""Generate list of words and their transformed versions"""
|
||||
count = rng.randint(self.config.min_words, self.config.max_words)
|
||||
|
||||
# Select random words
|
||||
selected_words = rng.sample(self.words, count)
|
||||
# Apply transformation
|
||||
transformed_words = [self._transform_word(word, rng) for word in selected_words]
|
||||
|
||||
return selected_words, transformed_words
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single sorting task"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
original_words, transformed_words = self._generate_words(rng)
|
||||
|
||||
# Generate both ascending and descending answers
|
||||
asc_words = sorted(transformed_words)
|
||||
desc_words = sorted(transformed_words, reverse=True)
|
||||
|
||||
# Randomly choose ascending or descending
|
||||
is_ascending = rng.choice([True, False])
|
||||
direction = "ascending" if is_ascending else "descending"
|
||||
answer = asc_words if is_ascending else desc_words
|
||||
|
||||
return {
|
||||
"question": f"Sort these words in {direction} order (using ASCII/Unicode ordering) and return them as a comma-separated list:\n{', '.join(transformed_words)}",
|
||||
"answer": ", ".join(answer),
|
||||
"metadata": {
|
||||
"original_words": original_words,
|
||||
"transformed_words": transformed_words,
|
||||
"direction": direction,
|
||||
"transformation": self.config.transformation,
|
||||
"sorted_words": answer,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
register_dataset("word_sorting", WordSortingDataset, WordSortingConfig)
|
||||
Loading…
Add table
Add a link
Reference in a new issue