diff --git a/README.md b/README.md index 12997eb4..d32fb5de 100644 --- a/README.md +++ b/README.md @@ -80,6 +80,7 @@ Available dataset names (which can be used with `create_dataset()`): - `LetterCountingDataset`: Count letter occurrences in text spans - `NumberFilteringDataset`: Filter numbers based on comparison with threshold - `NumberSortingDataset`: Sort lists of numbers in ascending or descending order +- `UnscrambleWordsDataset`: Unscramble words that have had their characters randomly swapped - `WordReversalDataset`: Reverse word order in text spans #### Cognition Tasks diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index bf26d4d3..acd4b086 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -11,6 +11,7 @@ from .caesar_cipher import CaesarCipherConfig, CaesarCipherDataset from .letter_counting import LetterCountingConfig, LetterCountingDataset from .number_filtering import NumberFilteringConfig, NumberFilteringDataset from .number_sorting import NumberSortingConfig, NumberSortingDataset +from .unscramble_words import UnscrambleWordsConfig, UnscrambleWordsDataset from .word_reversal import WordReversalConfig, WordReversalDataset __all__ = [ @@ -24,6 +25,8 @@ __all__ = [ "NumberFilteringDataset", "NumberSortingConfig", "NumberSortingDataset", + "UnscrambleWordsConfig", + "UnscrambleWordsDataset", "WordReversalConfig", "WordReversalDataset", ] diff --git a/reasoning_gym/algorithmic/unscramble_words.py b/reasoning_gym/algorithmic/unscramble_words.py new file mode 100644 index 00000000..87febd7f --- /dev/null +++ b/reasoning_gym/algorithmic/unscramble_words.py @@ -0,0 +1,96 @@ +"""Word unscrambling task generator""" + +import re +from dataclasses import dataclass +from random import Random +from typing import List, Optional + +from reasoning_gym.data import read_data_file +from ..factory import ProceduralDataset, register_dataset + + +@dataclass +class UnscrambleWordsConfig: + """Configuration for word unscrambling task generation""" + + min_word_len: int = 4 # Minimum word length + max_word_len: int = 64 # Maximum word length + min_words: int = 3 # Minimum words per task + max_words: int = 20 # Maximum words per task + min_corruption_level: float = 0.1 # Minimum fraction of characters to swap + max_corruption_level: float = 0.9 # Maximum fraction of characters to swap + seed: Optional[int] = None + size: int = 500 # Virtual dataset size + + def validate(self) -> None: + """Validate configuration parameters""" + assert self.min_word_len > 0, "min_word_len must be positive" + assert self.max_word_len >= self.min_word_len, "max_word_len must be >= min_word_len" + assert self.min_words > 0, "min_words must be positive" + assert self.max_words >= self.min_words, "max_words must be >= min_words" + assert 0 <= self.min_corruption_level <= 1, "min_corruption_level must be in [0,1]" + assert 0 <= self.max_corruption_level <= 1, "max_corruption_level must be in [0,1]" + assert self.max_corruption_level >= self.min_corruption_level, "max_corruption_level must be >= min_corruption_level" + + +class UnscrambleWordsDataset(ProceduralDataset): + """Generates word unscrambling tasks""" + + def __init__(self, config: UnscrambleWordsConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + + # Load and preprocess text + text = read_data_file("in_the_year_2889.txt") + # Extract words and filter by length + self.words = [ + word for word in re.findall(r"\b\w+\b", text) + if self.config.min_word_len <= len(word) <= self.config.max_word_len + and word.isalpha() + ] + + def _scramble_word(self, word: str, corruption_level: float, rng: Random) -> str: + """Scramble a word by swapping random pairs of characters""" + if len(word) < 2: # Can't scramble 1-character words + return word + + word = list(word) + num_swaps = int(len(word) * corruption_level) + + for _ in range(num_swaps): + # Pick two different random positions + pos1, pos2 = rng.sample(range(len(word)), 2) + # Swap characters + word[pos1], word[pos2] = word[pos2], word[pos1] + + return "".join(word) + + def __getitem__(self, idx: int) -> dict: + """Generate a single word unscrambling task""" + rng = Random(self.seed + idx) + + # Select number of words and corruption level + num_words = rng.randint(self.config.min_words, self.config.max_words) + corruption_level = rng.uniform(self.config.min_corruption_level, self.config.max_corruption_level) + + # Select random words + selected_words = rng.sample(self.words, num_words) + + # Scramble each word + scrambled_words = [ + self._scramble_word(word, corruption_level, rng) + for word in selected_words + ] + + return { + "question": f"Unscramble these words: {' '.join(scrambled_words)}", + "answer": " ".join(selected_words), + "metadata": { + "num_words": num_words, + "corruption_level": corruption_level, + "scrambled_words": scrambled_words, + "original_words": selected_words + } + } + + +register_dataset("unscramble_words", UnscrambleWordsDataset, UnscrambleWordsConfig) diff --git a/tests/test_unscramble_words.py b/tests/test_unscramble_words.py new file mode 100644 index 00000000..a69ae752 --- /dev/null +++ b/tests/test_unscramble_words.py @@ -0,0 +1,119 @@ +"""Tests for word unscrambling task generation""" + +import pytest + +from reasoning_gym.algorithmic.unscramble_words import UnscrambleWordsConfig, UnscrambleWordsDataset + + +def test_unscramble_words_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = UnscrambleWordsConfig(min_word_len=0) + config.validate() + + with pytest.raises(AssertionError): + config = UnscrambleWordsConfig(min_words=10, max_words=5) + config.validate() + + with pytest.raises(AssertionError): + config = UnscrambleWordsConfig(min_corruption_level=-0.1) + config.validate() + + with pytest.raises(AssertionError): + config = UnscrambleWordsConfig(max_corruption_level=1.1) + config.validate() + + +def test_unscramble_words_deterministic(): + """Test that dataset generates same items with same seed""" + config = UnscrambleWordsConfig(seed=42, size=10) + dataset1 = UnscrambleWordsDataset(config) + dataset2 = UnscrambleWordsDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_unscramble_words_scrambling(): + """Test the word scrambling logic""" + config = UnscrambleWordsConfig( + min_word_len=4, + max_word_len=8, + min_words=1, + max_words=1, + min_corruption_level=0.5, + max_corruption_level=0.5, + size=1, + seed=42 + ) + dataset = UnscrambleWordsDataset(config) + + # Test with known word + word = "testing" + rng = Random(42) + scrambled = dataset._scramble_word(word, 0.5, rng) + + # Verify scrambled word: + # - Has same length as original + assert len(scrambled) == len(word) + # - Contains same characters + assert sorted(scrambled) == sorted(word) + # - Is different from original (with high probability given 0.5 corruption) + assert scrambled != word + + +def test_unscramble_words_dataset_items(): + """Test basic properties of generated items""" + config = UnscrambleWordsConfig( + min_word_len=4, + max_word_len=8, + min_words=3, + max_words=5, + min_corruption_level=0.1, + max_corruption_level=0.3, + size=50, + seed=42 + ) + dataset = UnscrambleWordsDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + + # Check item structure + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata + metadata = item["metadata"] + assert "num_words" in metadata + assert "corruption_level" in metadata + assert "scrambled_words" in metadata + assert "original_words" in metadata + + # Verify word counts + num_words = metadata["num_words"] + assert config.min_words <= num_words <= config.max_words + assert len(metadata["scrambled_words"]) == num_words + assert len(metadata["original_words"]) == num_words + + # Verify corruption level + assert config.min_corruption_level <= metadata["corruption_level"] <= config.max_corruption_level + + # Verify word properties + for word in metadata["original_words"]: + assert config.min_word_len <= len(word) <= config.max_word_len + assert word.isalpha() + + +def test_unscramble_words_iteration(): + """Test that iteration respects dataset size""" + config = UnscrambleWordsConfig(size=5, seed=42) + dataset = UnscrambleWordsDataset(config) + + items = list(dataset) + assert len(items) == config.size + + # Test multiple iterations yield same items + assert items == list(dataset)