diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 8f167f09..7eb204d0 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -39,7 +39,7 @@ from .string_splitting import StringSplittingConfig, StringSplittingDataset from .string_synthesis import StringSynthesisConfig, StringSynthesisDataset from .word_ladder import WordLadderConfig, WordLadderDataset from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset -from .word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset +from .word_sorting import TextTransformation, WordSortingConfig, WordSortingCurriculum, WordSortingDataset __all__ = [ "SpellBackwardConfig", @@ -67,6 +67,7 @@ __all__ = [ "SentenceReorderingDataset", "WordSequenceReversalConfig", "WordSequenceReversalDataset", + "WordSortingCurriculum", "WordSortingConfig", "WordSortingDataset", "TextTransformation", diff --git a/reasoning_gym/algorithmic/word_sorting.py b/reasoning_gym/algorithmic/word_sorting.py index d246bd5a..fe29aaa7 100644 --- a/reasoning_gym/algorithmic/word_sorting.py +++ b/reasoning_gym/algorithmic/word_sorting.py @@ -6,6 +6,7 @@ from enum import StrEnum from random import Random from typing import Any, Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..data import read_data_file from ..factory import ProceduralDataset, register_dataset @@ -105,11 +106,14 @@ class WordSortingDataset(ProceduralDataset): "question": QUESTION_TEMPLATE.format(direction=direction, words=", ".join(transformed_words)), "answer": ", ".join(answer), "metadata": { + "difficulty": { + "num_words": len(original_words), + "word_length": max(len(word) for word in original_words), + }, "original_words": original_words, + "sorted_words": answer, "transformed_words": transformed_words, "direction": direction, - "transformation": self.config.transformation, - "sorted_words": answer, }, } @@ -125,4 +129,32 @@ class WordSortingDataset(ProceduralDataset): return 0.0 +class WordSortingCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(WordSortingCurriculum.__name__, WordSortingConfig) + + self._define_attributes( + RangeAttributeDefinition( + name="num_words", + levels=[5, 10, 20, 30], + default_level=0, + description="Number of words to sort", + attr_type=AttributeType.APPEND, + min_value=5, + lower_field_name="min_words", + upper_field_name="max_words", + ), + RangeAttributeDefinition( + name="word_length", + levels=[3, 6, 9, 12], + default_level=0, + description="Length of words to sort", + attr_type=AttributeType.APPEND, + min_value=3, + lower_field_name="min_word_length", + upper_field_name="max_word_length", + ), + ) + + register_dataset("word_sorting", WordSortingDataset, WordSortingConfig) diff --git a/tests/test_word_sorting.py b/tests/test_word_sorting.py index 802c1472..a3dcc454 100644 --- a/tests/test_word_sorting.py +++ b/tests/test_word_sorting.py @@ -2,7 +2,12 @@ import pytest -from reasoning_gym.algorithmic.word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset +from reasoning_gym.algorithmic.word_sorting import ( + TextTransformation, + WordSortingConfig, + WordSortingCurriculum, + WordSortingDataset, +) def test_word_sorting_config_validation(): @@ -78,8 +83,6 @@ def test_word_sorting_dataset_items(): # Check metadata assert "original_words" in item["metadata"] assert "transformed_words" in item["metadata"] - assert "direction" in item["metadata"] - assert "transformation" in item["metadata"] assert "sorted_words" in item["metadata"] # Verify word count constraints @@ -148,3 +151,46 @@ def test_word_sorting_scoring(): # Empty answer answer = None assert dataset.score_answer(answer, item) == 0.0 + + +def test_word_sorting_curriculum(): + """Test the WordSortingCurriculum functionality""" + + curriculum = WordSortingCurriculum() + + base_value = {"size": 150, "seed": 1} + + # Test base configuration + base_cfg: WordSortingConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_words == 5 and base_cfg.max_words == 5 + assert base_cfg.min_word_length == 3 and base_cfg.max_word_length == 3 + assert base_cfg.transformation == TextTransformation.ORIGINAL + + # Test incrementing num_words attribute level + curriculum.increment_attr_level("num_words") + words_cfg = curriculum.generate_configuration(base_value) + assert words_cfg.min_words == 5 and words_cfg.max_words == 10 + + # Test incrementing word_length attribute level + curriculum.set_attr_level("num_words", 0) # Reset num_words to default level + curriculum.increment_attr_level("word_length") + length_cfg = curriculum.generate_configuration(base_value) + assert length_cfg.min_word_length == 3 and length_cfg.max_word_length == 6 + + # Test incrementing both attributes + curriculum.set_attr_level("num_words", 0) # Reset to default levels + curriculum.set_attr_level("word_length", 0) + curriculum.increment_attr_level("num_words") + curriculum.increment_attr_level("word_length") + combined_cfg = curriculum.generate_configuration(base_value) + assert combined_cfg.min_words == 5 and combined_cfg.max_words == 10 + assert combined_cfg.min_word_length == 3 and combined_cfg.max_word_length == 6 + + # Test max level + curriculum.set_attr_level("num_words", 0) # Reset to default level + for _ in range(5): # More than the number of levels + curriculum.increment_attr_level("num_words") + max_level_cfg = curriculum.generate_configuration(base_value) + assert max_level_cfg.min_words == 5 and max_level_cfg.max_words == 30 # Should be at the highest level