palindrome generation curriculum (#322)

This commit is contained in:
Zafir Stojanovski 2025-03-11 00:19:11 +01:00 committed by GitHub
parent ad48c551f9
commit 9aeef4ebb0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 48 additions and 3 deletions

View file

@ -24,7 +24,7 @@ from .letter_jumble import LetterJumbleConfig, LetterJumbleCurriculum, LetterJum
from .manipulate_matrix import ManipulateMatrixConfig, ManipulateMatrixCurriculum, ManipulateMatrixDataset from .manipulate_matrix import ManipulateMatrixConfig, ManipulateMatrixCurriculum, ManipulateMatrixDataset
from .number_filtering import NumberFilteringConfig, NumberFilteringDataset from .number_filtering import NumberFilteringConfig, NumberFilteringDataset
from .number_sorting import NumberSortingConfig, NumberSortingCurriculum, NumberSortingDataset from .number_sorting import NumberSortingConfig, NumberSortingCurriculum, NumberSortingDataset
from .palindrome_generation import PalindromeConfig, PalindromeDataset from .palindrome_generation import PalindromeConfig, PalindromeCurriculum, PalindromeDataset
from .palindrome_partitioning import ( from .palindrome_partitioning import (
PalindromePartitioningConfig, PalindromePartitioningConfig,
PalindromePartitioningCurriculum, PalindromePartitioningCurriculum,
@ -88,6 +88,7 @@ __all__ = [
"WordLadderDataset", "WordLadderDataset",
"PalindromeConfig", "PalindromeConfig",
"PalindromeDataset", "PalindromeDataset",
"PalindromeCurriculum",
"GroupAnagramsConfig", "GroupAnagramsConfig",
"GroupAnagramsDataset", "GroupAnagramsDataset",
"GroupAnagramsCurriculum", "GroupAnagramsCurriculum",

View file

@ -3,6 +3,7 @@ import string
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPALTE = """Your task is, given a list of letters, to form a valid palindrome. QUESTION_TEMPALTE = """Your task is, given a list of letters, to form a valid palindrome.
@ -68,6 +69,9 @@ class PalindromeDataset(ProceduralDataset):
"metadata": { "metadata": {
"letters": scrambled_letters, "letters": scrambled_letters,
"generated_palindrome": palindrome, "generated_palindrome": palindrome,
"difficulty": {
"length": length,
},
}, },
} }
@ -116,4 +120,23 @@ class PalindromeDataset(ProceduralDataset):
return 1.0 # Correct solution return 1.0 # Correct solution
register_dataset("palindrome_generation", PalindromeDataset, PalindromeConfig) class PalindromeCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(PalindromeCurriculum.__name__, PalindromeConfig)
# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="length",
levels=[10, 50, 100, 500],
default_level=1,
description="Length of the generated palindrome.",
attr_type=AttributeType.APPEND,
min_value=2,
lower_field_name="min_length",
upper_field_name="max_length",
)
)
register_dataset("palindrome_generation", PalindromeDataset, PalindromeConfig, PalindromeCurriculum)

View file

@ -1,6 +1,6 @@
import pytest import pytest
from reasoning_gym.algorithmic.palindrome_generation import PalindromeConfig, PalindromeDataset from reasoning_gym.algorithmic.palindrome_generation import PalindromeConfig, PalindromeCurriculum, PalindromeDataset
def test_palindrome_config_validation(): def test_palindrome_config_validation():
@ -89,3 +89,24 @@ def test_score_answer():
# Empty input should score 0.0 # Empty input should score 0.0
assert dataset.score_answer(None, entry=item) == 0.0 assert dataset.score_answer(None, entry=item) == 0.0
def test_palindrome_curriculum():
curriculum = PalindromeCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: PalindromeConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_length == 10 and base_cfg.max_length == 50
# test incrementing attribute levels
curriculum.increment_attr_level("length")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_length == 10 and increased_cfg.max_length == 100
# test decrementing attribute levels
curriculum.decrement_attr_level("length")
partially_decreased_cfg = curriculum.generate_configuration(base_value)
assert partially_decreased_cfg.min_length == 10 and partially_decreased_cfg.max_length == 50