number filtering curriculum (#333)

This commit is contained in:
Zafir Stojanovski 2025-03-11 23:56:06 +01:00 committed by GitHub
parent f14662e213
commit aa6ccf1946
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 81 additions and 3 deletions

View file

@ -22,7 +22,7 @@ from .jugs import JugsConfig, JugsDataset
from .letter_counting import LetterCountingConfig, LetterCountingCurriculum, LetterCountingDataset from .letter_counting import LetterCountingConfig, LetterCountingCurriculum, LetterCountingDataset
from .letter_jumble import LetterJumbleConfig, LetterJumbleCurriculum, LetterJumbleDataset from .letter_jumble import LetterJumbleConfig, LetterJumbleCurriculum, LetterJumbleDataset
from .manipulate_matrix import ManipulateMatrixConfig, ManipulateMatrixCurriculum, ManipulateMatrixDataset from .manipulate_matrix import ManipulateMatrixConfig, ManipulateMatrixCurriculum, ManipulateMatrixDataset
from .number_filtering import NumberFilteringConfig, NumberFilteringDataset from .number_filtering import NumberFilteringConfig, NumberFilteringCurriculum, NumberFilteringDataset
from .number_sorting import NumberSortingConfig, NumberSortingCurriculum, NumberSortingDataset from .number_sorting import NumberSortingConfig, NumberSortingCurriculum, NumberSortingDataset
from .palindrome_generation import PalindromeConfig, PalindromeCurriculum, PalindromeDataset from .palindrome_generation import PalindromeConfig, PalindromeCurriculum, PalindromeDataset
from .palindrome_partitioning import ( from .palindrome_partitioning import (
@ -73,6 +73,7 @@ __all__ = [
"LetterJumbleCurriculum", "LetterJumbleCurriculum",
"NumberFilteringConfig", "NumberFilteringConfig",
"NumberFilteringDataset", "NumberFilteringDataset",
"NumberFilteringCurriculum",
"NumberSortingConfig", "NumberSortingConfig",
"NumberSortingDataset", "NumberSortingDataset",
"NumberSortingCurriculum", "NumberSortingCurriculum",

View file

@ -4,6 +4,7 @@ from dataclasses import dataclass
from random import Random from random import Random
from typing import Optional from typing import Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
@ -94,8 +95,52 @@ class NumberFilteringDataset(ProceduralDataset):
"filter_value": filter_str, "filter_value": filter_str,
"operation": f"{keep_remove}_{larger_smaller}", "operation": f"{keep_remove}_{larger_smaller}",
"result": result_strs, "result": result_strs,
"difficulty": {
"numbers": len(numbers),
"decimals": (self.config.min_decimals, self.config.max_decimals),
"value": (self.config.min_value, self.config.max_value),
},
}, },
} }
register_dataset("number_filtering", NumberFilteringDataset, NumberFilteringConfig) class NumberFilteringCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(NumberFilteringCurriculum.__name__, NumberFilteringConfig)
# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="numbers",
levels=[10, 100, 500, 1000],
default_level=1,
description="How many numbers to sort",
attr_type=AttributeType.APPEND,
min_value=2,
lower_field_name="min_numbers",
upper_field_name="max_numbers",
),
RangeAttributeDefinition(
name="decimals",
levels=[0, 2, 4, 6],
default_level=1,
description="Number of decimal places",
attr_type=AttributeType.APPEND,
min_value=0,
lower_field_name="min_decimals",
upper_field_name="max_decimals",
),
RangeAttributeDefinition(
name="value",
levels=[-10_000, 10_000],
default_level=1,
description="Range of numbers to sort",
attr_type=AttributeType.APPEND,
min_value=-10_000,
lower_field_name="min_value",
upper_field_name="max_value",
),
)
register_dataset("number_filtering", NumberFilteringDataset, NumberFilteringConfig, NumberFilteringCurriculum)

View file

@ -2,7 +2,11 @@
import pytest import pytest
from reasoning_gym.algorithmic.number_filtering import NumberFilteringConfig, NumberFilteringDataset from reasoning_gym.algorithmic.number_filtering import (
NumberFilteringConfig,
NumberFilteringCurriculum,
NumberFilteringDataset,
)
def test_number_filtering_config_validation(): def test_number_filtering_config_validation():
@ -116,3 +120,31 @@ def test_number_filtering_precision():
# Check that string representations maintain precision # Check that string representations maintain precision
for num in item["metadata"]["original_numbers"]: for num in item["metadata"]["original_numbers"]:
assert len(num.split(".")[-1]) == 2 assert len(num.split(".")[-1]) == 2
def test_number_filtering_curriculum():
curriculum = NumberFilteringCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: NumberFilteringConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_numbers == 10 and base_cfg.max_numbers == 100
assert base_cfg.min_decimals == 0 and base_cfg.max_decimals == 2
assert base_cfg.min_value == -10_000 and base_cfg.max_value == 10_000
# test incrementing some attribute levels
curriculum.increment_attr_level("numbers")
curriculum.increment_attr_level("decimals")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_numbers == 10 and increased_cfg.max_numbers == 500
assert increased_cfg.min_decimals == 0 and increased_cfg.max_decimals == 4
assert increased_cfg.min_value == -10_000 and increased_cfg.max_value == 10_000
# test decrementing attribute level for numbers again
curriculum.decrement_attr_level("numbers")
partially_decreased_cfg = curriculum.generate_configuration(base_value)
assert partially_decreased_cfg.min_numbers == 10 and partially_decreased_cfg.max_numbers == 100
assert partially_decreased_cfg.min_decimals == 0 and partially_decreased_cfg.max_decimals == 4
assert partially_decreased_cfg.min_value == -10_000 and partially_decreased_cfg.max_value == 10_000