diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 6fe3cc9a..3ebebe6e 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -34,7 +34,7 @@ from .pool_matrix import PoolMatrixConfig, PoolMatrixCurriculum, PoolMatrixDatas from .ransom_note import RansomNoteConfig, RansomNoteCurriculum, RansomNoteDataset from .rotate_matrix import RotateMatrixConfig, RotateMatrixCurriculum, RotateMatrixDataset from .rotten_oranges import RottenOrangesConfig, RottenOrangesCurriculum, RottenOrangesDataset -from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset +from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingCurriculum, SentenceReorderingDataset from .spell_backward import SpellBackwardConfig, SpellBackwardDataset from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixCurriculum, SpiralMatrixDataset from .string_insertion import StringInsertionConfig, StringInsertionCurriculum, StringInsertionDataset @@ -77,6 +77,7 @@ __all__ = [ "NumberSortingCurriculum", "SentenceReorderingConfig", "SentenceReorderingDataset", + "SentenceReorderingCurriculum", "WordSequenceReversalConfig", "WordSequenceReversalDataset", "WordSequenceReversalCurriculum", diff --git a/reasoning_gym/algorithmic/sentence_reordering.py b/reasoning_gym/algorithmic/sentence_reordering.py index 0cfbaaaf..3cc697d4 100644 --- a/reasoning_gym/algorithmic/sentence_reordering.py +++ b/reasoning_gym/algorithmic/sentence_reordering.py @@ -5,6 +5,7 @@ from dataclasses import dataclass from random import Random from typing import Any, Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..data import read_data_file from ..factory import ProceduralDataset, register_dataset @@ -89,7 +90,7 @@ class SentenceReorderingDataset(ProceduralDataset): return { "question": f"Restore the correct order of words in the following sentence: {question}", "answer": solved_sentence, - "metadata": {"word_count": word_count}, + "metadata": {"word_count": word_count, "difficulty": {"words_in_sentence": word_count}}, } def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: @@ -114,4 +115,25 @@ class SentenceReorderingDataset(ProceduralDataset): return reward -register_dataset("sentence_reordering", SentenceReorderingDataset, SentenceReorderingConfig) +class SentenceReorderingCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(SentenceReorderingCurriculum.__name__, SentenceReorderingConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="words_in_sentence", + levels=[5, 20, 50, 100], + default_level=1, + description="Number of words in the sentence", + attr_type=AttributeType.APPEND, + min_value=3, + lower_field_name="min_words_in_sentence", + upper_field_name="max_words_in_sentence", + ), + ) + + +register_dataset( + "sentence_reordering", SentenceReorderingDataset, SentenceReorderingConfig, SentenceReorderingCurriculum +) diff --git a/tests/test_sentence_reordering.py b/tests/test_sentence_reordering.py index 9348ec04..cf55cc76 100644 --- a/tests/test_sentence_reordering.py +++ b/tests/test_sentence_reordering.py @@ -1,6 +1,10 @@ import pytest -from reasoning_gym.algorithmic.sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset +from reasoning_gym.algorithmic.sentence_reordering import ( + SentenceReorderingConfig, + SentenceReorderingCurriculum, + SentenceReorderingDataset, +) @pytest.fixture @@ -49,3 +53,24 @@ def test_key_error_in_getitem(dataset): with pytest.raises(KeyError): dataset[0] + + +def test_sentence_reordering_curriculum(): + curriculum = SentenceReorderingCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: SentenceReorderingConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_words_in_sentence == 5 and base_cfg.max_words_in_sentence == 20 + + # test incrementing attribute levels + curriculum.increment_attr_level("words_in_sentence") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_words_in_sentence == 5 and increased_cfg.max_words_in_sentence == 50 + + # test decrementing attribute level + curriculum.decrement_attr_level("words_in_sentence") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_words_in_sentence == 5 and partially_decreased_cfg.max_words_in_sentence == 20