remove old files

This commit is contained in:
Andreas Koepf 2025-01-25 18:51:07 +01:00
parent 7c61924335
commit 5fc0b1bdc3
2 changed files with 0 additions and 223 deletions

View file

@ -1,103 +0,0 @@
"""Word unscrambling task generator"""
import re
from dataclasses import dataclass
from random import Random
from typing import List, Optional
from reasoning_gym.data import read_data_file
from ..factory import ProceduralDataset, register_dataset
@dataclass
class UnscrambleWordsConfig:
"""Configuration for word unscrambling task generation"""
min_word_len: int = 4 # Minimum word length
max_word_len: int = 64 # Maximum word length
min_words: int = 3 # Minimum words per task
max_words: int = 20 # Maximum words per task
min_corruption_level: float = 0.1 # Minimum fraction of characters to swap
max_corruption_level: float = 0.9 # Maximum fraction of characters to swap
consecutive_words: bool = True # Whether to select consecutive words from text
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_word_len > 0, "min_word_len must be positive"
assert self.max_word_len >= self.min_word_len, "max_word_len must be >= min_word_len"
assert self.min_words > 0, "min_words must be positive"
assert self.max_words >= self.min_words, "max_words must be >= min_words"
assert 0 <= self.min_corruption_level <= 1, "min_corruption_level must be in [0,1]"
assert 0 <= self.max_corruption_level <= 1, "max_corruption_level must be in [0,1]"
assert self.max_corruption_level >= self.min_corruption_level, "max_corruption_level must be >= min_corruption_level"
class UnscrambleWordsDataset(ProceduralDataset):
"""Generates word unscrambling tasks"""
def __init__(self, config: UnscrambleWordsConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
# Load and preprocess text
text = read_data_file("in_the_year_2889.txt")
# Extract words and filter by length
self.words = [
word for word in re.findall(r"\b\w+\b", text)
if self.config.min_word_len <= len(word) <= self.config.max_word_len
and word.isalpha()
]
def _scramble_word(self, word: str, corruption_level: float, rng: Random) -> str:
"""Scramble a word by swapping random pairs of characters"""
if len(word) < 2: # Can't scramble 1-character words
return word
word = list(word)
num_swaps = max(1, int(len(word) * corruption_level)) # Ensure at least one swap
for _ in range(num_swaps):
# Pick two different random positions
pos1, pos2 = rng.sample(range(len(word)), 2)
# Swap characters
word[pos1], word[pos2] = word[pos2], word[pos1]
return "".join(word)
def __getitem__(self, idx: int) -> dict:
"""Generate a single word unscrambling task"""
rng = Random(self.seed + idx)
# Select number of words and corruption level
num_words = rng.randint(self.config.min_words, self.config.max_words)
corruption_level = rng.uniform(self.config.min_corruption_level, self.config.max_corruption_level)
# Select words based on configuration
if self.config.consecutive_words:
# Select consecutive words from a random starting position
start_idx = rng.randint(0, len(self.words) - num_words)
selected_words = self.words[start_idx:start_idx + num_words]
else:
# Select random words
selected_words = rng.sample(self.words, num_words)
# Scramble each word
scrambled_words = [
self._scramble_word(word, corruption_level, rng)
for word in selected_words
]
return {
"question": f"Unscramble these words: {' '.join(scrambled_words)}",
"answer": " ".join(selected_words),
"metadata": {
"num_words": num_words,
"corruption_level": corruption_level,
"scrambled_words": scrambled_words,
"original_words": selected_words
}
}
register_dataset("unscramble_words", UnscrambleWordsDataset, UnscrambleWordsConfig)

View file

@ -1,120 +0,0 @@
"""Tests for word unscrambling task generation"""
import pytest
from random import Random
from reasoning_gym.algorithmic.unscramble_words import UnscrambleWordsConfig, UnscrambleWordsDataset
def test_unscramble_words_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = UnscrambleWordsConfig(min_word_len=0)
config.validate()
with pytest.raises(AssertionError):
config = UnscrambleWordsConfig(min_words=10, max_words=5)
config.validate()
with pytest.raises(AssertionError):
config = UnscrambleWordsConfig(min_corruption_level=-0.1)
config.validate()
with pytest.raises(AssertionError):
config = UnscrambleWordsConfig(max_corruption_level=1.1)
config.validate()
def test_unscramble_words_deterministic():
"""Test that dataset generates same items with same seed"""
config = UnscrambleWordsConfig(seed=42, size=10)
dataset1 = UnscrambleWordsDataset(config)
dataset2 = UnscrambleWordsDataset(config)
for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]
def test_unscramble_words_scrambling():
"""Test the word scrambling logic"""
config = UnscrambleWordsConfig(
min_word_len=4,
max_word_len=8,
min_words=1,
max_words=1,
min_corruption_level=0.5,
max_corruption_level=0.5,
size=1,
seed=42
)
dataset = UnscrambleWordsDataset(config)
# Test with known word
word = "testing"
rng = Random(42)
scrambled = dataset._scramble_word(word, 0.5, rng)
# Verify scrambled word:
# - Has same length as original
assert len(scrambled) == len(word)
# - Contains same characters
assert sorted(scrambled) == sorted(word)
# - Is different from original (with high probability given 0.5 corruption)
assert scrambled != word
def test_unscramble_words_dataset_items():
"""Test basic properties of generated items"""
config = UnscrambleWordsConfig(
min_word_len=4,
max_word_len=8,
min_words=3,
max_words=5,
min_corruption_level=0.1,
max_corruption_level=0.3,
size=50,
seed=42
)
dataset = UnscrambleWordsDataset(config)
for i in range(len(dataset)):
item = dataset[i]
# Check item structure
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Check metadata
metadata = item["metadata"]
assert "num_words" in metadata
assert "corruption_level" in metadata
assert "scrambled_words" in metadata
assert "original_words" in metadata
# Verify word counts
num_words = metadata["num_words"]
assert config.min_words <= num_words <= config.max_words
assert len(metadata["scrambled_words"]) == num_words
assert len(metadata["original_words"]) == num_words
# Verify corruption level
assert config.min_corruption_level <= metadata["corruption_level"] <= config.max_corruption_level
# Verify word properties
for word in metadata["original_words"]:
assert config.min_word_len <= len(word) <= config.max_word_len
assert word.isalpha()
def test_unscramble_words_iteration():
"""Test that iteration respects dataset size"""
config = UnscrambleWordsConfig(size=5, seed=42)
dataset = UnscrambleWordsDataset(config)
items = list(dataset)
assert len(items) == config.size
# Test multiple iterations yield same items
assert items == list(dataset)