diff --git a/reasoning_gym/algorithmic/word_sorting.py b/reasoning_gym/algorithmic/word_sorting.py new file mode 100644 index 00000000..1218f447 --- /dev/null +++ b/reasoning_gym/algorithmic/word_sorting.py @@ -0,0 +1,104 @@ +"""Word sorting task generator""" + +from dataclasses import dataclass +from enum import Enum +from random import Random +from typing import List, Optional, Tuple + +from ..data import read_data_file +from ..factory import ProceduralDataset, register_dataset + + +class TextTransformation(str, Enum): + """Text transformation options""" + LOWERCASE = "lowercase" + UPPERCASE = "uppercase" + ORIGINAL = "original" + RANDOMCASE = "randomcase" + + +@dataclass +class WordSortingConfig: + """Configuration for word sorting task generation""" + min_words: int = 3 # Minimum words to sort + max_words: int = 10 # Maximum words to sort + min_word_length: int = 3 # Minimum word length + max_word_length: int = 12 # Maximum word length + transformation: TextTransformation = TextTransformation.ORIGINAL + seed: Optional[int] = None + size: int = 500 # Virtual dataset size + + def validate(self) -> None: + """Validate configuration parameters""" + assert self.min_words > 0, "min_words must be positive" + assert self.min_words <= self.max_words, "max_words must be >= min_words" + assert self.min_word_length > 0, "min_word_length must be positive" + assert self.min_word_length <= self.max_word_length, "max_word_length must be >= min_word_length" + assert isinstance(self.transformation, TextTransformation), "transformation must be a TextTransformation" + + +class WordSortingDataset(ProceduralDataset): + """Generates word sorting tasks""" + + def __init__(self, config: WordSortingConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + + # Load and preprocess text + text = read_data_file("in_the_year_2889.txt") + # Extract unique words within length constraints + self.words = list(set( + word for word in re.findall(r'\b\w+\b', text) + if self.config.min_word_length <= len(word) <= self.config.max_word_length + )) + + def _transform_word(self, word: str, rng: Random) -> str: + """Apply configured transformation to word""" + if self.config.transformation == TextTransformation.LOWERCASE: + return word.lower() + elif self.config.transformation == TextTransformation.UPPERCASE: + return word.upper() + elif self.config.transformation == TextTransformation.RANDOMCASE: + return ''.join(c.upper() if rng.choice([True, False]) else c.lower() + for c in word) + return word # ORIGINAL case + + def _generate_words(self, rng: Random) -> Tuple[List[str], List[str]]: + """Generate list of words and their transformed versions""" + count = rng.randint(self.config.min_words, self.config.max_words) + + # Select random words + selected_words = rng.sample(self.words, count) + # Apply transformation + transformed_words = [self._transform_word(word, rng) for word in selected_words] + + return selected_words, transformed_words + + def __getitem__(self, idx: int) -> dict: + """Generate a single sorting task""" + rng = Random(self.seed + idx) + + original_words, transformed_words = self._generate_words(rng) + + # Generate both ascending and descending answers + asc_words = sorted(transformed_words) + desc_words = sorted(transformed_words, reverse=True) + + # Randomly choose ascending or descending + is_ascending = rng.choice([True, False]) + direction = "ascending" if is_ascending else "descending" + answer = asc_words if is_ascending else desc_words + + return { + "question": f"Sort these words in {direction} order: {', '.join(transformed_words)}", + "answer": str(answer), + "metadata": { + "original_words": original_words, + "transformed_words": transformed_words, + "direction": direction, + "transformation": self.config.transformation, + "sorted_words": answer + }, + } + + +register_dataset("word_sorting", WordSortingDataset, WordSortingConfig) diff --git a/tests/test_word_sorting.py b/tests/test_word_sorting.py new file mode 100644 index 00000000..386e5d12 --- /dev/null +++ b/tests/test_word_sorting.py @@ -0,0 +1,118 @@ +"""Tests for word sorting task generation""" + +import pytest + +from reasoning_gym.algorithmic.word_sorting import WordSortingConfig, WordSortingDataset, TextTransformation + + +def test_word_sorting_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = WordSortingConfig(min_words=0) + config.validate() + + with pytest.raises(AssertionError): + config = WordSortingConfig(min_words=10, max_words=5) + config.validate() + + with pytest.raises(AssertionError): + config = WordSortingConfig(min_word_length=0) + config.validate() + + with pytest.raises(AssertionError): + config = WordSortingConfig(min_word_length=10, max_word_length=5) + config.validate() + + +def test_word_sorting_dataset_deterministic(): + """Test that dataset generates same items with same seed""" + config = WordSortingConfig(seed=42, size=10) + dataset1 = WordSortingDataset(config) + dataset2 = WordSortingDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_word_sorting_transformations(): + """Test different text transformations""" + seed = 42 + size = 5 + + # Test LOWERCASE + config = WordSortingConfig(transformation=TextTransformation.LOWERCASE, seed=seed, size=size) + dataset = WordSortingDataset(config) + for item in dataset: + for word in item["metadata"]["transformed_words"]: + assert word.islower() + + # Test UPPERCASE + config = WordSortingConfig(transformation=TextTransformation.UPPERCASE, seed=seed, size=size) + dataset = WordSortingDataset(config) + for item in dataset: + for word in item["metadata"]["transformed_words"]: + assert word.isupper() + + # Test ORIGINAL + config = WordSortingConfig(transformation=TextTransformation.ORIGINAL, seed=seed, size=size) + dataset = WordSortingDataset(config) + for item in dataset: + assert item["metadata"]["original_words"] == item["metadata"]["transformed_words"] + + +def test_word_sorting_dataset_items(): + """Test basic properties of generated items""" + config = WordSortingConfig( + min_words=3, + max_words=6, + min_word_length=3, + max_word_length=8, + size=10, + seed=42 + ) + dataset = WordSortingDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + # Check item structure + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata + assert "original_words" in item["metadata"] + assert "transformed_words" in item["metadata"] + assert "direction" in item["metadata"] + assert "transformation" in item["metadata"] + assert "sorted_words" in item["metadata"] + + # Verify word count constraints + words = item["metadata"]["transformed_words"] + assert len(words) >= config.min_words + assert len(words) <= config.max_words + + # Verify word length constraints + for word in words: + assert len(word) >= config.min_word_length + assert len(word) <= config.max_word_length + + # Verify sorting + direction = item["metadata"]["direction"] + sorted_words = eval(item["answer"]) + if direction == "ascending": + assert sorted_words == sorted(sorted_words) + else: + assert sorted_words == sorted(sorted_words, reverse=True) + + +def test_word_sorting_dataset_iteration(): + """Test that iteration respects dataset size""" + config = WordSortingConfig(size=5, seed=42) + dataset = WordSortingDataset(config) + + items = list(dataset) + assert len(items) == config.size + + # Test multiple iterations yield same items + assert items == list(dataset)