mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
* remove min_value from AttributeDefinition * remove type from AttributeDefinition * Add CurriculumContext * add ensure_interval option for RangeAttributes * docs: Add legend explaining curriculum indicators in dataset gallery * update GALLERY.md
140 lines
4.8 KiB
Python
140 lines
4.8 KiB
Python
import io
|
|
|
|
import pytest
|
|
import yaml
|
|
|
|
from reasoning_gym.coaching.base_curriculum import DefaultCurriculumContext, RangeAttributeMode
|
|
from reasoning_gym.coaching.curriculum_config import CurriculumAttributeConfig, CurriculumExperimentConfig
|
|
from reasoning_gym.coaching.experiment import CurriculumExperiment
|
|
|
|
|
|
def test_curriculum_experiment_initialization():
|
|
"""Test basic initialization of CurriculumExperiment"""
|
|
|
|
# Create config with leg_counting dataset
|
|
config = CurriculumExperimentConfig(
|
|
curricula={"leg_counting": CurriculumAttributeConfig(attribute_levels={"num_animals": 2}, weight=1.0)}
|
|
)
|
|
|
|
# Create experiment
|
|
experiment = CurriculumExperiment(
|
|
name="test_experiment",
|
|
config=config,
|
|
context=DefaultCurriculumContext(mode=RangeAttributeMode.INCLUSIVE),
|
|
size=10,
|
|
seed=42,
|
|
)
|
|
|
|
# Check experiment was created correctly
|
|
assert experiment.name == "test_experiment"
|
|
assert "leg_counting" in experiment.curricula
|
|
assert "leg_counting" in experiment.composite.datasets
|
|
|
|
# Check curriculum was configured correctly
|
|
curriculum = experiment.curricula["leg_counting"]
|
|
assert curriculum.get_attr_level("num_animals") == 2
|
|
|
|
# Check dataset was created with correct config
|
|
dataset = experiment.composite.datasets["leg_counting"]
|
|
assert dataset.config.min_animals == 1
|
|
assert dataset.config.max_animals == 3
|
|
|
|
# Check we can get entries from the dataset
|
|
entry = experiment.get_dataset_entry(0)
|
|
assert "question" in entry
|
|
assert "answer" in entry
|
|
assert "metadata" in entry
|
|
assert entry["metadata"]["source_dataset"] == "leg_counting"
|
|
|
|
|
|
def test_curriculum_experiment_wildcard_level():
|
|
"""Test using "*" to set all attribute levels"""
|
|
|
|
config = CurriculumExperimentConfig(
|
|
curricula={"leg_counting": CurriculumAttributeConfig(attribute_levels={"*": 3}, weight=1.0)}
|
|
)
|
|
|
|
experiment = CurriculumExperiment(name="test_experiment", config=config, size=10, seed=42)
|
|
|
|
# Check all attributes were set to level 3
|
|
curriculum = experiment.curricula["leg_counting"]
|
|
for attr_name in curriculum.attributes:
|
|
assert curriculum.get_attr_level(attr_name) == 3
|
|
|
|
|
|
def test_curriculum_experiment_mixed_levels():
|
|
"""Test mixing "*" with specific attribute levels"""
|
|
|
|
config = CurriculumExperimentConfig(
|
|
curricula={
|
|
"leg_counting": CurriculumAttributeConfig(
|
|
attribute_levels={"*": 2, "num_animals": 4}, weight=1.0 # Should override the "*" level
|
|
)
|
|
}
|
|
)
|
|
|
|
context = DefaultCurriculumContext(mode=RangeAttributeMode.UPPER_BOUND)
|
|
experiment = CurriculumExperiment(name="test_experiment", config=config, context=context, size=10, seed=42)
|
|
|
|
curriculum = experiment.curricula["leg_counting"]
|
|
assert curriculum.get_attr_level("num_animals") == 4 # Specific override
|
|
|
|
|
|
def test_curriculum_experiment_from_yaml():
|
|
"""Test loading curriculum experiment config from YAML using a string stream"""
|
|
|
|
# Create a YAML string
|
|
yaml_content = """
|
|
curricula:
|
|
leg_counting:
|
|
attribute_levels:
|
|
"*": 2
|
|
num_animals: 4
|
|
weight: 1.5
|
|
chain_sum:
|
|
attribute_levels:
|
|
num_terms: 1
|
|
num_digits: 2
|
|
weight: 0.8
|
|
"""
|
|
|
|
# Use StringIO to create a file-like object from the string
|
|
from io import StringIO
|
|
|
|
yaml_stream = StringIO(yaml_content)
|
|
|
|
# Load config from YAML stream
|
|
config = CurriculumExperimentConfig.from_yaml_stream(yaml_stream)
|
|
|
|
# Verify config was loaded correctly
|
|
assert len(config.curricula) == 2
|
|
assert "leg_counting" in config.curricula
|
|
assert "chain_sum" in config.curricula
|
|
|
|
# Check leg_counting curriculum
|
|
leg_counting = config.curricula["leg_counting"]
|
|
assert leg_counting.attribute_levels["*"] == 2
|
|
assert leg_counting.attribute_levels["num_animals"] == 4
|
|
assert leg_counting.weight == 1.5
|
|
|
|
# Check chain_sum curriculum
|
|
chain_sum = config.curricula["chain_sum"]
|
|
assert chain_sum.attribute_levels["num_terms"] == 1
|
|
assert chain_sum.attribute_levels["num_digits"] == 2
|
|
assert chain_sum.weight == 0.8
|
|
|
|
# Create experiment from the loaded config
|
|
context = DefaultCurriculumContext(mode=RangeAttributeMode.UPPER_BOUND)
|
|
experiment = CurriculumExperiment(name="yaml_test", config=config, context=context, size=10, seed=42)
|
|
|
|
# Verify experiment was created correctly
|
|
assert "leg_counting" in experiment.curricula
|
|
assert "chain_sum" in experiment.curricula
|
|
|
|
# Check attribute levels were applied
|
|
leg_curriculum = experiment.curricula["leg_counting"]
|
|
assert leg_curriculum.get_attr_level("num_animals") == 4
|
|
|
|
chain_sum_curriculum = experiment.curricula["chain_sum"]
|
|
assert chain_sum_curriculum.get_attr_level("num_terms") == 1
|
|
assert chain_sum_curriculum.get_attr_level("num_digits") == 2
|