Refactor Curriculum Attributes (#335)

* remove min_value from AttributeDefinition
* remove type from AttributeDefinition
* Add CurriculumContext
* add ensure_interval option for RangeAttributes
* docs: Add legend explaining curriculum indicators in dataset gallery
* update GALLERY.md
This commit is contained in:
Andreas Köpf 2025-03-16 15:40:28 +01:00 committed by GitHub
parent 4e7d9296ee
commit d2c895f1d3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
101 changed files with 286 additions and 677 deletions

View file

@ -1,5 +1,6 @@
import pytest
from reasoning_gym.coaching.base_curriculum import DefaultCurriculumContext, RangeAttributeMode
from reasoning_gym.games import FutoshikiConfig, FutoshikiDataset
@ -196,7 +197,8 @@ def test_futoshiki_curriculum():
base_value = {"size": 150, "seed": 1}
base_cfg: FutoshikiConfig = curriculum.generate_configuration(base_value)
context = DefaultCurriculumContext(mode=RangeAttributeMode.UPPER_BOUND)
base_cfg: FutoshikiConfig = curriculum.generate_configuration(base_value, context=context)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_board_size == 4 and base_cfg.max_board_size == 4
@ -205,46 +207,46 @@ def test_futoshiki_curriculum():
# Test incrementing attribute levels
curriculum.increment_attr_level("board_size")
curriculum.increment_attr_level("difficulty")
increased_cfg = curriculum.generate_configuration(base_value)
increased_cfg = curriculum.generate_configuration(base_value, context=context)
assert increased_cfg.min_board_size == 6 and increased_cfg.max_board_size == 6
assert increased_cfg.min_difficulty == 1 and increased_cfg.max_difficulty == 1
# Test incrementing again
curriculum.increment_attr_level("board_size")
curriculum.increment_attr_level("difficulty")
increased_cfg2 = curriculum.generate_configuration(base_value)
increased_cfg2 = curriculum.generate_configuration(base_value, context=context)
assert increased_cfg2.min_board_size == 7 and increased_cfg2.max_board_size == 7
assert increased_cfg2.min_difficulty == 2 and increased_cfg2.max_difficulty == 2
# Test incrementing to max levels
curriculum.increment_attr_level("board_size")
curriculum.increment_attr_level("difficulty")
max_cfg = curriculum.generate_configuration(base_value)
max_cfg = curriculum.generate_configuration(base_value, context=context)
assert max_cfg.min_board_size == 9 and max_cfg.max_board_size == 9
assert max_cfg.min_difficulty == 3 and max_cfg.max_difficulty == 3
# Test that we can't go beyond max levels
assert not curriculum.increment_attr_level("board_size")
assert not curriculum.increment_attr_level("difficulty")
still_max_cfg = curriculum.generate_configuration(base_value)
still_max_cfg = curriculum.generate_configuration(base_value, context=context)
assert still_max_cfg.min_board_size == 9 and still_max_cfg.max_board_size == 9
assert still_max_cfg.min_difficulty == 3 and still_max_cfg.max_difficulty == 3
# Test decrementing attribute levels
curriculum.decrement_attr_level("board_size")
curriculum.decrement_attr_level("difficulty")
decreased_cfg = curriculum.generate_configuration(base_value)
decreased_cfg = curriculum.generate_configuration(base_value, context=context)
assert decreased_cfg.min_board_size == 7 and decreased_cfg.max_board_size == 7
assert decreased_cfg.min_difficulty == 2 and decreased_cfg.max_difficulty == 2
# Test global level setting
curriculum.set_global_level(0)
global_lvl0_cfg = curriculum.generate_configuration(base_value)
global_lvl0_cfg = curriculum.generate_configuration(base_value, context=context)
assert global_lvl0_cfg.min_board_size == 4 and global_lvl0_cfg.max_board_size == 4
assert global_lvl0_cfg.min_difficulty == 0 and global_lvl0_cfg.max_difficulty == 0
# Test global level increment
curriculum.increment_global_level()
global_lvl1_cfg = curriculum.generate_configuration(base_value)
global_lvl1_cfg = curriculum.generate_configuration(base_value, context=context)
assert global_lvl1_cfg.min_board_size == 6 and global_lvl1_cfg.max_board_size == 6
assert global_lvl1_cfg.min_difficulty == 1 and global_lvl1_cfg.max_difficulty == 1