feat(env): Number Sorting Curriculum (#321)

* number sorting curriculum

* metadata
This commit is contained in:
Zafir Stojanovski 2025-03-11 00:18:20 +01:00 committed by GitHub
parent 105374183f
commit ad48c551f9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 86 additions and 9 deletions

View file

@ -2,7 +2,7 @@
import pytest
from reasoning_gym.algorithmic.number_sorting import NumberSortingConfig, NumberSortingDataset
from reasoning_gym.algorithmic.number_sorting import NumberSortingConfig, NumberSortingCurriculum, NumberSortingDataset
def test_number_sorting_config_validation():
@ -89,3 +89,31 @@ def test_number_sorting_dataset_iteration():
# Test multiple iterations yield same items
assert items == list(dataset)
def test_number_sorting_curriculum():
curriculum = NumberSortingCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: NumberSortingConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_numbers == 10 and base_cfg.max_numbers == 100
assert base_cfg.min_decimals == 0 and base_cfg.max_decimals == 2
assert base_cfg.min_value == -10_000 and base_cfg.max_value == 10_000
# test incrementing some attribute levels
curriculum.increment_attr_level("numbers")
curriculum.increment_attr_level("decimals")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_numbers == 10 and increased_cfg.max_numbers == 500
assert increased_cfg.min_decimals == 0 and increased_cfg.max_decimals == 4
assert increased_cfg.min_value == -10_000 and increased_cfg.max_value == 10_000
# test decrementing attribute level for numbers again
curriculum.decrement_attr_level("numbers")
partially_decreased_cfg = curriculum.generate_configuration(base_value)
assert partially_decreased_cfg.min_numbers == 10 and partially_decreased_cfg.max_numbers == 100
assert partially_decreased_cfg.min_decimals == 0 and partially_decreased_cfg.max_decimals == 4
assert partially_decreased_cfg.min_value == -10_000 and partially_decreased_cfg.max_value == 10_000