mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-05-01 17:45:24 +00:00
Refactor Curriculum Attributes (#335)
* remove min_value from AttributeDefinition * remove type from AttributeDefinition * Add CurriculumContext * add ensure_interval option for RangeAttributes * docs: Add legend explaining curriculum indicators in dataset gallery * update GALLERY.md
This commit is contained in:
parent
4e7d9296ee
commit
d2c895f1d3
101 changed files with 286 additions and 677 deletions
|
|
@ -3,6 +3,7 @@ import io
|
|||
import pytest
|
||||
import yaml
|
||||
|
||||
from reasoning_gym.coaching.base_curriculum import DefaultCurriculumContext, RangeAttributeMode
|
||||
from reasoning_gym.coaching.curriculum_config import CurriculumAttributeConfig, CurriculumExperimentConfig
|
||||
from reasoning_gym.coaching.experiment import CurriculumExperiment
|
||||
|
||||
|
|
@ -16,7 +17,13 @@ def test_curriculum_experiment_initialization():
|
|||
)
|
||||
|
||||
# Create experiment
|
||||
experiment = CurriculumExperiment(name="test_experiment", config=config, size=10, seed=42)
|
||||
experiment = CurriculumExperiment(
|
||||
name="test_experiment",
|
||||
config=config,
|
||||
context=DefaultCurriculumContext(mode=RangeAttributeMode.INCLUSIVE),
|
||||
size=10,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
# Check experiment was created correctly
|
||||
assert experiment.name == "test_experiment"
|
||||
|
|
@ -66,7 +73,8 @@ def test_curriculum_experiment_mixed_levels():
|
|||
}
|
||||
)
|
||||
|
||||
experiment = CurriculumExperiment(name="test_experiment", config=config, size=10, seed=42)
|
||||
context = DefaultCurriculumContext(mode=RangeAttributeMode.UPPER_BOUND)
|
||||
experiment = CurriculumExperiment(name="test_experiment", config=config, context=context, size=10, seed=42)
|
||||
|
||||
curriculum = experiment.curricula["leg_counting"]
|
||||
assert curriculum.get_attr_level("num_animals") == 4 # Specific override
|
||||
|
|
@ -116,7 +124,8 @@ def test_curriculum_experiment_from_yaml():
|
|||
assert chain_sum.weight == 0.8
|
||||
|
||||
# Create experiment from the loaded config
|
||||
experiment = CurriculumExperiment(name="yaml_test", config=config, size=10, seed=42)
|
||||
context = DefaultCurriculumContext(mode=RangeAttributeMode.UPPER_BOUND)
|
||||
experiment = CurriculumExperiment(name="yaml_test", config=config, context=context, size=10, seed=42)
|
||||
|
||||
# Verify experiment was created correctly
|
||||
assert "leg_counting" in experiment.curricula
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue