mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-05-03 17:53:26 +00:00
number filtering curriculum (#333)
This commit is contained in:
parent
f14662e213
commit
aa6ccf1946
3 changed files with 81 additions and 3 deletions
|
|
@ -22,7 +22,7 @@ from .jugs import JugsConfig, JugsDataset
|
||||||
from .letter_counting import LetterCountingConfig, LetterCountingCurriculum, LetterCountingDataset
|
from .letter_counting import LetterCountingConfig, LetterCountingCurriculum, LetterCountingDataset
|
||||||
from .letter_jumble import LetterJumbleConfig, LetterJumbleCurriculum, LetterJumbleDataset
|
from .letter_jumble import LetterJumbleConfig, LetterJumbleCurriculum, LetterJumbleDataset
|
||||||
from .manipulate_matrix import ManipulateMatrixConfig, ManipulateMatrixCurriculum, ManipulateMatrixDataset
|
from .manipulate_matrix import ManipulateMatrixConfig, ManipulateMatrixCurriculum, ManipulateMatrixDataset
|
||||||
from .number_filtering import NumberFilteringConfig, NumberFilteringDataset
|
from .number_filtering import NumberFilteringConfig, NumberFilteringCurriculum, NumberFilteringDataset
|
||||||
from .number_sorting import NumberSortingConfig, NumberSortingCurriculum, NumberSortingDataset
|
from .number_sorting import NumberSortingConfig, NumberSortingCurriculum, NumberSortingDataset
|
||||||
from .palindrome_generation import PalindromeConfig, PalindromeCurriculum, PalindromeDataset
|
from .palindrome_generation import PalindromeConfig, PalindromeCurriculum, PalindromeDataset
|
||||||
from .palindrome_partitioning import (
|
from .palindrome_partitioning import (
|
||||||
|
|
@ -73,6 +73,7 @@ __all__ = [
|
||||||
"LetterJumbleCurriculum",
|
"LetterJumbleCurriculum",
|
||||||
"NumberFilteringConfig",
|
"NumberFilteringConfig",
|
||||||
"NumberFilteringDataset",
|
"NumberFilteringDataset",
|
||||||
|
"NumberFilteringCurriculum",
|
||||||
"NumberSortingConfig",
|
"NumberSortingConfig",
|
||||||
"NumberSortingDataset",
|
"NumberSortingDataset",
|
||||||
"NumberSortingCurriculum",
|
"NumberSortingCurriculum",
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||||
from ..factory import ProceduralDataset, register_dataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -94,8 +95,52 @@ class NumberFilteringDataset(ProceduralDataset):
|
||||||
"filter_value": filter_str,
|
"filter_value": filter_str,
|
||||||
"operation": f"{keep_remove}_{larger_smaller}",
|
"operation": f"{keep_remove}_{larger_smaller}",
|
||||||
"result": result_strs,
|
"result": result_strs,
|
||||||
|
"difficulty": {
|
||||||
|
"numbers": len(numbers),
|
||||||
|
"decimals": (self.config.min_decimals, self.config.max_decimals),
|
||||||
|
"value": (self.config.min_value, self.config.max_value),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
register_dataset("number_filtering", NumberFilteringDataset, NumberFilteringConfig)
|
class NumberFilteringCurriculum(BaseCurriculum):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(NumberFilteringCurriculum.__name__, NumberFilteringConfig)
|
||||||
|
|
||||||
|
# Define attributes
|
||||||
|
self._define_attributes(
|
||||||
|
RangeAttributeDefinition(
|
||||||
|
name="numbers",
|
||||||
|
levels=[10, 100, 500, 1000],
|
||||||
|
default_level=1,
|
||||||
|
description="How many numbers to sort",
|
||||||
|
attr_type=AttributeType.APPEND,
|
||||||
|
min_value=2,
|
||||||
|
lower_field_name="min_numbers",
|
||||||
|
upper_field_name="max_numbers",
|
||||||
|
),
|
||||||
|
RangeAttributeDefinition(
|
||||||
|
name="decimals",
|
||||||
|
levels=[0, 2, 4, 6],
|
||||||
|
default_level=1,
|
||||||
|
description="Number of decimal places",
|
||||||
|
attr_type=AttributeType.APPEND,
|
||||||
|
min_value=0,
|
||||||
|
lower_field_name="min_decimals",
|
||||||
|
upper_field_name="max_decimals",
|
||||||
|
),
|
||||||
|
RangeAttributeDefinition(
|
||||||
|
name="value",
|
||||||
|
levels=[-10_000, 10_000],
|
||||||
|
default_level=1,
|
||||||
|
description="Range of numbers to sort",
|
||||||
|
attr_type=AttributeType.APPEND,
|
||||||
|
min_value=-10_000,
|
||||||
|
lower_field_name="min_value",
|
||||||
|
upper_field_name="max_value",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_dataset("number_filtering", NumberFilteringDataset, NumberFilteringConfig, NumberFilteringCurriculum)
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,11 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.algorithmic.number_filtering import NumberFilteringConfig, NumberFilteringDataset
|
from reasoning_gym.algorithmic.number_filtering import (
|
||||||
|
NumberFilteringConfig,
|
||||||
|
NumberFilteringCurriculum,
|
||||||
|
NumberFilteringDataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_number_filtering_config_validation():
|
def test_number_filtering_config_validation():
|
||||||
|
|
@ -116,3 +120,31 @@ def test_number_filtering_precision():
|
||||||
# Check that string representations maintain precision
|
# Check that string representations maintain precision
|
||||||
for num in item["metadata"]["original_numbers"]:
|
for num in item["metadata"]["original_numbers"]:
|
||||||
assert len(num.split(".")[-1]) == 2
|
assert len(num.split(".")[-1]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_number_filtering_curriculum():
|
||||||
|
curriculum = NumberFilteringCurriculum()
|
||||||
|
|
||||||
|
base_value = {"size": 150, "seed": 1}
|
||||||
|
|
||||||
|
base_cfg: NumberFilteringConfig = curriculum.generate_configuration(base_value)
|
||||||
|
assert base_cfg.seed == 1
|
||||||
|
assert base_cfg.size == 150
|
||||||
|
assert base_cfg.min_numbers == 10 and base_cfg.max_numbers == 100
|
||||||
|
assert base_cfg.min_decimals == 0 and base_cfg.max_decimals == 2
|
||||||
|
assert base_cfg.min_value == -10_000 and base_cfg.max_value == 10_000
|
||||||
|
|
||||||
|
# test incrementing some attribute levels
|
||||||
|
curriculum.increment_attr_level("numbers")
|
||||||
|
curriculum.increment_attr_level("decimals")
|
||||||
|
increased_cfg = curriculum.generate_configuration(base_value)
|
||||||
|
assert increased_cfg.min_numbers == 10 and increased_cfg.max_numbers == 500
|
||||||
|
assert increased_cfg.min_decimals == 0 and increased_cfg.max_decimals == 4
|
||||||
|
assert increased_cfg.min_value == -10_000 and increased_cfg.max_value == 10_000
|
||||||
|
|
||||||
|
# test decrementing attribute level for numbers again
|
||||||
|
curriculum.decrement_attr_level("numbers")
|
||||||
|
partially_decreased_cfg = curriculum.generate_configuration(base_value)
|
||||||
|
assert partially_decreased_cfg.min_numbers == 10 and partially_decreased_cfg.max_numbers == 100
|
||||||
|
assert partially_decreased_cfg.min_decimals == 0 and partially_decreased_cfg.max_decimals == 4
|
||||||
|
assert partially_decreased_cfg.min_value == -10_000 and partially_decreased_cfg.max_value == 10_000
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue