string manipulation curriculum (#306)

This commit is contained in:
Zafir Stojanovski 2025-03-09 18:12:35 +01:00 committed by GitHub
parent e1e05884ee
commit 7c7c783883
3 changed files with 69 additions and 3 deletions

View file

@ -38,7 +38,7 @@ from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDat
from .spell_backward import SpellBackwardConfig, SpellBackwardDataset from .spell_backward import SpellBackwardConfig, SpellBackwardDataset
from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixCurriculum, SpiralMatrixDataset from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixCurriculum, SpiralMatrixDataset
from .string_insertion import StringInsertionConfig, StringInsertionCurriculum, StringInsertionDataset from .string_insertion import StringInsertionConfig, StringInsertionCurriculum, StringInsertionDataset
from .string_manipulation import StringManipulationConfig, StringManipulationDataset from .string_manipulation import StringManipulationConfig, StringManipulationCurriculum, StringManipulationDataset
from .string_splitting import StringSplittingConfig, StringSplittingDataset from .string_splitting import StringSplittingConfig, StringSplittingDataset
from .string_synthesis import StringSynthesisConfig, StringSynthesisDataset from .string_synthesis import StringSynthesisConfig, StringSynthesisDataset
from .word_ladder import WordLadderConfig, WordLadderDataset from .word_ladder import WordLadderConfig, WordLadderDataset

View file

@ -7,6 +7,7 @@ from dataclasses import dataclass
from random import Random from random import Random
from typing import Optional from typing import Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Your job is to repeatedly transform a string according to a set of rules until no further transformations can be performed, or a state is repeated. QUESTION_TEMPLATE = """Your job is to repeatedly transform a string according to a set of rules until no further transformations can be performed, or a state is repeated.
@ -42,6 +43,7 @@ class StringManipulationConfig:
assert self.min_string_length <= self.max_string_length, "Minimum string length should be less than maximum" assert self.min_string_length <= self.max_string_length, "Minimum string length should be less than maximum"
assert 3 <= self.min_num_rules, "Minimum number of rules should be at least 3" assert 3 <= self.min_num_rules, "Minimum number of rules should be at least 3"
assert self.min_num_rules <= self.max_num_rules, "Minimum number of rules should be less than maximum" assert self.min_num_rules <= self.max_num_rules, "Minimum number of rules should be less than maximum"
assert self.max_num_rules <= 20, "Maximum number of rules should be at most 20"
class StringManipulationDataset(ProceduralDataset): class StringManipulationDataset(ProceduralDataset):
@ -181,8 +183,43 @@ class StringManipulationDataset(ProceduralDataset):
"solution": answer, "solution": answer,
"states": states, "states": states,
"selected_rules": [rule for rule, _ in selected_rules], "selected_rules": [rule for rule, _ in selected_rules],
"difficulty": {
"string_length": string_length,
"num_rules": num_rules,
},
}, },
} }
register_dataset("string_manipulation", StringManipulationDataset, StringManipulationConfig) class StringManipulationCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(StringManipulationCurriculum.__name__, StringManipulationConfig)
# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="string_length",
levels=[10, 50, 100, 500],
default_level=0,
description="Length of the string",
attr_type=AttributeType.APPEND,
min_value=1,
lower_field_name="min_string_length",
upper_field_name="max_string_length",
),
RangeAttributeDefinition(
name="num_rules",
levels=[5, 10, 15, 20],
default_level=0,
description="Number of rules to apply",
attr_type=AttributeType.APPEND,
min_value=1,
lower_field_name="min_num_rules",
upper_field_name="max_num_rules",
),
)
register_dataset(
"string_manipulation", StringManipulationDataset, StringManipulationConfig, StringManipulationCurriculum
)

View file

@ -2,7 +2,11 @@
import pytest import pytest
from reasoning_gym.algorithmic.string_manipulation import StringManipulationConfig, StringManipulationDataset from reasoning_gym.algorithmic.string_manipulation import (
StringManipulationConfig,
StringManipulationCurriculum,
StringManipulationDataset,
)
def test_string_manipulation_config_validation(): def test_string_manipulation_config_validation():
@ -255,3 +259,28 @@ def test_string_manipulation_answer():
) )
] ]
assert dataset._get_all_transforms("acab", rules)[-1] == "zzab" assert dataset._get_all_transforms("acab", rules)[-1] == "zzab"
def test_string_manipulation_curriculum():
curriculum = StringManipulationCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: StringManipulationConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_string_length == 10 and base_cfg.max_string_length == 10
assert base_cfg.min_num_rules == 5 and base_cfg.max_num_rules == 5
# test incrementing attribute levels
curriculum.increment_attr_level("string_length")
curriculum.increment_attr_level("num_rules")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_string_length == 10 and increased_cfg.max_string_length == 50
assert increased_cfg.min_num_rules == 5 and increased_cfg.max_num_rules == 10
# test decrementing attribute level for string_length again
curriculum.decrement_attr_level("string_length")
partially_decreased_cfg = curriculum.generate_configuration(base_value)
assert partially_decreased_cfg.min_string_length == 10 and partially_decreased_cfg.max_string_length == 10
assert partially_decreased_cfg.min_num_rules == 5 and partially_decreased_cfg.max_num_rules == 10