diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 3ebebe6e..c16dc6e5 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -35,7 +35,7 @@ from .ransom_note import RansomNoteConfig, RansomNoteCurriculum, RansomNoteDatas from .rotate_matrix import RotateMatrixConfig, RotateMatrixCurriculum, RotateMatrixDataset from .rotten_oranges import RottenOrangesConfig, RottenOrangesCurriculum, RottenOrangesDataset from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingCurriculum, SentenceReorderingDataset -from .spell_backward import SpellBackwardConfig, SpellBackwardDataset +from .spell_backward import SpellBackwardConfig, SpellBackwardCurriculum, SpellBackwardDataset from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixCurriculum, SpiralMatrixDataset from .string_insertion import StringInsertionConfig, StringInsertionCurriculum, StringInsertionDataset from .string_manipulation import StringManipulationConfig, StringManipulationDataset @@ -52,6 +52,7 @@ from .word_sorting import TextTransformation, WordSortingConfig, WordSortingCurr __all__ = [ "SpellBackwardConfig", "SpellBackwardDataset", + "SpellBackwardCurriculum", "BaseConversionConfig", "BaseConversionDataset", "BaseConversionCurriculum", diff --git a/reasoning_gym/algorithmic/spell_backward.py b/reasoning_gym/algorithmic/spell_backward.py index bf33441b..c2e9d767 100644 --- a/reasoning_gym/algorithmic/spell_backward.py +++ b/reasoning_gym/algorithmic/spell_backward.py @@ -5,6 +5,7 @@ from dataclasses import dataclass from random import Random from typing import Any, Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..data import read_data_file from ..factory import ProceduralDataset, register_dataset @@ -14,12 +15,14 @@ class SpellBackwardConfig: """Configuration for spelling words backward task generation""" min_word_len: int = 3 # Minimum word length + max_word_len: int = 20 # Maximum word length seed: Optional[int] = None size: int = 500 # Virtual dataset size def validate(self) -> None: """Validate configuration parameters""" assert self.min_word_len > 0, "min_word_len must be positive" + assert self.max_word_len >= self.min_word_len, "max_word_len must be >= min_word_len" class SpellBackwardDataset(ProceduralDataset): @@ -32,7 +35,9 @@ class SpellBackwardDataset(ProceduralDataset): text = read_data_file("in_the_year_2889.txt") # Extract words and clean them to contain only alphanumeric characters self.words = [ - word for word in re.findall(r"\b\w+\b", text) if word.isalnum() and len(word) >= config.min_word_len + word + for word in re.findall(r"\b\w+\b", text) + if word.isalnum() and config.min_word_len <= len(word) <= config.max_word_len ] def __getitem__(self, idx: int) -> dict: @@ -46,7 +51,11 @@ class SpellBackwardDataset(ProceduralDataset): return { "question": f"Spell this word backward (example: sun -> nus): {word}", "answer": answer, - "metadata": {"word": word, "word_len": len(word)}, + "metadata": { + "word": word, + "word_len": len(word), + "difficulty": {"word_len": (self.config.min_word_len, self.config.max_word_len)}, + }, } def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: @@ -63,4 +72,23 @@ class SpellBackwardDataset(ProceduralDataset): return reward -register_dataset("spell_backward", SpellBackwardDataset, SpellBackwardConfig) +class SpellBackwardCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(SpellBackwardCurriculum.__name__, SpellBackwardConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="word_len", + levels=[5, 10, 20, 30], + default_level=1, + description="Word length", + attr_type=AttributeType.APPEND, + min_value=3, + lower_field_name="min_word_len", + upper_field_name="max_word_len", + ), + ) + + +register_dataset("spell_backward", SpellBackwardDataset, SpellBackwardConfig, SpellBackwardCurriculum) diff --git a/tests/test_spell_backward.py b/tests/test_spell_backward.py index 2db86c62..64d091a7 100644 --- a/tests/test_spell_backward.py +++ b/tests/test_spell_backward.py @@ -2,7 +2,7 @@ import pytest -from reasoning_gym.algorithmic.spell_backward import SpellBackwardConfig, SpellBackwardDataset +from reasoning_gym.algorithmic.spell_backward import SpellBackwardConfig, SpellBackwardCurriculum, SpellBackwardDataset def test_spell_backward_config_validation(): @@ -11,6 +11,10 @@ def test_spell_backward_config_validation(): config = SpellBackwardConfig(min_word_len=0) config.validate() + with pytest.raises(AssertionError): + config = SpellBackwardConfig(min_word_len=4, max_word_len=3) + config.validate() + def test_spell_backward_dataset_deterministic(): """Test that dataset generates same items with same seed""" @@ -57,3 +61,24 @@ def test_spell_backward_dataset_iteration(): # Test multiple iterations yield same items assert items == list(dataset) + + +def test_spell_backward_curriculum(): + curriculum = SpellBackwardCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: SpellBackwardConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_word_len == 5 and base_cfg.max_word_len == 10 + + # test incrementing attribute levels + curriculum.increment_attr_level("word_len") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_word_len == 5 and increased_cfg.max_word_len == 20 + + # test decrementing attribute levels + curriculum.decrement_attr_level("word_len") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_word_len == 5 and partially_decreased_cfg.max_word_len == 10