diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 9fd0e64c..ea3774fa 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -38,8 +38,8 @@ from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDat from .spell_backward import SpellBackwardConfig, SpellBackwardDataset from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixCurriculum, SpiralMatrixDataset from .string_insertion import StringInsertionConfig, StringInsertionCurriculum, StringInsertionDataset -from .string_manipulation import StringManipulationConfig, StringManipulationCurriculum, StringManipulationDataset -from .string_splitting import StringSplittingConfig, StringSplittingDataset +from .string_manipulation import StringManipulationConfig, StringManipulationDataset +from .string_splitting import StringSplittingConfig, StringSplittingCurriculum, StringSplittingDataset from .string_synthesis import StringSynthesisConfig, StringSynthesisDataset from .word_ladder import WordLadderConfig, WordLadderDataset from .word_sequence_reversal import ( @@ -129,6 +129,7 @@ __all__ = [ "StringManipulationCurriculum", "StringSplittingConfig", "StringSplittingDataset", + "StringSplittingCurriculum", "StringSynthesisConfig", "StringSynthesisDataset", "RottenOrangesConfig", diff --git a/reasoning_gym/algorithmic/string_splitting.py b/reasoning_gym/algorithmic/string_splitting.py index 6679e812..490cfc47 100644 --- a/reasoning_gym/algorithmic/string_splitting.py +++ b/reasoning_gym/algorithmic/string_splitting.py @@ -7,6 +7,7 @@ from dataclasses import dataclass from random import Random from typing import Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset QUESTION_TEMPLATE = """There is a dismantling engineer who has old machines A, B, and C. @@ -24,6 +25,7 @@ The output should be the count of each machine and part type after the rules hav For example 1 0 1 5 4 3 means that you have 1 machine A, 0 machine B, 1 machine C, 5 part X, 4 part Y, and 3 part Z. Now, you have {A_machine} machine A, {B_machine} machine B, and {C_machine} machine C. Provide the count of each machine and part type after applying the above rules. +Note: Apply the rules at most {max_iterations} times. If the rules cannot be applied anymore, or if you have reached the maximum number of iterations, stop and provide the current counts of each machine and part type. """ @@ -115,10 +117,40 @@ class StringSplittingDataset(ProceduralDataset): answer_str = " ".join(str(x) for x in answer) return { - "question": QUESTION_TEMPLATE.format(A_machine=A_machine, B_machine=B_machine, C_machine=C_machine), + "question": QUESTION_TEMPLATE.format( + A_machine=A_machine, + B_machine=B_machine, + C_machine=C_machine, + max_iterations=self.config.max_iterations, + ), "answer": answer_str, - "metadata": {"states": states, "solution": answer}, + "metadata": { + "states": states, + "solution": answer, + "difficulty": { + "initial_machines": (A_machine, B_machine, C_machine), + }, + }, } -register_dataset("string_splitting", StringSplittingDataset, StringSplittingConfig) +class StringSplittingCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(StringSplittingCurriculum.__name__, StringSplittingConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="initial_machines", + levels=[10, 50, 100, 500], + default_level=1, + description="Number of initial machines", + attr_type=AttributeType.APPEND, + min_value=0, + lower_field_name="min_initial_machines", + upper_field_name="max_initial_machines", + ) + ) + + +register_dataset("string_splitting", StringSplittingDataset, StringSplittingConfig, StringSplittingCurriculum) diff --git a/tests/test_string_splitting.py b/tests/test_string_splitting.py index ef78eaa6..6d8f9e6b 100644 --- a/tests/test_string_splitting.py +++ b/tests/test_string_splitting.py @@ -2,7 +2,11 @@ import pytest -from reasoning_gym.algorithmic.string_splitting import StringSplittingConfig, StringSplittingDataset +from reasoning_gym.algorithmic.string_splitting import ( + StringSplittingConfig, + StringSplittingCurriculum, + StringSplittingDataset, +) def test_string_splitting_config_validation(): @@ -106,3 +110,24 @@ def test_string_splitting_answer(): [0, 0, 1, 3, 1, 1], [0, 0, 1, 2, 0, 2], ] + + +def test_string_splitting_curriculum(): + curriculum = StringSplittingCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: StringSplittingConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_initial_machines == 10 and base_cfg.max_initial_machines == 50 + + # test incrementing attribute levels + curriculum.increment_attr_level("initial_machines") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_initial_machines == 10 and increased_cfg.max_initial_machines == 100 + + # test decrementing attribute level for initial_machines again + curriculum.decrement_attr_level("initial_machines") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_initial_machines == 10 and partially_decreased_cfg.max_initial_machines == 50