diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 0b5c4769..72e589dd 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -25,7 +25,11 @@ from .manipulate_matrix import ManipulateMatrixConfig, ManipulateMatrixCurriculu from .number_filtering import NumberFilteringConfig, NumberFilteringDataset from .number_sorting import NumberSortingConfig, NumberSortingDataset from .palindrome_generation import PalindromeConfig, PalindromeDataset -from .palindrome_partitioning import PalindromePartitioningConfig, PalindromePartitioningDataset +from .palindrome_partitioning import ( + PalindromePartitioningConfig, + PalindromePartitioningCurriculum, + PalindromePartitioningDataset, +) from .pool_matrix import PoolMatrixConfig, PoolMatrixCurriculum, PoolMatrixDataset from .ransom_note import RansomNoteConfig, RansomNoteDataset from .rotate_matrix import RotateMatrixConfig, RotateMatrixCurriculum, RotateMatrixDataset @@ -80,6 +84,7 @@ __all__ = [ "GroupAnagramsCurriculum", "PalindromePartitioningConfig", "PalindromePartitioningDataset", + "PalindromePartitioningCurriculum", "SpiralMatrixConfig", "SpiralMatrixDataset", "SpiralMatrixCurriculum", diff --git a/reasoning_gym/algorithmic/palindrome_partitioning.py b/reasoning_gym/algorithmic/palindrome_partitioning.py index 067a19e7..52637772 100644 --- a/reasoning_gym/algorithmic/palindrome_partitioning.py +++ b/reasoning_gym/algorithmic/palindrome_partitioning.py @@ -10,6 +10,7 @@ from dataclasses import dataclass from random import Random from typing import Any, Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset QUESTION_TEMPLATE = """Given a string, partition it such that every substring is a palindrome. @@ -30,7 +31,8 @@ class PalindromePartitioningConfig: min_string_len: int = 5 max_string_len: int = 15 - max_substring_palindome_len: int = 5 + min_substring_palindrome_len: int = 1 + max_substring_palindrome_len: int = 5 size: int = 500 # Virtual dataset size seed: Optional[int] = None @@ -39,9 +41,12 @@ class PalindromePartitioningConfig: """Validate configuration parameters""" assert 1 <= self.min_string_len, "Minimum string length must be at least 1" assert self.min_string_len <= self.max_string_len, "Minimum string length must be less than or equal to maximum" - assert 1 <= self.max_substring_palindome_len, "Maximum substring palindrome length must be at least 1" + assert 1 <= self.min_substring_palindrome_len, "Minimum substring palindrome length must be at least 1" assert ( - self.max_substring_palindome_len <= self.max_string_len + self.min_substring_palindrome_len <= self.max_substring_palindrome_len + ), "Minimum substring palindrome length must be less than or equal to maximum" + assert ( + self.max_substring_palindrome_len <= self.max_string_len ), "Maximum substring palindrome length must be less than or equal to maximum string length" @@ -108,30 +113,72 @@ class PalindromePartitioningDataset(ProceduralDataset): return letters + [middle_letter] + letters[::-1] return letters + letters[::-1] - def _get_string(self, rng: Random) -> str: + def _get_string(self, rng: Random, string_len: int) -> str: """Generate a random string""" - size = rng.randint(self.config.min_string_len, self.config.max_string_len) output = "" - - while len(output) < size: - palindrome_len = rng.randint(1, min(self.config.max_substring_palindome_len, size - len(output))) + while len(output) < string_len: + palindrome_len = min( + string_len - len(output), + rng.randint(self.config.min_substring_palindrome_len, self.config.max_substring_palindrome_len), + ) substring = "".join(self._generate_palindrome_letters(rng, palindrome_len)) output += substring - return output def __getitem__(self, idx: int) -> dict: """Generate a single Palindrome Partitioning question""" rng = Random(self.seed + idx) - string = self._get_string(rng) + + string_len = rng.randint(self.config.min_string_len, self.config.max_string_len) + string = self._get_string(rng, string_len) answer = self._palindrome_partitioning(string) answer_str = json.dumps(answer) return { "question": QUESTION_TEMPLATE.format(string=string), "answer": answer_str, - "metadata": {"string": string, "solution": answer}, + "metadata": { + "string": string, + "solution": answer, + "difficulty": { + "string_len": string_len, + }, + }, } -register_dataset("palindrome_partitioning", PalindromePartitioningDataset, PalindromePartitioningConfig) +class PalindromePartitioningCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(PalindromePartitioningCurriculum.__name__, PalindromePartitioningConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="string_len", + levels=[10, 100, 500, 1000], + default_level=0, + description="Length of the string", + attr_type=AttributeType.APPEND, + min_value=1, + lower_field_name="min_string_len", + upper_field_name="max_string_len", + ), + RangeAttributeDefinition( + name="substring_palindrome_len", + levels=[5, 10, 50, 100], + default_level=0, + description="Length of the substring palindrome", + attr_type=AttributeType.APPEND, + min_value=1, + lower_field_name="min_substring_palindrome_len", + upper_field_name="max_substring_palindrome_len", + ), + ) + + +register_dataset( + "palindrome_partitioning", + PalindromePartitioningDataset, + PalindromePartitioningConfig, + PalindromePartitioningCurriculum, +) diff --git a/tests/test_palindrome_partitioning.py b/tests/test_palindrome_partitioning.py index 7a0d386f..530a027b 100644 --- a/tests/test_palindrome_partitioning.py +++ b/tests/test_palindrome_partitioning.py @@ -4,6 +4,7 @@ import json from reasoning_gym.algorithmic.palindrome_partitioning import ( PalindromePartitioningConfig, + PalindromePartitioningCurriculum, PalindromePartitioningDataset, ) @@ -109,3 +110,31 @@ def test_palindrome_partitioning_score_answer(): answer = '["n", "o", "o", "n"], ["no", "on"], ["noon"]' item = {"metadata": {"solution": [["no", "on"], ["noon"], ["n", "o", "o", "n"]]}} assert dataset.score_answer(answer, item) == 0.0 + + +def test_palindrome_partitioning_curriculum(): + curriculum = PalindromePartitioningCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: PalindromePartitioningConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_string_len == 10 and base_cfg.max_string_len == 10 + assert base_cfg.min_substring_palindrome_len == 5 and base_cfg.max_substring_palindrome_len == 5 + + # test incrementing attribute levels + curriculum.increment_attr_level("string_len") + curriculum.increment_attr_level("substring_palindrome_len") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_string_len == 10 and increased_cfg.max_string_len == 100 + assert increased_cfg.min_substring_palindrome_len == 5 and increased_cfg.max_substring_palindrome_len == 10 + + # test decrementing attribute level for substring_palindrome_len again + curriculum.decrement_attr_level("substring_palindrome_len") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_string_len == 10 and partially_decreased_cfg.max_string_len == 100 + assert ( + partially_decreased_cfg.min_substring_palindrome_len == 5 + and partially_decreased_cfg.max_substring_palindrome_len == 5 + )