diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 62450f12..6185d73e 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -23,7 +23,7 @@ from .letter_counting import LetterCountingConfig, LetterCountingCurriculum, Let from .letter_jumble import LetterJumbleConfig, LetterJumbleCurriculum, LetterJumbleDataset from .manipulate_matrix import ManipulateMatrixConfig, ManipulateMatrixCurriculum, ManipulateMatrixDataset from .number_filtering import NumberFilteringConfig, NumberFilteringDataset -from .number_sorting import NumberSortingConfig, NumberSortingDataset +from .number_sorting import NumberSortingConfig, NumberSortingCurriculum, NumberSortingDataset from .palindrome_generation import PalindromeConfig, PalindromeDataset from .palindrome_partitioning import ( PalindromePartitioningConfig, @@ -74,6 +74,7 @@ __all__ = [ "NumberFilteringDataset", "NumberSortingConfig", "NumberSortingDataset", + "NumberSortingCurriculum", "SentenceReorderingConfig", "SentenceReorderingDataset", "WordSequenceReversalConfig", diff --git a/reasoning_gym/algorithmic/number_sorting.py b/reasoning_gym/algorithmic/number_sorting.py index c7170347..3c02896c 100644 --- a/reasoning_gym/algorithmic/number_sorting.py +++ b/reasoning_gym/algorithmic/number_sorting.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from random import Random from typing import Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -46,16 +47,14 @@ Please follow the instruction below: # Reparse to ensure exact decimal representation return f"{float(formatted):.{decimals}f}" - def _generate_numbers(self, rng: Random) -> tuple[list[float], list[str]]: + def _generate_numbers(self, rng: Random, count: int) -> tuple[list[float], list[str]]: """Generate list of numbers and their string representations""" - count = rng.randint(self.config.min_numbers, self.config.max_numbers) - decimals = rng.randint(self.config.min_decimals, self.config.max_decimals) - numbers = [] number_strs = [] for _ in range(count): num = rng.uniform(self.config.min_value, self.config.max_value) + decimals = rng.randint(self.config.min_decimals, self.config.max_decimals) num_str = self._format_number(num, decimals) # Reparse to ensure exact value num = float(num_str) @@ -68,7 +67,8 @@ Please follow the instruction below: """Generate a single sorting task""" rng = Random(self.seed + idx) - numbers, number_strs = self._generate_numbers(rng) + count = rng.randint(self.config.min_numbers, self.config.max_numbers) + numbers, number_strs = self._generate_numbers(rng, count) # Generate both ascending and descending answers asc_numbers = sorted(numbers) @@ -88,8 +88,56 @@ Please follow the instruction below: return { "question": question, "answer": str(answer), - "metadata": {"original_numbers": number_strs, "direction": direction, "sorted_numbers": answer}, + "metadata": { + "original_numbers": number_strs, + "direction": direction, + "sorted_numbers": answer, + "difficulty": { + "numbers": count, + "decimals": (self.config.min_decimals, self.config.max_decimals), + "value": (self.config.min_value, self.config.max_value), + }, + }, } -register_dataset("number_sorting", NumberSortingDataset, NumberSortingConfig) +class NumberSortingCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(NumberSortingCurriculum.__name__, NumberSortingConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="numbers", + levels=[10, 100, 500, 1000], + default_level=1, + description="How many numbers to sort", + attr_type=AttributeType.APPEND, + min_value=2, + lower_field_name="min_numbers", + upper_field_name="max_numbers", + ), + RangeAttributeDefinition( + name="decimals", + levels=[0, 2, 4, 6], + default_level=1, + description="Number of decimal places", + attr_type=AttributeType.APPEND, + min_value=0, + lower_field_name="min_decimals", + upper_field_name="max_decimals", + ), + RangeAttributeDefinition( + name="value", + levels=[-10_000, 10_000], + default_level=1, + description="Range of numbers to sort", + attr_type=AttributeType.APPEND, + min_value=-10_000, + lower_field_name="min_value", + upper_field_name="max_value", + ), + ) + + +register_dataset("number_sorting", NumberSortingDataset, NumberSortingConfig, NumberSortingCurriculum) diff --git a/tests/test_number_sorting.py b/tests/test_number_sorting.py index 88916076..531a3103 100644 --- a/tests/test_number_sorting.py +++ b/tests/test_number_sorting.py @@ -2,7 +2,7 @@ import pytest -from reasoning_gym.algorithmic.number_sorting import NumberSortingConfig, NumberSortingDataset +from reasoning_gym.algorithmic.number_sorting import NumberSortingConfig, NumberSortingCurriculum, NumberSortingDataset def test_number_sorting_config_validation(): @@ -89,3 +89,31 @@ def test_number_sorting_dataset_iteration(): # Test multiple iterations yield same items assert items == list(dataset) + + +def test_number_sorting_curriculum(): + curriculum = NumberSortingCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: NumberSortingConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_numbers == 10 and base_cfg.max_numbers == 100 + assert base_cfg.min_decimals == 0 and base_cfg.max_decimals == 2 + assert base_cfg.min_value == -10_000 and base_cfg.max_value == 10_000 + + # test incrementing some attribute levels + curriculum.increment_attr_level("numbers") + curriculum.increment_attr_level("decimals") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_numbers == 10 and increased_cfg.max_numbers == 500 + assert increased_cfg.min_decimals == 0 and increased_cfg.max_decimals == 4 + assert increased_cfg.min_value == -10_000 and increased_cfg.max_value == 10_000 + + # test decrementing attribute level for numbers again + curriculum.decrement_attr_level("numbers") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_numbers == 10 and partially_decreased_cfg.max_numbers == 100 + assert partially_decreased_cfg.min_decimals == 0 and partially_decreased_cfg.max_decimals == 4 + assert partially_decreased_cfg.min_value == -10_000 and partially_decreased_cfg.max_value == 10_000