diff --git a/README.md b/README.md index 06b71bfb..60952e1f 100644 --- a/README.md +++ b/README.md @@ -106,7 +106,7 @@ Available dataset names (which can be used with `create_dataset()`): - `NumberSortingDataset`: Sort lists of numbers in ascending or descending order - `LetterJumbleDataset`: Unscramble words that have had their letters randomly jumbled - `SpellBackwardDataset`: Spell individual words backward (e.g. "sun" -> "nus") -- `WordReversalDataset`: Reverse word order in text spans +- `WordSequenceReversalDataset`: Reverse word order in text spans #### Cognition Tasks diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index b9810fdc..7026cc7d 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -13,7 +13,7 @@ from .letter_jumble import LetterJumbleConfig, LetterJumbleDataset from .number_filtering import NumberFilteringConfig, NumberFilteringDataset from .number_sorting import NumberSortingConfig, NumberSortingDataset from .spell_backward import SpellBackwardConfig, SpellBackwardDataset -from .word_reversal import WordReversalConfig, WordReversalDataset +from .word_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset __all__ = [ "SpellBackwardConfig", @@ -30,6 +30,6 @@ __all__ = [ "NumberFilteringDataset", "NumberSortingConfig", "NumberSortingDataset", - "WordReversalConfig", - "WordReversalDataset", + "WordSequenceReversalConfig", + "WordSequenceReversalDataset", ] diff --git a/reasoning_gym/algorithmic/word_reversal.py b/reasoning_gym/algorithmic/word_reversal.py index b08b459d..ce9f273d 100644 --- a/reasoning_gym/algorithmic/word_reversal.py +++ b/reasoning_gym/algorithmic/word_reversal.py @@ -10,8 +10,8 @@ from ..factory import ProceduralDataset, register_dataset @dataclass -class WordReversalConfig: - """Configuration for word reversal task generation""" +class WordSequenceReversalConfig: + """Configuration for word sequence reversal task generation""" min_words: int = 3 # Minimum words in list max_words: int = 8 # Maximum words in list @@ -24,8 +24,8 @@ class WordReversalConfig: assert self.max_words >= self.min_words, "max_words must be >= min_words" -class WordReversalDataset(ProceduralDataset): - """Generates word reversal tasks from text spans""" +class WordSequenceReversalDataset(ProceduralDataset): + """Generates word sequence reversal tasks from text spans""" def __init__(self, config: WordReversalConfig): super().__init__(config=config, seed=config.seed, size=config.size) @@ -55,4 +55,4 @@ class WordReversalDataset(ProceduralDataset): } -register_dataset("word_reversal", WordReversalDataset, WordReversalConfig) +register_dataset("word_sequence_reversal", WordSequenceReversalDataset, WordSequenceReversalConfig) diff --git a/tests/test_word_reversal.py b/tests/test_word_reversal.py index 310f9cc1..bcf042ce 100644 --- a/tests/test_word_reversal.py +++ b/tests/test_word_reversal.py @@ -3,13 +3,13 @@ import pytest from reasoning_gym.algorithmic.spell_backward import SpellBackwardConfig, SpellBackwardDataset -from reasoning_gym.algorithmic.word_reversal import WordReversalConfig, WordReversalDataset +from reasoning_gym.algorithmic.word_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset -def test_word_reversal_config_validation(): +def test_word_sequence_reversal_config_validation(): """Test that invalid configs raise appropriate errors""" with pytest.raises(AssertionError): - config = WordReversalConfig(min_words=0) + config = WordSequenceReversalConfig(min_words=0) config.validate() with pytest.raises(AssertionError): @@ -17,20 +17,20 @@ def test_word_reversal_config_validation(): config.validate() -def test_word_reversal_dataset_deterministic(): +def test_word_sequence_reversal_dataset_deterministic(): """Test that dataset generates same items with same seed""" - config = WordReversalConfig(seed=42, size=10) - dataset1 = WordReversalDataset(config) - dataset2 = WordReversalDataset(config) + config = WordSequenceReversalConfig(seed=42, size=10) + dataset1 = WordSequenceReversalDataset(config) + dataset2 = WordSequenceReversalDataset(config) for i in range(len(dataset1)): assert dataset1[i] == dataset2[i] -def test_word_reversal_dataset_items(): +def test_word_sequence_reversal_dataset_items(): """Test basic properties of generated items""" - config = WordReversalConfig(min_words=3, max_words=6, size=10, seed=42) - dataset = WordReversalDataset(config) + config = WordSequenceReversalConfig(min_words=3, max_words=6, size=10, seed=42) + dataset = WordSequenceReversalDataset(config) for i in range(len(dataset)): item = dataset[i] @@ -109,10 +109,10 @@ def test_spell_backward_dataset_iteration(): assert items == list(dataset) -def test_word_reversal_dataset_iteration(): +def test_word_sequence_reversal_dataset_iteration(): """Test that iteration respects dataset size""" - config = WordReversalConfig(size=5, seed=42) - dataset = WordReversalDataset(config) + config = WordSequenceReversalConfig(size=5, seed=42) + dataset = WordSequenceReversalDataset(config) items = list(dataset) assert len(items) == config.size @@ -121,10 +121,10 @@ def test_word_reversal_dataset_iteration(): assert items == list(dataset) -def test_word_reversal_text_preprocessing(): +def test_word_sequence_reversal_text_preprocessing(): """Test that text preprocessing handles edge cases""" - config = WordReversalConfig(size=1, seed=42) - dataset = WordReversalDataset(config) + config = WordSequenceReversalConfig(size=1, seed=42) + dataset = WordSequenceReversalDataset(config) # Verify words were extracted from text assert len(dataset.words) > 0