diff --git a/README.md b/README.md index dae6975a..1426cac9 100644 --- a/README.md +++ b/README.md @@ -81,6 +81,7 @@ Available dataset names (which can be used with `create_dataset()`): - `NumberFilteringDataset`: Filter numbers based on comparison with threshold - `NumberSortingDataset`: Sort lists of numbers in ascending or descending order - `LetterJumbleDataset`: Unscramble words that have had their letters randomly jumbled +- `SentenceReorderingDataset`: Reorder sentence after words it in have been randomly shuffled - `WordReversalDataset`: Reverse word order in text spans #### Cognition Tasks diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 78136d66..5eb551f0 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -13,6 +13,7 @@ from .letter_jumble import LetterJumbleConfig, LetterJumbleDataset from .number_filtering import NumberFilteringConfig, NumberFilteringDataset from .number_sorting import NumberSortingConfig, NumberSortingDataset from .word_reversal import WordReversalConfig, WordReversalDataset +from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset __all__ = [ "BaseConversionConfig", @@ -29,4 +30,6 @@ __all__ = [ "NumberSortingDataset", "WordReversalConfig", "WordReversalDataset", + "SentenceReorderingConfig", + "SentenceReorderingDataset", ] diff --git a/reasoning_gym/algorithmic/sentence_reordering.py b/reasoning_gym/algorithmic/sentence_reordering.py new file mode 100644 index 00000000..36043c8e --- /dev/null +++ b/reasoning_gym/algorithmic/sentence_reordering.py @@ -0,0 +1,79 @@ +"""Sentence re-ordering task generator""" + +import re +from dataclasses import dataclass +from random import Random +from typing import List, Optional + +from ..data import read_data_file +from ..factory import ProceduralDataset, register_dataset + +@dataclass +class SentenceReorderingConfig: + """Configuration for sentence reordering task generation""" + num_of_words_in_sentence: int = 10 + seed: Optional[int] = None + size: int = 500 + + def validate(self) -> None: + """Validate configuration parameters""" + pass + + +class SentenceReorderingDataset(ProceduralDataset): + """Generates sentence reordering tasks from text spans""" + + def __init__(self, config: SentenceReorderingConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + + # Load and preprocess text + text = read_data_file("in_the_year_2889.txt") + # Extract sentences make sure they are greater than or equal to the number of words in a sentence + self.sentences = [ + sentence + for sentence in re.findall(r"[^.!?]+", text) + if len(sentence.split()) >= self.config.num_of_words_in_sentence + ] + + def _generate_sentence_dataset(self, sentence: str, seed: int, idx: int, shuffle=True): + """ + Generate a procedural dataset by shuffling the words in the input sentence. + + Args: + sentence (str): The correct sentence to use for dataset generation. + shuffle (bool): Whether to shuffle the words to create the input sentence. + + Returns: + dict: A dictionary containing the input sentence and the correct sentence (goal). + """ + rng = Random(seed + idx) + words = sentence.split() # Split the sentence into words + scrambled_words = words.copy() + if shuffle: + rng.shuffle(scrambled_words) # Shuffle the words to generate the input + input_sentence = " ".join(scrambled_words) + goal_sentence = " ".join(words) + return {"input": input_sentence, "goal": goal_sentence} + + def __getitem__(self, idx: int) -> dict: + """Generate a single sentence reordering task""" + rng = Random(self.seed + idx) + sentence_dataset = self._generate_sentence_dataset(rng.choice(self.sentences), self.seed, idx) + + # Ensure only 'input' and 'goal' keys are present + if set(sentence_dataset.keys()) != {'input', 'goal'}: + raise KeyError("The dictionary must contain only 'input' and 'goal' keys") + + # Solve the task by sorting words to match the goal sentence + input_words = sentence_dataset['input'].split() + question = " ".join(input_words) + goal_words = sentence_dataset['goal'].split() + solved_sentence = " ".join(sorted(input_words, key=lambda word: goal_words.index(word))) + + return { + "question": f"Correct the following sentence: {question}", + "answer": solved_sentence, + "metadata": {"num_of_words_in_sentence": len(goal_words)}, + } + +register_dataset("sentence_reordering", SentenceReorderingDataset, SentenceReorderingConfig) diff --git a/tests/test_sentence_reordering.py b/tests/test_sentence_reordering.py new file mode 100644 index 00000000..8ebcc3d1 --- /dev/null +++ b/tests/test_sentence_reordering.py @@ -0,0 +1,45 @@ +import pytest +from reasoning_gym.algorithmic.sentence_reordering import ( + SentenceReorderingConfig, + SentenceReorderingDataset, +) + +@pytest.fixture +def config(): + return SentenceReorderingConfig(num_of_words_in_sentence=5, seed=42, size=10) + +@pytest.fixture +def dataset(config): + return SentenceReorderingDataset(config=config) + +def test_config_validation(config): + # Test that the config validation does not raise any exceptions + try: + config.validate() + except Exception as e: + pytest.fail(f"Config validation raised an exception: {e}") + +def test_generate_sentence_dataset(dataset): + sentence = "This is a test sentence for reordering" + result = dataset._generate_sentence_dataset(sentence, seed=42, idx=0, shuffle=True) + assert "input" in result + assert "goal" in result + assert result["input"] != result["goal"] + assert sorted(result["input"].split()) == sorted(result["goal"].split()) + +def test_getitem(dataset, config): + item = dataset[0] + assert "question" in item + assert "answer" in item + assert "metadata" in item + assert item["metadata"]["num_of_words_in_sentence"] >= config.num_of_words_in_sentence + +def test_key_error_in_getitem(dataset): + # Modify the dataset to include an incorrect key + def mock_generate_sentence_dataset(*args, **kwargs): + return {"input": "mock input", "goal": "mock goal", "extra": "extra key"} + + dataset._generate_sentence_dataset = mock_generate_sentence_dataset + + with pytest.raises(KeyError): + dataset[0] \ No newline at end of file