diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index e6188eb2..0db87606 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -10,7 +10,7 @@ from .ab import ABConfig, ABDataset from .base_conversion import BaseConversionConfig, BaseConversionCurriculum, BaseConversionDataset from .binary_alternation import BinaryAlternationConfig, BinaryAlternationCurriculum, BinaryAlternationDataset from .binary_matrix import BinaryMatrixConfig, BinaryMatrixCurriculum, BinaryMatrixDataset -from .caesar_cipher import CaesarCipherConfig, CaesarCipherDataset +from .caesar_cipher import CaesarCipherConfig, CaesarCipherCurriculum, CaesarCipherDataset from .count_primes import CountPrimesConfig, CountPrimesCurriculum, CountPrimesDataset from .cryptarithm import CryptarithmConfig, CryptarithmDataset from .game_of_life import GameOfLifeConfig, GameOfLifeDataset @@ -53,6 +53,7 @@ __all__ = [ "BaseConversionCurriculum", "CaesarCipherConfig", "CaesarCipherDataset", + "CaesarCipherCurriculum", "CryptarithmConfig", "CryptarithmDataset", "GameOfLifeConfig", diff --git a/reasoning_gym/algorithmic/caesar_cipher.py b/reasoning_gym/algorithmic/caesar_cipher.py index 3f7dedd0..0f628e39 100644 --- a/reasoning_gym/algorithmic/caesar_cipher.py +++ b/reasoning_gym/algorithmic/caesar_cipher.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from random import Random from typing import Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..data import read_data_file from ..factory import ProceduralDataset, register_dataset @@ -74,8 +75,40 @@ class CaesarCipherDataset(ProceduralDataset): return { "question": f"Decrypt this Caesar cipher text: {cipher_text}. Provide only the decrypted text as your final answer.", "answer": sentence, - "metadata": {"rotation": rotation, "cipher_text": cipher_text, "clear_text": sentence}, + "metadata": { + "rotation": rotation, + "cipher_text": cipher_text, + "clear_text": sentence, + }, } -register_dataset("caesar_cipher", CaesarCipherDataset, CaesarCipherConfig) +class CaesarCipherCurriculum(BaseCurriculum): + """Curriculum for Caesar cipher task generation""" + + def __init__(self): + super().__init__(CaesarCipherCurriculum.__name__, CaesarCipherConfig) + + self._define_attributes( + RangeAttributeDefinition( + name="rotation", + levels=[5, 10, 15, 25], + default_level=0, + description="Max rotation for cipher", + attr_type=AttributeType.APPEND, + lower_field_name="min_rotation", + upper_field_name="max_rotation", + ), + RangeAttributeDefinition( + name="words", + levels=[5, 10, 15, 25], + default_level=0, + description="Max number of words", + attr_type=AttributeType.APPEND, + lower_field_name="min_words", + upper_field_name="max_words", + ), + ) + + +register_dataset("caesar_cipher", CaesarCipherDataset, CaesarCipherConfig, CaesarCipherCurriculum) diff --git a/tests/test_caesar_cipher.py b/tests/test_caesar_cipher.py index fa572d8d..54cc8065 100644 --- a/tests/test_caesar_cipher.py +++ b/tests/test_caesar_cipher.py @@ -2,7 +2,7 @@ import pytest -from reasoning_gym.algorithmic.caesar_cipher import CaesarCipherConfig, CaesarCipherDataset +from reasoning_gym.algorithmic.caesar_cipher import CaesarCipherConfig, CaesarCipherCurriculum, CaesarCipherDataset def test_caesar_cipher_config_validation(): @@ -98,3 +98,48 @@ def test_caesar_cipher_iteration(): # Test multiple iterations yield same items assert items == list(dataset) + + +def test_caesar_cipher_curriculum(): + curriculum = CaesarCipherCurriculum() + base_value = {"size": 150, "seed": 1} + + base_cfg: CaesarCipherConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_rotation == base_cfg.max_rotation == 5 + assert base_cfg.min_words == base_cfg.max_words == 5 + + curriculum.increment_attr_level("rotation") + cfg = curriculum.generate_configuration(base_value) + assert cfg.min_rotation == 5 + assert cfg.max_rotation == 10 + + curriculum.increment_attr_level("words") + cfg = curriculum.generate_configuration(base_value) + assert cfg.min_words == 5 + assert cfg.max_words == 10 + + curriculum.increment_attr_level("rotation") + curriculum.increment_attr_level("words") + cfg = curriculum.generate_configuration(base_value) + assert cfg.min_rotation == 5 + assert cfg.max_rotation == 15 + assert cfg.min_words == 5 + assert cfg.max_words == 15 + + curriculum.increment_attr_level("rotation") + curriculum.increment_attr_level("words") + cfg = curriculum.generate_configuration(base_value) + assert cfg.min_rotation == 5 + assert cfg.max_rotation == 25 + assert cfg.min_words == 5 + assert cfg.max_words == 25 + + curriculum.decrement_attr_level("rotation") + curriculum.decrement_attr_level("words") + cfg = curriculum.generate_configuration(base_value) + assert cfg.min_rotation == 5 + assert cfg.max_rotation == 15 + assert cfg.min_words == 5 + assert cfg.max_words == 15