diff --git a/reasoning_gym/algorithmic/unscramble_words.py b/reasoning_gym/algorithmic/unscramble_words.py deleted file mode 100644 index 0dd4521d..00000000 --- a/reasoning_gym/algorithmic/unscramble_words.py +++ /dev/null @@ -1,103 +0,0 @@ -"""Word unscrambling task generator""" - -import re -from dataclasses import dataclass -from random import Random -from typing import List, Optional - -from reasoning_gym.data import read_data_file -from ..factory import ProceduralDataset, register_dataset - - -@dataclass -class UnscrambleWordsConfig: - """Configuration for word unscrambling task generation""" - - min_word_len: int = 4 # Minimum word length - max_word_len: int = 64 # Maximum word length - min_words: int = 3 # Minimum words per task - max_words: int = 20 # Maximum words per task - min_corruption_level: float = 0.1 # Minimum fraction of characters to swap - max_corruption_level: float = 0.9 # Maximum fraction of characters to swap - consecutive_words: bool = True # Whether to select consecutive words from text - seed: Optional[int] = None - size: int = 500 # Virtual dataset size - - def validate(self) -> None: - """Validate configuration parameters""" - assert self.min_word_len > 0, "min_word_len must be positive" - assert self.max_word_len >= self.min_word_len, "max_word_len must be >= min_word_len" - assert self.min_words > 0, "min_words must be positive" - assert self.max_words >= self.min_words, "max_words must be >= min_words" - assert 0 <= self.min_corruption_level <= 1, "min_corruption_level must be in [0,1]" - assert 0 <= self.max_corruption_level <= 1, "max_corruption_level must be in [0,1]" - assert self.max_corruption_level >= self.min_corruption_level, "max_corruption_level must be >= min_corruption_level" - - -class UnscrambleWordsDataset(ProceduralDataset): - """Generates word unscrambling tasks""" - - def __init__(self, config: UnscrambleWordsConfig): - super().__init__(config=config, seed=config.seed, size=config.size) - - # Load and preprocess text - text = read_data_file("in_the_year_2889.txt") - # Extract words and filter by length - self.words = [ - word for word in re.findall(r"\b\w+\b", text) - if self.config.min_word_len <= len(word) <= self.config.max_word_len - and word.isalpha() - ] - - def _scramble_word(self, word: str, corruption_level: float, rng: Random) -> str: - """Scramble a word by swapping random pairs of characters""" - if len(word) < 2: # Can't scramble 1-character words - return word - - word = list(word) - num_swaps = max(1, int(len(word) * corruption_level)) # Ensure at least one swap - - for _ in range(num_swaps): - # Pick two different random positions - pos1, pos2 = rng.sample(range(len(word)), 2) - # Swap characters - word[pos1], word[pos2] = word[pos2], word[pos1] - - return "".join(word) - - def __getitem__(self, idx: int) -> dict: - """Generate a single word unscrambling task""" - rng = Random(self.seed + idx) - - # Select number of words and corruption level - num_words = rng.randint(self.config.min_words, self.config.max_words) - corruption_level = rng.uniform(self.config.min_corruption_level, self.config.max_corruption_level) - - # Select words based on configuration - if self.config.consecutive_words: - # Select consecutive words from a random starting position - start_idx = rng.randint(0, len(self.words) - num_words) - selected_words = self.words[start_idx:start_idx + num_words] - else: - # Select random words - selected_words = rng.sample(self.words, num_words) - - # Scramble each word - scrambled_words = [ - self._scramble_word(word, corruption_level, rng) - for word in selected_words - ] - - return { - "question": f"Unscramble these words: {' '.join(scrambled_words)}", - "answer": " ".join(selected_words), - "metadata": { - "num_words": num_words, - "corruption_level": corruption_level, - "scrambled_words": scrambled_words, - "original_words": selected_words - } - } - - -register_dataset("unscramble_words", UnscrambleWordsDataset, UnscrambleWordsConfig) diff --git a/tests/test_unscramble_words.py b/tests/test_unscramble_words.py deleted file mode 100644 index 6ca18b1c..00000000 --- a/tests/test_unscramble_words.py +++ /dev/null @@ -1,120 +0,0 @@ -"""Tests for word unscrambling task generation""" - -import pytest -from random import Random - -from reasoning_gym.algorithmic.unscramble_words import UnscrambleWordsConfig, UnscrambleWordsDataset - - -def test_unscramble_words_config_validation(): - """Test that invalid configs raise appropriate errors""" - with pytest.raises(AssertionError): - config = UnscrambleWordsConfig(min_word_len=0) - config.validate() - - with pytest.raises(AssertionError): - config = UnscrambleWordsConfig(min_words=10, max_words=5) - config.validate() - - with pytest.raises(AssertionError): - config = UnscrambleWordsConfig(min_corruption_level=-0.1) - config.validate() - - with pytest.raises(AssertionError): - config = UnscrambleWordsConfig(max_corruption_level=1.1) - config.validate() - - -def test_unscramble_words_deterministic(): - """Test that dataset generates same items with same seed""" - config = UnscrambleWordsConfig(seed=42, size=10) - dataset1 = UnscrambleWordsDataset(config) - dataset2 = UnscrambleWordsDataset(config) - - for i in range(len(dataset1)): - assert dataset1[i] == dataset2[i] - - -def test_unscramble_words_scrambling(): - """Test the word scrambling logic""" - config = UnscrambleWordsConfig( - min_word_len=4, - max_word_len=8, - min_words=1, - max_words=1, - min_corruption_level=0.5, - max_corruption_level=0.5, - size=1, - seed=42 - ) - dataset = UnscrambleWordsDataset(config) - - # Test with known word - word = "testing" - rng = Random(42) - scrambled = dataset._scramble_word(word, 0.5, rng) - - # Verify scrambled word: - # - Has same length as original - assert len(scrambled) == len(word) - # - Contains same characters - assert sorted(scrambled) == sorted(word) - # - Is different from original (with high probability given 0.5 corruption) - assert scrambled != word - - -def test_unscramble_words_dataset_items(): - """Test basic properties of generated items""" - config = UnscrambleWordsConfig( - min_word_len=4, - max_word_len=8, - min_words=3, - max_words=5, - min_corruption_level=0.1, - max_corruption_level=0.3, - size=50, - seed=42 - ) - dataset = UnscrambleWordsDataset(config) - - for i in range(len(dataset)): - item = dataset[i] - - # Check item structure - assert isinstance(item, dict) - assert "question" in item - assert "answer" in item - assert "metadata" in item - - # Check metadata - metadata = item["metadata"] - assert "num_words" in metadata - assert "corruption_level" in metadata - assert "scrambled_words" in metadata - assert "original_words" in metadata - - # Verify word counts - num_words = metadata["num_words"] - assert config.min_words <= num_words <= config.max_words - assert len(metadata["scrambled_words"]) == num_words - assert len(metadata["original_words"]) == num_words - - # Verify corruption level - assert config.min_corruption_level <= metadata["corruption_level"] <= config.max_corruption_level - - # Verify word properties - for word in metadata["original_words"]: - assert config.min_word_len <= len(word) <= config.max_word_len - assert word.isalpha() - - -def test_unscramble_words_iteration(): - """Test that iteration respects dataset size""" - config = UnscrambleWordsConfig(size=5, seed=42) - dataset = UnscrambleWordsDataset(config) - - items = list(dataset) - assert len(items) == config.size - - # Test multiple iterations yield same items - assert items == list(dataset)