spell backward curriculum (#327)

Co-authored-by: Andreas Köpf <andreas.koepf@xamla.com>
This commit is contained in:
Zafir Stojanovski 2025-03-11 00:22:28 +01:00 committed by GitHub
parent a23c8c3d4e
commit f204a848d9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 59 additions and 5 deletions

View file

@ -35,7 +35,7 @@ from .ransom_note import RansomNoteConfig, RansomNoteCurriculum, RansomNoteDatas
from .rotate_matrix import RotateMatrixConfig, RotateMatrixCurriculum, RotateMatrixDataset
from .rotten_oranges import RottenOrangesConfig, RottenOrangesCurriculum, RottenOrangesDataset
from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingCurriculum, SentenceReorderingDataset
from .spell_backward import SpellBackwardConfig, SpellBackwardDataset
from .spell_backward import SpellBackwardConfig, SpellBackwardCurriculum, SpellBackwardDataset
from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixCurriculum, SpiralMatrixDataset
from .string_insertion import StringInsertionConfig, StringInsertionCurriculum, StringInsertionDataset
from .string_manipulation import StringManipulationConfig, StringManipulationDataset
@ -52,6 +52,7 @@ from .word_sorting import TextTransformation, WordSortingConfig, WordSortingCurr
__all__ = [
"SpellBackwardConfig",
"SpellBackwardDataset",
"SpellBackwardCurriculum",
"BaseConversionConfig",
"BaseConversionDataset",
"BaseConversionCurriculum",

View file

@ -5,6 +5,7 @@ from dataclasses import dataclass
from random import Random
from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..data import read_data_file
from ..factory import ProceduralDataset, register_dataset
@ -14,12 +15,14 @@ class SpellBackwardConfig:
"""Configuration for spelling words backward task generation"""
min_word_len: int = 3 # Minimum word length
max_word_len: int = 20 # Maximum word length
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_word_len > 0, "min_word_len must be positive"
assert self.max_word_len >= self.min_word_len, "max_word_len must be >= min_word_len"
class SpellBackwardDataset(ProceduralDataset):
@ -32,7 +35,9 @@ class SpellBackwardDataset(ProceduralDataset):
text = read_data_file("in_the_year_2889.txt")
# Extract words and clean them to contain only alphanumeric characters
self.words = [
word for word in re.findall(r"\b\w+\b", text) if word.isalnum() and len(word) >= config.min_word_len
word
for word in re.findall(r"\b\w+\b", text)
if word.isalnum() and config.min_word_len <= len(word) <= config.max_word_len
]
def __getitem__(self, idx: int) -> dict:
@ -46,7 +51,11 @@ class SpellBackwardDataset(ProceduralDataset):
return {
"question": f"Spell this word backward (example: sun -> nus): {word}",
"answer": answer,
"metadata": {"word": word, "word_len": len(word)},
"metadata": {
"word": word,
"word_len": len(word),
"difficulty": {"word_len": (self.config.min_word_len, self.config.max_word_len)},
},
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
@ -63,4 +72,23 @@ class SpellBackwardDataset(ProceduralDataset):
return reward
register_dataset("spell_backward", SpellBackwardDataset, SpellBackwardConfig)
class SpellBackwardCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(SpellBackwardCurriculum.__name__, SpellBackwardConfig)
# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="word_len",
levels=[5, 10, 20, 30],
default_level=1,
description="Word length",
attr_type=AttributeType.APPEND,
min_value=3,
lower_field_name="min_word_len",
upper_field_name="max_word_len",
),
)
register_dataset("spell_backward", SpellBackwardDataset, SpellBackwardConfig, SpellBackwardCurriculum)

View file

@ -2,7 +2,7 @@
import pytest
from reasoning_gym.algorithmic.spell_backward import SpellBackwardConfig, SpellBackwardDataset
from reasoning_gym.algorithmic.spell_backward import SpellBackwardConfig, SpellBackwardCurriculum, SpellBackwardDataset
def test_spell_backward_config_validation():
@ -11,6 +11,10 @@ def test_spell_backward_config_validation():
config = SpellBackwardConfig(min_word_len=0)
config.validate()
with pytest.raises(AssertionError):
config = SpellBackwardConfig(min_word_len=4, max_word_len=3)
config.validate()
def test_spell_backward_dataset_deterministic():
"""Test that dataset generates same items with same seed"""
@ -57,3 +61,24 @@ def test_spell_backward_dataset_iteration():
# Test multiple iterations yield same items
assert items == list(dataset)
def test_spell_backward_curriculum():
curriculum = SpellBackwardCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: SpellBackwardConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_word_len == 5 and base_cfg.max_word_len == 10
# test incrementing attribute levels
curriculum.increment_attr_level("word_len")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_word_len == 5 and increased_cfg.max_word_len == 20
# test decrementing attribute levels
curriculum.decrement_attr_level("word_len")
partially_decreased_cfg = curriculum.generate_configuration(base_value)
assert partially_decreased_cfg.min_word_len == 5 and partially_decreased_cfg.max_word_len == 10