diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 0db87606..5fc1df84 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -37,7 +37,7 @@ from .rotten_oranges import RottenOrangesConfig, RottenOrangesCurriculum, Rotten from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset from .spell_backward import SpellBackwardConfig, SpellBackwardDataset from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixCurriculum, SpiralMatrixDataset -from .string_insertion import StringInsertionConfig, StringInsertionDataset +from .string_insertion import StringInsertionConfig, StringInsertionCurriculum, StringInsertionDataset from .string_manipulation import StringManipulationConfig, StringManipulationDataset from .string_splitting import StringSplittingConfig, StringSplittingDataset from .string_synthesis import StringSynthesisConfig, StringSynthesisDataset @@ -117,6 +117,7 @@ __all__ = [ "GraphColorCurriculum", "StringInsertionConfig", "StringInsertionDataset", + "StringInsertionCurriculum", "StringManipulationConfig", "StringManipulationDataset", "StringSplittingConfig", diff --git a/reasoning_gym/algorithmic/string_insertion.py b/reasoning_gym/algorithmic/string_insertion.py index 0dafe8f4..6b5956c2 100644 --- a/reasoning_gym/algorithmic/string_insertion.py +++ b/reasoning_gym/algorithmic/string_insertion.py @@ -7,6 +7,7 @@ from dataclasses import dataclass from random import Random from typing import Any, Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset QUESTION_TEMPLATE = """Given a string consisting of characters A, B, C, D, and E, your job is to insert a character according to the following pattern: @@ -100,8 +101,33 @@ class StringInsertionDataset(ProceduralDataset): return { "question": QUESTION_TEMPLATE.format(string=string), "answer": str(answer), - "metadata": {"string": string, "solution": answer}, + "metadata": { + "string": string, + "solution": answer, + "difficulty": { + "string_length": string_length, + }, + }, } -register_dataset("string_insertion", StringInsertionDataset, StringInsertionConfig) +class StringInsertionCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(StringInsertionCurriculum.__name__, StringInsertionConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="string_length", + levels=[10, 50, 100, 1000], + default_level=1, + description="Length of the string", + attr_type=AttributeType.APPEND, + min_value=5, + lower_field_name="min_string_length", + upper_field_name="max_string_length", + ), + ) + + +register_dataset("string_insertion", StringInsertionDataset, StringInsertionConfig, StringInsertionCurriculum) diff --git a/tests/test_string_insertion.py b/tests/test_string_insertion.py index faff8d90..2bf602a0 100644 --- a/tests/test_string_insertion.py +++ b/tests/test_string_insertion.py @@ -2,7 +2,11 @@ import pytest -from reasoning_gym.algorithmic.string_insertion import StringInsertionConfig, StringInsertionDataset +from reasoning_gym.algorithmic.string_insertion import ( + StringInsertionConfig, + StringInsertionCurriculum, + StringInsertionDataset, +) def test_string_insertion_config_validation(): @@ -102,3 +106,24 @@ def test_string_insertion_answer(): answer = "['A', 'A', 'B', 'C', 'D', 'A', 'E', 'E', 'E', 'E', 'E', 'E', 'E', 'B', 'C', 'D', 'E', 'B', 'A', 'A', 'A', 'A', 'A']" entry = {"answer": "AABCDAEEEEEEEBCDEBAAAAA"} assert dataset.score_answer(answer, entry) == 0.1 + + +def test_string_insertion_curriculum(): + curriculum = StringInsertionCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: StringInsertionConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_string_length == 10 and base_cfg.max_string_length == 50 + + # test incrementing attribute levels + curriculum.increment_attr_level("string_length") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_string_length == 10 and increased_cfg.max_string_length == 100 + + # test decrementing attribute level for string_length again + curriculum.decrement_attr_level("string_length") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_string_length == 10 and partially_decreased_cfg.max_string_length == 50