diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 6185d73e..6fe3cc9a 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -24,7 +24,7 @@ from .letter_jumble import LetterJumbleConfig, LetterJumbleCurriculum, LetterJum from .manipulate_matrix import ManipulateMatrixConfig, ManipulateMatrixCurriculum, ManipulateMatrixDataset from .number_filtering import NumberFilteringConfig, NumberFilteringDataset from .number_sorting import NumberSortingConfig, NumberSortingCurriculum, NumberSortingDataset -from .palindrome_generation import PalindromeConfig, PalindromeDataset +from .palindrome_generation import PalindromeConfig, PalindromeCurriculum, PalindromeDataset from .palindrome_partitioning import ( PalindromePartitioningConfig, PalindromePartitioningCurriculum, @@ -88,6 +88,7 @@ __all__ = [ "WordLadderDataset", "PalindromeConfig", "PalindromeDataset", + "PalindromeCurriculum", "GroupAnagramsConfig", "GroupAnagramsDataset", "GroupAnagramsCurriculum", diff --git a/reasoning_gym/algorithmic/palindrome_generation.py b/reasoning_gym/algorithmic/palindrome_generation.py index acbb5a17..dd130036 100644 --- a/reasoning_gym/algorithmic/palindrome_generation.py +++ b/reasoning_gym/algorithmic/palindrome_generation.py @@ -3,6 +3,7 @@ import string from dataclasses import dataclass from typing import Any, Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset QUESTION_TEMPALTE = """Your task is, given a list of letters, to form a valid palindrome. @@ -68,6 +69,9 @@ class PalindromeDataset(ProceduralDataset): "metadata": { "letters": scrambled_letters, "generated_palindrome": palindrome, + "difficulty": { + "length": length, + }, }, } @@ -116,4 +120,23 @@ class PalindromeDataset(ProceduralDataset): return 1.0 # Correct solution -register_dataset("palindrome_generation", PalindromeDataset, PalindromeConfig) +class PalindromeCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(PalindromeCurriculum.__name__, PalindromeConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="length", + levels=[10, 50, 100, 500], + default_level=1, + description="Length of the generated palindrome.", + attr_type=AttributeType.APPEND, + min_value=2, + lower_field_name="min_length", + upper_field_name="max_length", + ) + ) + + +register_dataset("palindrome_generation", PalindromeDataset, PalindromeConfig, PalindromeCurriculum) diff --git a/tests/test_palindrome.py b/tests/test_palindrome.py index 472390de..48171d9a 100644 --- a/tests/test_palindrome.py +++ b/tests/test_palindrome.py @@ -1,6 +1,6 @@ import pytest -from reasoning_gym.algorithmic.palindrome_generation import PalindromeConfig, PalindromeDataset +from reasoning_gym.algorithmic.palindrome_generation import PalindromeConfig, PalindromeCurriculum, PalindromeDataset def test_palindrome_config_validation(): @@ -89,3 +89,24 @@ def test_score_answer(): # Empty input should score 0.0 assert dataset.score_answer(None, entry=item) == 0.0 + + +def test_palindrome_curriculum(): + curriculum = PalindromeCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: PalindromeConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_length == 10 and base_cfg.max_length == 50 + + # test incrementing attribute levels + curriculum.increment_attr_level("length") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_length == 10 and increased_cfg.max_length == 100 + + # test decrementing attribute levels + curriculum.decrement_attr_level("length") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_length == 10 and partially_decreased_cfg.max_length == 50