diff --git a/README.md b/README.md index 60952e1f..db866593 100644 --- a/README.md +++ b/README.md @@ -104,7 +104,7 @@ Available dataset names (which can be used with `create_dataset()`): - `LetterCountingDataset`: Count letter occurrences in text spans - `NumberFilteringDataset`: Filter numbers based on comparison with threshold - `NumberSortingDataset`: Sort lists of numbers in ascending or descending order -- `LetterJumbleDataset`: Unscramble words that have had their letters randomly jumbled +- `LetterJumbleDataset`: Unscramble words that have had their letters randomly jumbled - `SpellBackwardDataset`: Spell individual words backward (e.g. "sun" -> "nus") - `WordSequenceReversalDataset`: Reverse word order in text spans diff --git a/reasoning_gym/algorithmic/spell_backward.py b/reasoning_gym/algorithmic/spell_backward.py index 83dc0dc0..59b163ee 100644 --- a/reasoning_gym/algorithmic/spell_backward.py +++ b/reasoning_gym/algorithmic/spell_backward.py @@ -9,12 +9,12 @@ from ..data import read_data_file from ..factory import ProceduralDataset, register_dataset -@dataclass +@dataclass class SpellBackwardConfig: """Configuration for spelling words backward task generation""" - + min_word_len: int = 3 # Minimum word length - seed: Optional[int] = None + seed: Optional[int] = None size: int = 500 # Virtual dataset size def validate(self) -> None: @@ -31,8 +31,9 @@ class SpellBackwardDataset(ProceduralDataset): # Load and preprocess text text = read_data_file("in_the_year_2889.txt") # Extract words and clean them to contain only alphanumeric characters - self.words = [word for word in re.findall(r"\b\w+\b", text) - if word.isalnum() and len(word) >= config.min_word_len] + self.words = [ + word for word in re.findall(r"\b\w+\b", text) if word.isalnum() and len(word) >= config.min_word_len + ] def __getitem__(self, idx: int) -> dict: """Generate a single spell backward task""" diff --git a/reasoning_gym/algorithmic/word_reversal.py b/reasoning_gym/algorithmic/word_sequence_reversal.py similarity index 100% rename from reasoning_gym/algorithmic/word_reversal.py rename to reasoning_gym/algorithmic/word_sequence_reversal.py diff --git a/tests/test_spell_backward.py b/tests/test_spell_backward.py new file mode 100644 index 00000000..2db86c62 --- /dev/null +++ b/tests/test_spell_backward.py @@ -0,0 +1,59 @@ +"""Tests for spell backward task generation""" + +import pytest + +from reasoning_gym.algorithmic.spell_backward import SpellBackwardConfig, SpellBackwardDataset + + +def test_spell_backward_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = SpellBackwardConfig(min_word_len=0) + config.validate() + + +def test_spell_backward_dataset_deterministic(): + """Test that dataset generates same items with same seed""" + config = SpellBackwardConfig(seed=42, size=10) + dataset1 = SpellBackwardDataset(config) + dataset2 = SpellBackwardDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_spell_backward_dataset_items(): + """Test basic properties of generated items""" + config = SpellBackwardConfig(min_word_len=3, size=10, seed=42) + dataset = SpellBackwardDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + # Check item structure + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata + assert "word" in item["metadata"] + assert "word_len" in item["metadata"] + + # Verify word length constraint + word = item["metadata"]["word"] + assert len(word) >= config.min_word_len + + # Verify answer is correct + assert item["answer"] == word[::-1] + + +def test_spell_backward_dataset_iteration(): + """Test that iteration respects dataset size""" + config = SpellBackwardConfig(size=5, seed=42) + dataset = SpellBackwardDataset(config) + + items = list(dataset) + assert len(items) == config.size + + # Test multiple iterations yield same items + assert items == list(dataset) diff --git a/tests/test_word_reversal.py b/tests/test_word_sequence_reversal.py similarity index 57% rename from tests/test_word_reversal.py rename to tests/test_word_sequence_reversal.py index 15680223..a117bfba 100644 --- a/tests/test_word_reversal.py +++ b/tests/test_word_sequence_reversal.py @@ -1,9 +1,6 @@ -"""Tests for word reversal task generation""" - import pytest -from reasoning_gym.algorithmic.spell_backward import SpellBackwardConfig, SpellBackwardDataset -from reasoning_gym.algorithmic.word_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset +from reasoning_gym.algorithmic.word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset def test_word_sequence_reversal_config_validation(): @@ -55,60 +52,6 @@ def test_word_sequence_reversal_dataset_items(): assert answer_words == list(reversed(question_words)) -def test_spell_backward_config_validation(): - """Test that invalid configs raise appropriate errors""" - with pytest.raises(AssertionError): - config = SpellBackwardConfig(min_word_len=0) - config.validate() - - -def test_spell_backward_dataset_deterministic(): - """Test that dataset generates same items with same seed""" - config = SpellBackwardConfig(seed=42, size=10) - dataset1 = SpellBackwardDataset(config) - dataset2 = SpellBackwardDataset(config) - - for i in range(len(dataset1)): - assert dataset1[i] == dataset2[i] - - -def test_spell_backward_dataset_items(): - """Test basic properties of generated items""" - config = SpellBackwardConfig(min_word_len=3, size=10, seed=42) - dataset = SpellBackwardDataset(config) - - for i in range(len(dataset)): - item = dataset[i] - # Check item structure - assert isinstance(item, dict) - assert "question" in item - assert "answer" in item - assert "metadata" in item - - # Check metadata - assert "word" in item["metadata"] - assert "word_len" in item["metadata"] - - # Verify word length constraint - word = item["metadata"]["word"] - assert len(word) >= config.min_word_len - - # Verify answer is correct - assert item["answer"] == word[::-1] - - -def test_spell_backward_dataset_iteration(): - """Test that iteration respects dataset size""" - config = SpellBackwardConfig(size=5, seed=42) - dataset = SpellBackwardDataset(config) - - items = list(dataset) - assert len(items) == config.size - - # Test multiple iterations yield same items - assert items == list(dataset) - - def test_word_sequence_reversal_dataset_iteration(): """Test that iteration respects dataset size""" config = WordSequenceReversalConfig(size=5, seed=42)