diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index b70a1e8d..9fd0e64c 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -19,7 +19,7 @@ from .graph_color import GraphColorConfig, GraphColorCurriculum, GraphColorDatas from .group_anagrams import GroupAnagramsConfig, GroupAnagramsCurriculum, GroupAnagramsDataset from .isomorphic_strings import IsomorphicStringsConfig, IsomorphicStringsCurriculum, IsomorphicStringsDataset from .jugs import JugsConfig, JugsDataset -from .letter_counting import LetterCountingConfig, LetterCountingDataset +from .letter_counting import LetterCountingConfig, LetterCountingCurriculum, LetterCountingDataset from .letter_jumble import LetterJumbleConfig, LetterJumbleDataset from .manipulate_matrix import ManipulateMatrixConfig, ManipulateMatrixCurriculum, ManipulateMatrixDataset from .number_filtering import NumberFilteringConfig, NumberFilteringDataset @@ -66,6 +66,7 @@ __all__ = [ "GameOfLifeHaltingDataset", "LetterCountingConfig", "LetterCountingDataset", + "LetterCountingCurriculum", "LetterJumbleConfig", "LetterJumbleDataset", "NumberFilteringConfig", diff --git a/reasoning_gym/algorithmic/letter_counting.py b/reasoning_gym/algorithmic/letter_counting.py index 2ed65737..43d49178 100644 --- a/reasoning_gym/algorithmic/letter_counting.py +++ b/reasoning_gym/algorithmic/letter_counting.py @@ -7,6 +7,7 @@ from typing import Optional from reasoning_gym.data import read_data_file +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -41,7 +42,10 @@ class LetterCountingDataset(ProceduralDataset): rng = Random(self.seed + idx) # Select random span of words - span_length = rng.randint(self.config.min_words, self.config.max_words) + span_length = min( + rng.randint(self.config.min_words, self.config.max_words), + len(self.words), + ) start_idx = rng.randint(0, len(self.words) - span_length) span = self.words[start_idx : start_idx + span_length] @@ -59,8 +63,32 @@ class LetterCountingDataset(ProceduralDataset): return { "question": f'How many times does the letter "{target_letter}" appear in the text: "{" ".join(span)}"?', "answer": str(count), - "metadata": {"span_length": span_length, "target_letter": target_letter, "span": span}, + "metadata": { + "span_length": span_length, + "target_letter": target_letter, + "span": span, + "difficulty": {"words": span_length}, + }, } -register_dataset("letter_counting", LetterCountingDataset, LetterCountingConfig) +class LetterCountingCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(LetterCountingCurriculum.__name__, LetterCountingConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="words", + levels=[10, 50, 100, 1000], + default_level=1, + description="Number of words in the span", + attr_type=AttributeType.APPEND, + min_value=1, + lower_field_name="min_words", + upper_field_name="max_words", + ), + ) + + +register_dataset("letter_counting", LetterCountingDataset, LetterCountingConfig, LetterCountingCurriculum) diff --git a/tests/test_letter_counting.py b/tests/test_letter_counting.py index 7c6e9bd1..6484d553 100644 --- a/tests/test_letter_counting.py +++ b/tests/test_letter_counting.py @@ -2,7 +2,11 @@ import pytest -from reasoning_gym.algorithmic.letter_counting import LetterCountingConfig, LetterCountingDataset +from reasoning_gym.algorithmic.letter_counting import ( + LetterCountingConfig, + LetterCountingCurriculum, + LetterCountingDataset, +) def test_letter_counting_config_validation(): @@ -76,3 +80,24 @@ def test_letter_counting_text_preprocessing(): assert len(dataset.words) > 0 # Verify words contain only word characters assert all(word.isalnum() for word in dataset.words) + + +def test_letter_counting_curriculum(): + curriculum = LetterCountingCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: LetterCountingConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_words == 10 and base_cfg.max_words == 50 + + # test incrementing attribute levels + curriculum.increment_attr_level("words") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_words == 10 and increased_cfg.max_words == 100 + + # test decrementing attribute level for words again + curriculum.decrement_attr_level("words") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_words == 10 and partially_decreased_cfg.max_words == 50