diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index c16dc6e5..190bbf32 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -22,7 +22,7 @@ from .jugs import JugsConfig, JugsDataset from .letter_counting import LetterCountingConfig, LetterCountingCurriculum, LetterCountingDataset from .letter_jumble import LetterJumbleConfig, LetterJumbleCurriculum, LetterJumbleDataset from .manipulate_matrix import ManipulateMatrixConfig, ManipulateMatrixCurriculum, ManipulateMatrixDataset -from .number_filtering import NumberFilteringConfig, NumberFilteringDataset +from .number_filtering import NumberFilteringConfig, NumberFilteringCurriculum, NumberFilteringDataset from .number_sorting import NumberSortingConfig, NumberSortingCurriculum, NumberSortingDataset from .palindrome_generation import PalindromeConfig, PalindromeCurriculum, PalindromeDataset from .palindrome_partitioning import ( @@ -73,6 +73,7 @@ __all__ = [ "LetterJumbleCurriculum", "NumberFilteringConfig", "NumberFilteringDataset", + "NumberFilteringCurriculum", "NumberSortingConfig", "NumberSortingDataset", "NumberSortingCurriculum", diff --git a/reasoning_gym/algorithmic/number_filtering.py b/reasoning_gym/algorithmic/number_filtering.py index f122f04c..523933bf 100644 --- a/reasoning_gym/algorithmic/number_filtering.py +++ b/reasoning_gym/algorithmic/number_filtering.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from random import Random from typing import Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -94,8 +95,52 @@ class NumberFilteringDataset(ProceduralDataset): "filter_value": filter_str, "operation": f"{keep_remove}_{larger_smaller}", "result": result_strs, + "difficulty": { + "numbers": len(numbers), + "decimals": (self.config.min_decimals, self.config.max_decimals), + "value": (self.config.min_value, self.config.max_value), + }, }, } -register_dataset("number_filtering", NumberFilteringDataset, NumberFilteringConfig) +class NumberFilteringCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(NumberFilteringCurriculum.__name__, NumberFilteringConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="numbers", + levels=[10, 100, 500, 1000], + default_level=1, + description="How many numbers to sort", + attr_type=AttributeType.APPEND, + min_value=2, + lower_field_name="min_numbers", + upper_field_name="max_numbers", + ), + RangeAttributeDefinition( + name="decimals", + levels=[0, 2, 4, 6], + default_level=1, + description="Number of decimal places", + attr_type=AttributeType.APPEND, + min_value=0, + lower_field_name="min_decimals", + upper_field_name="max_decimals", + ), + RangeAttributeDefinition( + name="value", + levels=[-10_000, 10_000], + default_level=1, + description="Range of numbers to sort", + attr_type=AttributeType.APPEND, + min_value=-10_000, + lower_field_name="min_value", + upper_field_name="max_value", + ), + ) + + +register_dataset("number_filtering", NumberFilteringDataset, NumberFilteringConfig, NumberFilteringCurriculum) diff --git a/tests/test_number_filtering.py b/tests/test_number_filtering.py index e70ec6b7..6033e867 100644 --- a/tests/test_number_filtering.py +++ b/tests/test_number_filtering.py @@ -2,7 +2,11 @@ import pytest -from reasoning_gym.algorithmic.number_filtering import NumberFilteringConfig, NumberFilteringDataset +from reasoning_gym.algorithmic.number_filtering import ( + NumberFilteringConfig, + NumberFilteringCurriculum, + NumberFilteringDataset, +) def test_number_filtering_config_validation(): @@ -116,3 +120,31 @@ def test_number_filtering_precision(): # Check that string representations maintain precision for num in item["metadata"]["original_numbers"]: assert len(num.split(".")[-1]) == 2 + + +def test_number_filtering_curriculum(): + curriculum = NumberFilteringCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: NumberFilteringConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_numbers == 10 and base_cfg.max_numbers == 100 + assert base_cfg.min_decimals == 0 and base_cfg.max_decimals == 2 + assert base_cfg.min_value == -10_000 and base_cfg.max_value == 10_000 + + # test incrementing some attribute levels + curriculum.increment_attr_level("numbers") + curriculum.increment_attr_level("decimals") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_numbers == 10 and increased_cfg.max_numbers == 500 + assert increased_cfg.min_decimals == 0 and increased_cfg.max_decimals == 4 + assert increased_cfg.min_value == -10_000 and increased_cfg.max_value == 10_000 + + # test decrementing attribute level for numbers again + curriculum.decrement_attr_level("numbers") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_numbers == 10 and partially_decreased_cfg.max_numbers == 100 + assert partially_decreased_cfg.min_decimals == 0 and partially_decreased_cfg.max_decimals == 4 + assert partially_decreased_cfg.min_value == -10_000 and partially_decreased_cfg.max_value == 10_000