Refactor Curriculum Attributes (#335)

* remove min_value from AttributeDefinition
* remove type from AttributeDefinition
* Add CurriculumContext
* add ensure_interval option for RangeAttributes
* docs: Add legend explaining curriculum indicators in dataset gallery
* update GALLERY.md
This commit is contained in:
Andreas Köpf 2025-03-16 15:40:28 +01:00 committed by GitHub
parent 4e7d9296ee
commit d2c895f1d3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
101 changed files with 286 additions and 677 deletions

View file

@ -3,6 +3,7 @@ import io
import pytest
import yaml
from reasoning_gym.coaching.base_curriculum import DefaultCurriculumContext, RangeAttributeMode
from reasoning_gym.coaching.curriculum_config import CurriculumAttributeConfig, CurriculumExperimentConfig
from reasoning_gym.coaching.experiment import CurriculumExperiment
@ -16,7 +17,13 @@ def test_curriculum_experiment_initialization():
)
# Create experiment
experiment = CurriculumExperiment(name="test_experiment", config=config, size=10, seed=42)
experiment = CurriculumExperiment(
name="test_experiment",
config=config,
context=DefaultCurriculumContext(mode=RangeAttributeMode.INCLUSIVE),
size=10,
seed=42,
)
# Check experiment was created correctly
assert experiment.name == "test_experiment"
@ -66,7 +73,8 @@ def test_curriculum_experiment_mixed_levels():
}
)
experiment = CurriculumExperiment(name="test_experiment", config=config, size=10, seed=42)
context = DefaultCurriculumContext(mode=RangeAttributeMode.UPPER_BOUND)
experiment = CurriculumExperiment(name="test_experiment", config=config, context=context, size=10, seed=42)
curriculum = experiment.curricula["leg_counting"]
assert curriculum.get_attr_level("num_animals") == 4 # Specific override
@ -116,7 +124,8 @@ def test_curriculum_experiment_from_yaml():
assert chain_sum.weight == 0.8
# Create experiment from the loaded config
experiment = CurriculumExperiment(name="yaml_test", config=config, size=10, seed=42)
context = DefaultCurriculumContext(mode=RangeAttributeMode.UPPER_BOUND)
experiment = CurriculumExperiment(name="yaml_test", config=config, context=context, size=10, seed=42)
# Verify experiment was created correctly
assert "leg_counting" in experiment.curricula