diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index ea3774fa..32dc39be 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -40,7 +40,7 @@ from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixCurriculum, SpiralMat from .string_insertion import StringInsertionConfig, StringInsertionCurriculum, StringInsertionDataset from .string_manipulation import StringManipulationConfig, StringManipulationDataset from .string_splitting import StringSplittingConfig, StringSplittingCurriculum, StringSplittingDataset -from .string_synthesis import StringSynthesisConfig, StringSynthesisDataset +from .string_synthesis import StringSynthesisConfig, StringSynthesisCurriculum, StringSynthesisDataset from .word_ladder import WordLadderConfig, WordLadderDataset from .word_sequence_reversal import ( WordSequenceReversalConfig, @@ -132,6 +132,7 @@ __all__ = [ "StringSplittingCurriculum", "StringSynthesisConfig", "StringSynthesisDataset", + "StringSynthesisCurriculum", "RottenOrangesConfig", "RottenOrangesDataset", "RottenOrangesCurriculum", diff --git a/reasoning_gym/algorithmic/string_synthesis.py b/reasoning_gym/algorithmic/string_synthesis.py index 63cafa98..cad98f7f 100644 --- a/reasoning_gym/algorithmic/string_synthesis.py +++ b/reasoning_gym/algorithmic/string_synthesis.py @@ -7,6 +7,7 @@ from dataclasses import dataclass from random import Random from typing import Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset QUESTION_TEMPLATE = """There are nine different blocks [A] [B] [C] {{A}} {{B}} {{C}} (A) (B) (C) @@ -24,6 +25,7 @@ The output should be the count of each block type after the rules have been appl For example 1 0 3 0 2 0 0 0 1 means that you have 1 [A] 0 [B] 3 [C] 0 {{A}} 2 {{B}} 0 {{C}} 0 (A) 0 (B) 1 (C). Now, you have {A_square} [A], {B_square} [B], and {C_square} [C] blocks. Provide the count of each block type after applying the above rules. +Note: Apply the rules at most {max_iterations} times. If the rules cannot be applied anymore, or if you have reached the maximum number of iterations, stop and provide the current counts. """ @@ -120,10 +122,40 @@ class StringSynthesisDataset(ProceduralDataset): answer_str = " ".join(str(x) for x in answer) return { - "question": QUESTION_TEMPLATE.format(A_square=A_square, B_square=B_square, C_square=C_square), + "question": QUESTION_TEMPLATE.format( + A_square=A_square, + B_square=B_square, + C_square=C_square, + max_iterations=self.config.max_iterations, + ), "answer": answer_str, - "metadata": {"states": states, "solution": answer}, + "metadata": { + "states": states, + "solution": answer, + "difficulty": { + "initial_blocks": (A_square, B_square, C_square), + }, + }, } -register_dataset("string_synthesis", StringSynthesisDataset, StringSynthesisConfig) +class StringSynthesisCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(StringSynthesisCurriculum.__name__, StringSynthesisConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="initial_blocks", + levels=[10, 50, 100, 500], + default_level=1, + description="Number of initial blocks", + attr_type=AttributeType.APPEND, + min_value=0, + lower_field_name="min_initial_blocks", + upper_field_name="max_initial_blocks", + ) + ) + + +register_dataset("string_synthesis", StringSynthesisDataset, StringSynthesisConfig, StringSynthesisCurriculum) diff --git a/tests/test_string_synthesis.py b/tests/test_string_synthesis.py index 39fa4133..909ab6d5 100644 --- a/tests/test_string_synthesis.py +++ b/tests/test_string_synthesis.py @@ -2,7 +2,11 @@ import pytest -from reasoning_gym.algorithmic.string_synthesis import StringSynthesisConfig, StringSynthesisDataset +from reasoning_gym.algorithmic.string_synthesis import ( + StringSynthesisConfig, + StringSynthesisCurriculum, + StringSynthesisDataset, +) def test_string_synthesis_config_validation(): @@ -117,3 +121,24 @@ def test_string_synthesis_answer(): [0, 1, 1, 2, 0, 0, 0, 0, 0], # Rule 1 again [0, 0, 0, 2, 1, 0, 0, 0, 0], # Rule 3 (final state) ] + + +def test_string_synthesis_curriculum(): + curriculum = StringSynthesisCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: StringSynthesisConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_initial_blocks == 10 and base_cfg.max_initial_blocks == 50 + + # test incrementing attribute levels + curriculum.increment_attr_level("initial_blocks") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_initial_blocks == 10 and increased_cfg.max_initial_blocks == 100 + + # test decrementing attribute level for initial_blocks again + curriculum.decrement_attr_level("initial_blocks") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_initial_blocks == 10 and partially_decreased_cfg.max_initial_blocks == 50