From b5651e5e2c6073e2dbdd80f9c8fbbaa790ef397d Mon Sep 17 00:00:00 2001 From: Oliver Stanley Date: Fri, 14 Mar 2025 15:10:52 +0000 Subject: [PATCH] add word ladder curriculum (#361) * add word ladder curriculum * add to __init__.py --- reasoning_gym/algorithmic/__init__.py | 3 ++- reasoning_gym/algorithmic/word_ladder.py | 33 ++++++++++++++++++++++-- tests/test_word_ladder.py | 23 ++++++++++++++++- 3 files changed, 55 insertions(+), 4 deletions(-) diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 337dae60..fed97fea 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -41,7 +41,7 @@ from .string_insertion import StringInsertionConfig, StringInsertionCurriculum, from .string_manipulation import StringManipulationConfig, StringManipulationDataset from .string_splitting import StringSplittingConfig, StringSplittingCurriculum, StringSplittingDataset from .string_synthesis import StringSynthesisConfig, StringSynthesisCurriculum, StringSynthesisDataset -from .word_ladder import WordLadderConfig, WordLadderDataset +from .word_ladder import WordLadderConfig, WordLadderCurriculum, WordLadderDataset from .word_sequence_reversal import ( WordSequenceReversalConfig, WordSequenceReversalCurriculum, @@ -90,6 +90,7 @@ __all__ = [ "WordSortingDataset", "TextTransformation", "WordLadderConfig", + "WordLadderCurriculum", "WordLadderDataset", "PalindromeConfig", "PalindromeDataset", diff --git a/reasoning_gym/algorithmic/word_ladder.py b/reasoning_gym/algorithmic/word_ladder.py index d15c7be6..7917903d 100644 --- a/reasoning_gym/algorithmic/word_ladder.py +++ b/reasoning_gym/algorithmic/word_ladder.py @@ -5,6 +5,7 @@ from dataclasses import dataclass from random import Random from typing import Any, Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..data import get_data_file_path from ..factory import ProceduralDataset, register_dataset @@ -217,7 +218,16 @@ class WordLadderDataset(ProceduralDataset): return { "question": QUESTION_TEMPLATE.format(start=start, end=end), "answer": ",".join(path), - "metadata": {"start_word": start, "end_word": end, "word_length": length, "chain_length": len(path)}, + "metadata": { + "start_word": start, + "end_word": end, + "word_length": length, + "chain_length": len(path), + "difficulty": { + "word_length": length, + "chain_length": len(path), + }, + }, } def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: @@ -259,4 +269,23 @@ class WordLadderDataset(ProceduralDataset): return reward -register_dataset("word_ladder", WordLadderDataset, WordLadderConfig) +class WordLadderCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(WordLadderCurriculum.__name__, WordLadderConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="word_length", + levels=[3, 4, 5, 6], + default_level=1, + description="Length of words in the puzzle", + attr_type=AttributeType.APPEND, + min_value=2, + lower_field_name="min_word_length", + upper_field_name="max_word_length", + ) + ) + + +register_dataset("word_ladder", WordLadderDataset, WordLadderConfig, WordLadderCurriculum) diff --git a/tests/test_word_ladder.py b/tests/test_word_ladder.py index 738d7939..ad16eadd 100644 --- a/tests/test_word_ladder.py +++ b/tests/test_word_ladder.py @@ -3,7 +3,7 @@ from random import Random import pytest -from reasoning_gym.algorithmic.word_ladder import WordLadderConfig, WordLadderDataset +from reasoning_gym.algorithmic.word_ladder import WordLadderConfig, WordLadderCurriculum, WordLadderDataset def test_word_ladder_config_validation(): @@ -397,3 +397,24 @@ def test_word_ladder_score_answer(): # Test with unknown words (should return partial credit) assert dataset.score_answer("COLD,COXD,CORD,CARD,WARD,WARM", entry) < 1.0 assert dataset.score_answer("COLD,COXD,CORD,CARD,WARD,WARM", entry) > 0.0 + + +def test_word_ladder_curriculum(): + curriculum = WordLadderCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: WordLadderConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_word_length == 3 and base_cfg.max_word_length == 4 + + # test incrementing attribute levels + curriculum.increment_attr_level("word_length") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_word_length == 3 and increased_cfg.max_word_length == 5 + + # test decrementing attribute level for word length again + curriculum.decrement_attr_level("word_length") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_word_length == 3 and partially_decreased_cfg.max_word_length == 4