spell backward curriculum (#327)

Co-authored-by: Andreas Köpf <andreas.koepf@xamla.com>
This commit is contained in:
Zafir Stojanovski 2025-03-11 00:22:28 +01:00 committed by GitHub
parent a23c8c3d4e
commit f204a848d9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 59 additions and 5 deletions

View file

@ -5,6 +5,7 @@ from dataclasses import dataclass
from random import Random
from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..data import read_data_file
from ..factory import ProceduralDataset, register_dataset
@ -14,12 +15,14 @@ class SpellBackwardConfig:
"""Configuration for spelling words backward task generation"""
min_word_len: int = 3 # Minimum word length
max_word_len: int = 20 # Maximum word length
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_word_len > 0, "min_word_len must be positive"
assert self.max_word_len >= self.min_word_len, "max_word_len must be >= min_word_len"
class SpellBackwardDataset(ProceduralDataset):
@ -32,7 +35,9 @@ class SpellBackwardDataset(ProceduralDataset):
text = read_data_file("in_the_year_2889.txt")
# Extract words and clean them to contain only alphanumeric characters
self.words = [
word for word in re.findall(r"\b\w+\b", text) if word.isalnum() and len(word) >= config.min_word_len
word
for word in re.findall(r"\b\w+\b", text)
if word.isalnum() and config.min_word_len <= len(word) <= config.max_word_len
]
def __getitem__(self, idx: int) -> dict:
@ -46,7 +51,11 @@ class SpellBackwardDataset(ProceduralDataset):
return {
"question": f"Spell this word backward (example: sun -> nus): {word}",
"answer": answer,
"metadata": {"word": word, "word_len": len(word)},
"metadata": {
"word": word,
"word_len": len(word),
"difficulty": {"word_len": (self.config.min_word_len, self.config.max_word_len)},
},
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
@ -63,4 +72,23 @@ class SpellBackwardDataset(ProceduralDataset):
return reward
register_dataset("spell_backward", SpellBackwardDataset, SpellBackwardConfig)
class SpellBackwardCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(SpellBackwardCurriculum.__name__, SpellBackwardConfig)
# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="word_len",
levels=[5, 10, 20, 30],
default_level=1,
description="Word length",
attr_type=AttributeType.APPEND,
min_value=3,
lower_field_name="min_word_len",
upper_field_name="max_word_len",
),
)
register_dataset("spell_backward", SpellBackwardDataset, SpellBackwardConfig, SpellBackwardCurriculum)