diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 32dc39be..62450f12 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -20,7 +20,7 @@ from .group_anagrams import GroupAnagramsConfig, GroupAnagramsCurriculum, GroupA from .isomorphic_strings import IsomorphicStringsConfig, IsomorphicStringsCurriculum, IsomorphicStringsDataset from .jugs import JugsConfig, JugsDataset from .letter_counting import LetterCountingConfig, LetterCountingCurriculum, LetterCountingDataset -from .letter_jumble import LetterJumbleConfig, LetterJumbleDataset +from .letter_jumble import LetterJumbleConfig, LetterJumbleCurriculum, LetterJumbleDataset from .manipulate_matrix import ManipulateMatrixConfig, ManipulateMatrixCurriculum, ManipulateMatrixDataset from .number_filtering import NumberFilteringConfig, NumberFilteringDataset from .number_sorting import NumberSortingConfig, NumberSortingDataset @@ -69,6 +69,7 @@ __all__ = [ "LetterCountingCurriculum", "LetterJumbleConfig", "LetterJumbleDataset", + "LetterJumbleCurriculum", "NumberFilteringConfig", "NumberFilteringDataset", "NumberSortingConfig", diff --git a/reasoning_gym/algorithmic/letter_jumble.py b/reasoning_gym/algorithmic/letter_jumble.py index 2cb3fc08..61c49174 100644 --- a/reasoning_gym/algorithmic/letter_jumble.py +++ b/reasoning_gym/algorithmic/letter_jumble.py @@ -7,6 +7,7 @@ from typing import Any, Optional from reasoning_gym.data import read_data_file +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset QUESTION_TEMPLATE = """Your task is to unsramble words in a sentence. @@ -107,6 +108,11 @@ class LetterJumbleDataset(ProceduralDataset): "corruption_level": corruption_level, "scrambled_words": scrambled_words, "original_words": selected_words, + "difficulty": { + "word_len": (self.config.min_word_len, self.config.max_word_len), + "words": num_words, + "corruption_level": corruption_level, + }, }, } @@ -154,4 +160,43 @@ class LetterJumbleDataset(ProceduralDataset): return partial_score -register_dataset("letter_jumble", LetterJumbleDataset, LetterJumbleConfig) +class LetterJumbleCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(LetterJumbleCurriculum.__name__, LetterJumbleConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="word_len", + levels=[5, 15, 30, 50], + default_level=1, + description="Word length", + attr_type=AttributeType.APPEND, + min_value=2, + lower_field_name="min_word_len", + upper_field_name="max_word_len", + ), + RangeAttributeDefinition( + name="words", + levels=[10, 50, 100, 500], + default_level=1, + description="Number of words", + attr_type=AttributeType.APPEND, + min_value=5, + lower_field_name="min_words", + upper_field_name="max_words", + ), + RangeAttributeDefinition( + name="corruption_level", + levels=[0.1, 0.3, 0.6, 0.9], + default_level=1, + description="Corruption level", + attr_type=AttributeType.APPEND, + min_value=0.0, + lower_field_name="min_corruption_level", + upper_field_name="max_corruption_level", + ), + ) + + +register_dataset("letter_jumble", LetterJumbleDataset, LetterJumbleConfig, LetterJumbleCurriculum) diff --git a/tests/test_letter_jumble.py b/tests/test_letter_jumble.py index 0a11ce1e..a167bcb4 100644 --- a/tests/test_letter_jumble.py +++ b/tests/test_letter_jumble.py @@ -4,7 +4,7 @@ from random import Random import pytest -from reasoning_gym.algorithmic.letter_jumble import LetterJumbleConfig, LetterJumbleDataset +from reasoning_gym.algorithmic.letter_jumble import LetterJumbleConfig, LetterJumbleCurriculum, LetterJumbleDataset def test_letter_jumble_config_validation(): @@ -128,3 +128,32 @@ def test_letter_jumble_iteration(): # Test multiple iterations yield same items assert items == list(dataset) + + +def test_letter_jumble_curriculum(): + curriculum = LetterJumbleCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: LetterJumbleConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_word_len == 5 and base_cfg.max_word_len == 15 + assert base_cfg.min_words == 10 and base_cfg.max_words == 50 + assert base_cfg.min_corruption_level == 0.1 and base_cfg.max_corruption_level == 0.3 + + # test incrementing attribute levels + curriculum.increment_attr_level("word_len") + curriculum.increment_attr_level("words") + curriculum.increment_attr_level("corruption_level") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_word_len == 5 and increased_cfg.max_word_len == 30 + assert increased_cfg.min_words == 10 and increased_cfg.max_words == 100 + assert increased_cfg.min_corruption_level == 0.1 and increased_cfg.max_corruption_level == 0.6 + + # test decrementing attribute level for words again + curriculum.decrement_attr_level("words") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_word_len == 5 and partially_decreased_cfg.max_word_len == 30 + assert partially_decreased_cfg.min_words == 10 and partially_decreased_cfg.max_words == 50 + assert partially_decreased_cfg.min_corruption_level == 0.1 and partially_decreased_cfg.max_corruption_level == 0.6