feat: add curriculum learning scheduler for sample-efficient RL training

Add CurriculumScheduler to atroposlib/envs/ with:
- EMA-based per-item difficulty tracking from reward signals
- Quantile-based difficulty binning (configurable N bins)
- Three sampling strategies: uniform, easy_first, competence_based
- Competence-based strategy cites Platanios et al. 2019
- Opt-in integration in BaseEnv via 3 config fields
- WandB metrics for difficulty distribution tracking
- Checkpoint save/load support

22/22 tests passing.
This commit is contained in:
RUFFY-369 2026-03-28 03:39:01 +05:30
parent c421582b6f
commit 01da524b6b
3 changed files with 638 additions and 0 deletions

View file

@ -0,0 +1,248 @@
"""
Tests for CurriculumScheduler -- difficulty-based sampling for RL training.
Tests cover:
- Uniform passthrough (default behavior unchanged)
- Easy-first annealing
- Competence-based frontier sampling
- EMA difficulty updates
- Bin assignment with quantile boundaries
- Metrics and state persistence
- Edge cases
"""
import math
import random
from typing import Dict, List
import pytest
from atroposlib.envs.curriculum import CurriculumScheduler, CurriculumStrategy
# ---------------------------------------------------------------------------
# Strategy tests
# ---------------------------------------------------------------------------
class TestUniformStrategy:
def test_uniform_returns_valid_bins(self):
scheduler = CurriculumScheduler(strategy="uniform", n_bins=5)
bins = [scheduler.sample_bin(step, 1000) for step in range(100)]
assert all(0 <= b < 5 for b in bins)
def test_uniform_covers_all_bins(self):
"""Uniform should eventually sample from every bin."""
random.seed(42)
scheduler = CurriculumScheduler(strategy="uniform", n_bins=5)
bins = set()
for _ in range(200):
bins.add(scheduler.sample_bin(0, 1000))
assert bins == {0, 1, 2, 3, 4}
class TestEasyFirstStrategy:
def test_early_training_prefers_easy(self):
"""At step 0, easy_first should strongly prefer low bins (easy)."""
random.seed(42)
scheduler = CurriculumScheduler(
strategy="easy_first", n_bins=5, temperature=0.5
)
bins = [scheduler.sample_bin(0, 1000) for _ in range(200)]
easy_count = sum(1 for b in bins if b <= 1)
hard_count = sum(1 for b in bins if b >= 3)
# Early training should have more easy than hard
assert easy_count > hard_count
def test_late_training_approaches_uniform(self):
"""Near the end (step~total), easy_first should be roughly uniform."""
random.seed(42)
scheduler = CurriculumScheduler(
strategy="easy_first", n_bins=5, temperature=1.0
)
probs = scheduler._easy_first_probs(progress=1.0)
# At progress=1.0, all probs should be near 1/n_bins
for p in probs:
assert abs(p - 0.2) < 0.05
class TestCompetenceBasedStrategy:
def test_competence_frontier_moves(self):
"""The frontier should shift from easy to hard as training progresses."""
scheduler = CurriculumScheduler(
strategy="competence_based", n_bins=5, temperature=0.5
)
# Early training: frontier at easy bins
random.seed(42)
early_bins = [scheduler.sample_bin(0, 1000) for _ in range(200)]
early_mean = sum(early_bins) / len(early_bins)
# Late training: frontier at hard bins
late_bins = [scheduler.sample_bin(900, 1000) for _ in range(200)]
late_mean = sum(late_bins) / len(late_bins)
# Late mean should be higher (harder bins)
assert late_mean > early_mean
def test_mid_training_prefers_middle(self):
"""At 50% progress, competence_based should prefer middle bins."""
random.seed(42)
scheduler = CurriculumScheduler(
strategy="competence_based", n_bins=5, temperature=0.5
)
bins = [scheduler.sample_bin(500, 1000) for _ in range(300)]
mid_count = sum(1 for b in bins if 1 <= b <= 3)
edge_count = sum(1 for b in bins if b == 0 or b == 4)
assert mid_count > edge_count
# ---------------------------------------------------------------------------
# EMA difficulty tracking tests
# ---------------------------------------------------------------------------
class TestDifficultyTracking:
def test_ema_update(self):
scheduler = CurriculumScheduler(strategy="uniform", ema_alpha=0.5)
scheduler.update("item_1", 1.0)
assert math.isclose(scheduler.get_item_difficulty("item_1"), 1.0)
scheduler.update("item_1", 0.0)
# EMA: 0.5 * 0.0 + 0.5 * 1.0 = 0.5
assert math.isclose(scheduler.get_item_difficulty("item_1"), 0.5)
def test_batch_update(self):
scheduler = CurriculumScheduler(strategy="uniform")
scheduler.update_batch("item_1", [0.8, 0.6, 1.0])
# Should use average: 0.8
diff = scheduler.get_item_difficulty("item_1")
assert diff is not None
assert math.isclose(diff, 0.8)
def test_untracked_item_returns_none(self):
scheduler = CurriculumScheduler(strategy="uniform")
assert scheduler.get_item_difficulty("nonexistent") is None
def test_multiple_items_tracked(self):
scheduler = CurriculumScheduler(strategy="uniform")
scheduler.update("easy", 0.9)
scheduler.update("hard", 0.1)
scheduler.update("medium", 0.5)
assert scheduler.n_items_tracked == 3
assert scheduler.get_item_difficulty("easy") > scheduler.get_item_difficulty("hard")
# ---------------------------------------------------------------------------
# Bin assignment tests
# ---------------------------------------------------------------------------
class TestBinAssignment:
def test_easy_item_gets_low_bin(self):
scheduler = CurriculumScheduler(strategy="uniform", n_bins=5)
# Create items spanning the difficulty range
for i in range(100):
scheduler.update(f"item_{i}", i / 100.0)
# High score = easy = low bin
easy_bin = scheduler.get_item_bin("item_95")
hard_bin = scheduler.get_item_bin("item_5")
assert easy_bin < hard_bin
def test_untracked_gets_middle_bin(self):
scheduler = CurriculumScheduler(strategy="uniform", n_bins=5)
assert scheduler.get_item_bin("unknown") == 2 # n_bins // 2
def test_single_bin(self):
scheduler = CurriculumScheduler(strategy="uniform", n_bins=1)
scheduler.update("item", 0.5)
assert scheduler.get_item_bin("item") == 0
# ---------------------------------------------------------------------------
# Metrics and state tests
# ---------------------------------------------------------------------------
class TestMetrics:
def test_metrics_dict_empty(self):
scheduler = CurriculumScheduler(strategy="uniform")
metrics = scheduler.metrics_dict()
assert "curriculum/items_tracked" in metrics
assert metrics["curriculum/items_tracked"] == 0
def test_metrics_dict_populated(self):
scheduler = CurriculumScheduler(strategy="uniform", n_bins=3)
for i in range(60): # Enough to trigger rebinning
scheduler.update(f"item_{i}", i / 60.0)
metrics = scheduler.metrics_dict()
assert metrics["curriculum/items_tracked"] == 60
assert "curriculum/mean_difficulty" in metrics
assert "curriculum/min_difficulty" in metrics
assert "curriculum/max_difficulty" in metrics
assert "curriculum/total_updates" in metrics
class TestStatePersistence:
def test_save_load_roundtrip(self):
scheduler = CurriculumScheduler(
strategy="competence_based", n_bins=3, temperature=0.8
)
for i in range(20):
scheduler.update(f"item_{i}", i / 20.0)
state = scheduler.state_dict()
scheduler2 = CurriculumScheduler(strategy="uniform")
scheduler2.load_state_dict(state)
assert scheduler2.strategy == "competence_based"
assert scheduler2.n_bins == 3
assert math.isclose(scheduler2.temperature, 0.8)
assert scheduler2.n_items_tracked == 20
# Difficulty scores should match
for i in range(20):
key = f"item_{i}"
d1 = scheduler.get_item_difficulty(key)
d2 = scheduler2.get_item_difficulty(key)
assert math.isclose(d1, d2)
# ---------------------------------------------------------------------------
# Edge cases
# ---------------------------------------------------------------------------
class TestEdgeCases:
def test_invalid_strategy_raises(self):
with pytest.raises(ValueError, match="Invalid curriculum strategy"):
CurriculumScheduler(strategy="invalid")
def test_invalid_n_bins_raises(self):
with pytest.raises(ValueError, match="n_bins must be >= 1"):
CurriculumScheduler(n_bins=0)
def test_temperature_floor(self):
scheduler = CurriculumScheduler(temperature=0.001)
assert scheduler.temperature >= 0.01
def test_ema_alpha_clamped(self):
scheduler = CurriculumScheduler(ema_alpha=2.0)
assert scheduler.ema_alpha <= 1.0
scheduler2 = CurriculumScheduler(ema_alpha=-1.0)
assert scheduler2.ema_alpha >= 0.0
def test_empty_batch_update(self):
scheduler = CurriculumScheduler(strategy="uniform")
scheduler.update_batch("item", [])
assert scheduler.n_items_tracked == 0
def test_strategy_enum_values(self):
assert CurriculumStrategy.UNIFORM.value == "uniform"
assert CurriculumStrategy.EASY_FIRST.value == "easy_first"
assert CurriculumStrategy.COMPETENCE_BASED.value == "competence_based"