add word ladder curriculum (#361)

* add word ladder curriculum

* add to __init__.py
This commit is contained in:
Oliver Stanley 2025-03-14 15:10:52 +00:00 committed by GitHub
parent 8f8bd9d756
commit b5651e5e2c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 55 additions and 4 deletions

View file

@ -41,7 +41,7 @@ from .string_insertion import StringInsertionConfig, StringInsertionCurriculum,
from .string_manipulation import StringManipulationConfig, StringManipulationDataset from .string_manipulation import StringManipulationConfig, StringManipulationDataset
from .string_splitting import StringSplittingConfig, StringSplittingCurriculum, StringSplittingDataset from .string_splitting import StringSplittingConfig, StringSplittingCurriculum, StringSplittingDataset
from .string_synthesis import StringSynthesisConfig, StringSynthesisCurriculum, StringSynthesisDataset from .string_synthesis import StringSynthesisConfig, StringSynthesisCurriculum, StringSynthesisDataset
from .word_ladder import WordLadderConfig, WordLadderDataset from .word_ladder import WordLadderConfig, WordLadderCurriculum, WordLadderDataset
from .word_sequence_reversal import ( from .word_sequence_reversal import (
WordSequenceReversalConfig, WordSequenceReversalConfig,
WordSequenceReversalCurriculum, WordSequenceReversalCurriculum,
@ -90,6 +90,7 @@ __all__ = [
"WordSortingDataset", "WordSortingDataset",
"TextTransformation", "TextTransformation",
"WordLadderConfig", "WordLadderConfig",
"WordLadderCurriculum",
"WordLadderDataset", "WordLadderDataset",
"PalindromeConfig", "PalindromeConfig",
"PalindromeDataset", "PalindromeDataset",

View file

@ -5,6 +5,7 @@ from dataclasses import dataclass
from random import Random from random import Random
from typing import Any, Optional from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..data import get_data_file_path from ..data import get_data_file_path
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
@ -217,7 +218,16 @@ class WordLadderDataset(ProceduralDataset):
return { return {
"question": QUESTION_TEMPLATE.format(start=start, end=end), "question": QUESTION_TEMPLATE.format(start=start, end=end),
"answer": ",".join(path), "answer": ",".join(path),
"metadata": {"start_word": start, "end_word": end, "word_length": length, "chain_length": len(path)}, "metadata": {
"start_word": start,
"end_word": end,
"word_length": length,
"chain_length": len(path),
"difficulty": {
"word_length": length,
"chain_length": len(path),
},
},
} }
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
@ -259,4 +269,23 @@ class WordLadderDataset(ProceduralDataset):
return reward return reward
register_dataset("word_ladder", WordLadderDataset, WordLadderConfig) class WordLadderCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(WordLadderCurriculum.__name__, WordLadderConfig)
# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="word_length",
levels=[3, 4, 5, 6],
default_level=1,
description="Length of words in the puzzle",
attr_type=AttributeType.APPEND,
min_value=2,
lower_field_name="min_word_length",
upper_field_name="max_word_length",
)
)
register_dataset("word_ladder", WordLadderDataset, WordLadderConfig, WordLadderCurriculum)

View file

@ -3,7 +3,7 @@ from random import Random
import pytest import pytest
from reasoning_gym.algorithmic.word_ladder import WordLadderConfig, WordLadderDataset from reasoning_gym.algorithmic.word_ladder import WordLadderConfig, WordLadderCurriculum, WordLadderDataset
def test_word_ladder_config_validation(): def test_word_ladder_config_validation():
@ -397,3 +397,24 @@ def test_word_ladder_score_answer():
# Test with unknown words (should return partial credit) # Test with unknown words (should return partial credit)
assert dataset.score_answer("COLD,COXD,CORD,CARD,WARD,WARM", entry) < 1.0 assert dataset.score_answer("COLD,COXD,CORD,CARD,WARD,WARM", entry) < 1.0
assert dataset.score_answer("COLD,COXD,CORD,CARD,WARD,WARM", entry) > 0.0 assert dataset.score_answer("COLD,COXD,CORD,CARD,WARD,WARM", entry) > 0.0
def test_word_ladder_curriculum():
curriculum = WordLadderCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: WordLadderConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_word_length == 3 and base_cfg.max_word_length == 4
# test incrementing attribute levels
curriculum.increment_attr_level("word_length")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_word_length == 3 and increased_cfg.max_word_length == 5
# test decrementing attribute level for word length again
curriculum.decrement_attr_level("word_length")
partially_decreased_cfg = curriculum.generate_configuration(base_value)
assert partially_decreased_cfg.min_word_length == 3 and partially_decreased_cfg.max_word_length == 4