This commit is contained in:
Prakarsh Kaushik 2026-03-30 21:08:11 +00:00 committed by GitHub
commit 1ebd8aee3e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 635 additions and 0 deletions

View file

@ -211,6 +211,24 @@ class BaseEnvConfig(BaseModel):
"no thinking prompt is injected. Use HERMES_REASONING_PROMPT from "
"eval_helpers for the standard Hermes reasoning prompt.",
)
curriculum_strategy: str = Field(
default="uniform",
description="Curriculum learning strategy. 'uniform' = no curriculum (default), "
"'easy_first' = oversample easy items early then anneal, "
"'competence_based' = sample at competence frontier. "
"See Platanios et al. 2019 for competence-based curriculum.",
)
curriculum_bins: int = Field(
default=5,
ge=1,
description="Number of difficulty bins for curriculum scheduling.",
)
curriculum_temperature: float = Field(
default=1.0,
gt=0,
description="Temperature for curriculum bin sampling. Higher = more uniform, "
"lower = more concentrated on target difficulty.",
)
class BaseEnv(ABC):
@ -262,6 +280,17 @@ class BaseEnv(ABC):
self.max_token_len = -1
self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name)
self.completion_lengths = []
# Initialize curriculum scheduler (opt-in via config)
if config.curriculum_strategy != "uniform":
from atroposlib.envs.curriculum import CurriculumScheduler
self.curriculum = CurriculumScheduler(
strategy=config.curriculum_strategy,
n_bins=config.curriculum_bins,
temperature=config.curriculum_temperature,
)
else:
self.curriculum = None
self.max_num_workers = config.max_num_workers
if self.max_num_workers == -1:
self.max_num_workers = config.max_num_workers_per_node * len(
@ -674,6 +703,9 @@ class BaseEnv(ABC):
wandb_metrics["train/completion_lengths_p95"] = (
np.array(self.completion_lengths) > (0.95 * self.max_token_len)
).mean()
# Log curriculum metrics if active
if self.curriculum is not None:
wandb_metrics.update(self.curriculum.metrics_dict())
wandb_metrics = await self.create_rollout_table(wandb_metrics)
wandb_metrics = self.perf_stats(wandb_metrics)
self.rollouts_for_wandb = []

View file

@ -0,0 +1,354 @@
from __future__ import annotations
"""
Curriculum learning scheduler for sample-efficient RL training.
Implements automatic difficulty-based sampling for environments, tracking
per-item difficulty from reward signals and adjusting sampling probabilities
to focus training on appropriately challenging examples.
Strategies:
- uniform: No curriculum (baseline, default)
- easy_first: Oversample easy items early, anneal to uniform
- competence_based: Sample items at the competence frontier (reward ~ 0.5),
following Platanios et al. 2019 (https://arxiv.org/abs/1904.03746)
Usage:
scheduler = CurriculumScheduler(
strategy="competence_based",
n_bins=5,
temperature=1.0,
)
# After scoring an item
scheduler.update("item_key_123", reward_score=0.7)
# When selecting next item
target_bin = scheduler.sample_bin(current_step=50, total_steps=1000)
"""
import logging
import math
import random
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
class CurriculumStrategy(str, Enum):
"""Available curriculum learning strategies."""
UNIFORM = "uniform"
EASY_FIRST = "easy_first"
COMPETENCE_BASED = "competence_based"
class CurriculumScheduler:
"""
Curriculum learning scheduler that tracks item difficulty and provides
difficulty-aware sampling.
Maintains an exponential moving average (EMA) of reward scores per item
to estimate difficulty. Items are binned by difficulty quantile, and the
sampling strategy determines which bins are preferred at each stage of
training.
Args:
strategy: Sampling strategy. One of "uniform", "easy_first",
"competence_based".
n_bins: Number of difficulty bins. Default: 5.
temperature: Controls sampling sharpness. Higher = more uniform,
lower = more concentrated on target bin. Default: 1.0.
ema_alpha: EMA smoothing factor for difficulty scores. Higher values
give more weight to recent rewards. Default: 0.3.
competence_threshold: For competence_based strategy, the target
reward level considered "at frontier". Default: 0.5.
"""
def __init__(
self,
strategy: str = "uniform",
n_bins: int = 5,
temperature: float = 1.0,
ema_alpha: float = 0.3,
competence_threshold: float = 0.5,
):
# Validate strategy
try:
self._strategy = CurriculumStrategy(strategy)
except ValueError:
valid = [s.value for s in CurriculumStrategy]
raise ValueError(
f"Invalid curriculum strategy '{strategy}'. Must be one of: {valid}"
)
if n_bins < 1:
raise ValueError(f"n_bins must be >= 1, got {n_bins}")
self.n_bins = n_bins
self.temperature = max(0.01, temperature)
self.ema_alpha = max(0.0, min(1.0, ema_alpha))
self.competence_threshold = competence_threshold
# Per-item difficulty tracking: key -> (ema_score, count)
self._item_scores: Dict[str, Tuple[float, int]] = {}
# Bin boundaries (recomputed periodically)
self._bin_boundaries: List[float] = []
self._last_rebin_count: int = 0
self._rebin_interval: int = 50 # Recompute bins every N updates
@property
def strategy(self) -> str:
"""Current strategy name."""
return self._strategy.value
@property
def n_items_tracked(self) -> int:
"""Number of unique items being tracked."""
return len(self._item_scores)
def update(self, item_key: str, score: float) -> None:
"""
Update difficulty estimate for an item based on its reward score.
Uses exponential moving average so recent performance has more
influence than historical.
Args:
item_key: Unique identifier for the item (e.g., dataset index).
score: Reward score achieved on this item. Higher = easier.
"""
if item_key in self._item_scores:
old_ema, count = self._item_scores[item_key]
new_ema = self.ema_alpha * score + (1 - self.ema_alpha) * old_ema
self._item_scores[item_key] = (new_ema, count + 1)
else:
self._item_scores[item_key] = (score, 1)
# Periodically recompute bin boundaries
total_updates = sum(c for _, c in self._item_scores.values())
if total_updates - self._last_rebin_count >= self._rebin_interval:
self._recompute_bins()
self._last_rebin_count = total_updates
def update_batch(self, item_key: str, scores: List[float]) -> None:
"""
Update difficulty estimate with multiple scores (e.g., from group_size).
Args:
item_key: Unique identifier for the item.
scores: List of reward scores from the group rollout.
"""
if not scores:
return
avg_score = sum(scores) / len(scores)
self.update(item_key, avg_score)
def get_item_difficulty(self, item_key: str) -> Optional[float]:
"""
Get the current difficulty estimate for an item.
Returns:
EMA reward score (higher = easier), or None if item not tracked.
"""
if item_key not in self._item_scores:
return None
return self._item_scores[item_key][0]
def get_item_bin(self, item_key: str) -> int:
"""
Get the difficulty bin for an item.
Args:
item_key: Unique identifier for the item.
Returns:
Bin index (0 = easiest, n_bins-1 = hardest).
Returns middle bin if item is not tracked.
"""
difficulty = self.get_item_difficulty(item_key)
if difficulty is None:
return self.n_bins // 2 # Default to middle bin
if not self._bin_boundaries:
self._recompute_bins()
# Bin assignment: higher score = lower bin index (easier)
# We invert so bin 0 = easiest (highest reward)
for i, boundary in enumerate(self._bin_boundaries):
if difficulty >= boundary:
return i
return self.n_bins - 1
def sample_bin(self, current_step: int = 0, total_steps: int = 1000) -> int:
"""
Sample a target difficulty bin based on the curriculum strategy.
Args:
current_step: Current training step (for annealing strategies).
total_steps: Total training steps planned.
Returns:
Target bin index to sample from (0 = easiest, n_bins-1 = hardest).
"""
if self._strategy == CurriculumStrategy.UNIFORM:
return random.randint(0, self.n_bins - 1)
# Compute bin probabilities
probs = self._compute_bin_probabilities(current_step, total_steps)
# Temperature-scaled sampling
if self.temperature != 1.0:
log_probs = [math.log(max(p, 1e-10)) / self.temperature for p in probs]
max_lp = max(log_probs)
exp_probs = [math.exp(lp - max_lp) for lp in log_probs]
total = sum(exp_probs)
probs = [p / total for p in exp_probs]
# Weighted random choice
return random.choices(range(self.n_bins), weights=probs, k=1)[0]
def _compute_bin_probabilities(
self, current_step: int, total_steps: int
) -> List[float]:
"""Compute sampling probabilities for each bin."""
progress = min(1.0, max(0.0, current_step / max(1, total_steps)))
if self._strategy == CurriculumStrategy.EASY_FIRST:
return self._easy_first_probs(progress)
elif self._strategy == CurriculumStrategy.COMPETENCE_BASED:
return self._competence_based_probs(progress)
else:
# Uniform fallback
return [1.0 / self.n_bins] * self.n_bins
def _easy_first_probs(self, progress: float) -> List[float]:
"""
Easy-first: linearly anneal from easy-biased to uniform.
At progress=0: strongly prefer easy items (bin 0).
At progress=1: uniform sampling across all bins.
"""
probs = []
for i in range(self.n_bins):
# Base: uniform
uniform_prob = 1.0 / self.n_bins
# Bias: exponential decay favoring low bins (easy)
easy_bias = math.exp(-2.0 * i / max(1, self.n_bins - 1))
# Anneal from biased to uniform
prob = (1.0 - progress) * easy_bias + progress * uniform_prob
probs.append(prob)
# Normalize
total = sum(probs)
return [p / total for p in probs]
def _competence_based_probs(self, progress: float) -> List[float]:
"""
Competence-based: sample items near the competence frontier.
The frontier moves from easy to hard as training progresses.
Items where expected reward ~ competence_threshold are preferred.
"""
# Competence level increases with training progress
# Maps to which bin is at the frontier
frontier_bin = progress * (self.n_bins - 1)
probs = []
for i in range(self.n_bins):
# Gaussian-like probability centered on frontier bin
distance = abs(i - frontier_bin)
prob = math.exp(-0.5 * (distance**2))
probs.append(prob)
total = sum(probs)
return [p / total for p in probs]
def _recompute_bins(self) -> None:
"""Recompute bin boundaries based on current difficulty quantiles."""
if not self._item_scores:
self._bin_boundaries = []
return
# Sort scores descending (highest reward = easiest = bin 0)
scores = sorted([ema for ema, _ in self._item_scores.values()], reverse=True)
if len(scores) < self.n_bins:
# Not enough items to properly bin, use equal spacing
min_s = min(scores)
max_s = max(scores)
if max_s == min_s:
self._bin_boundaries = [min_s] * self.n_bins
else:
step = (max_s - min_s) / self.n_bins
self._bin_boundaries = [max_s - i * step for i in range(self.n_bins)]
return
# Quantile-based boundaries
boundaries = []
for i in range(self.n_bins):
idx = int(i * len(scores) / self.n_bins)
idx = min(idx, len(scores) - 1)
boundaries.append(scores[idx])
self._bin_boundaries = boundaries
def metrics_dict(self) -> Dict[str, float]:
"""
Return curriculum stats for WandB logging.
Returns:
Dictionary with keys suitable for wandb.log().
"""
if not self._item_scores:
return {
"curriculum/items_tracked": 0,
"curriculum/strategy": 0, # Can't log strings to wandb
}
scores = [ema for ema, _ in self._item_scores.values()]
counts = [c for _, c in self._item_scores.values()]
metrics = {
"curriculum/items_tracked": float(len(scores)),
"curriculum/mean_difficulty": sum(scores) / len(scores),
"curriculum/min_difficulty": min(scores),
"curriculum/max_difficulty": max(scores),
"curriculum/total_updates": float(sum(counts)),
}
# Bin distribution
if self._bin_boundaries:
bin_counts = [0] * self.n_bins
for key in self._item_scores:
bin_idx = self.get_item_bin(key)
bin_counts[bin_idx] += 1
for i, count in enumerate(bin_counts):
metrics[f"curriculum/bin_{i}_count"] = float(count)
return metrics
def state_dict(self) -> Dict[str, Any]:
"""Serialize state for checkpointing."""
return {
"strategy": self._strategy.value,
"n_bins": self.n_bins,
"temperature": self.temperature,
"ema_alpha": self.ema_alpha,
"competence_threshold": self.competence_threshold,
"item_scores": dict(self._item_scores),
"bin_boundaries": self._bin_boundaries,
"last_rebin_count": self._last_rebin_count,
}
def load_state_dict(self, state: Dict[str, Any]) -> None:
"""Restore state from checkpoint."""
self._strategy = CurriculumStrategy(state["strategy"])
self.n_bins = state["n_bins"]
self.temperature = state["temperature"]
self.ema_alpha = state["ema_alpha"]
self.competence_threshold = state["competence_threshold"]
self._item_scores = {k: tuple(v) for k, v in state["item_scores"].items()}
self._bin_boundaries = state["bin_boundaries"]
self._last_rebin_count = state["last_rebin_count"]

View file

@ -0,0 +1,248 @@
"""
Tests for CurriculumScheduler -- difficulty-based sampling for RL training.
Tests cover:
- Uniform passthrough (default behavior unchanged)
- Easy-first annealing
- Competence-based frontier sampling
- EMA difficulty updates
- Bin assignment with quantile boundaries
- Metrics and state persistence
- Edge cases
"""
import math
import random
import pytest
from atroposlib.envs.curriculum import CurriculumScheduler, CurriculumStrategy
# ---------------------------------------------------------------------------
# Strategy tests
# ---------------------------------------------------------------------------
class TestUniformStrategy:
def test_uniform_returns_valid_bins(self):
scheduler = CurriculumScheduler(strategy="uniform", n_bins=5)
bins = [scheduler.sample_bin(step, 1000) for step in range(100)]
assert all(0 <= b < 5 for b in bins)
def test_uniform_covers_all_bins(self):
"""Uniform should eventually sample from every bin."""
random.seed(42)
scheduler = CurriculumScheduler(strategy="uniform", n_bins=5)
bins = set()
for _ in range(200):
bins.add(scheduler.sample_bin(0, 1000))
assert bins == {0, 1, 2, 3, 4}
class TestEasyFirstStrategy:
def test_early_training_prefers_easy(self):
"""At step 0, easy_first should strongly prefer low bins (easy)."""
random.seed(42)
scheduler = CurriculumScheduler(
strategy="easy_first", n_bins=5, temperature=0.5
)
bins = [scheduler.sample_bin(0, 1000) for _ in range(200)]
easy_count = sum(1 for b in bins if b <= 1)
hard_count = sum(1 for b in bins if b >= 3)
# Early training should have more easy than hard
assert easy_count > hard_count
def test_late_training_approaches_uniform(self):
"""Near the end (step~total), easy_first should be roughly uniform."""
random.seed(42)
scheduler = CurriculumScheduler(
strategy="easy_first", n_bins=5, temperature=1.0
)
probs = scheduler._easy_first_probs(progress=1.0)
# At progress=1.0, all probs should be near 1/n_bins
for p in probs:
assert abs(p - 0.2) < 0.05
class TestCompetenceBasedStrategy:
def test_competence_frontier_moves(self):
"""The frontier should shift from easy to hard as training progresses."""
scheduler = CurriculumScheduler(
strategy="competence_based", n_bins=5, temperature=0.5
)
# Early training: frontier at easy bins
random.seed(42)
early_bins = [scheduler.sample_bin(0, 1000) for _ in range(200)]
early_mean = sum(early_bins) / len(early_bins)
# Late training: frontier at hard bins
late_bins = [scheduler.sample_bin(900, 1000) for _ in range(200)]
late_mean = sum(late_bins) / len(late_bins)
# Late mean should be higher (harder bins)
assert late_mean > early_mean
def test_mid_training_prefers_middle(self):
"""At 50% progress, competence_based should prefer middle bins."""
random.seed(42)
scheduler = CurriculumScheduler(
strategy="competence_based", n_bins=5, temperature=0.5
)
bins = [scheduler.sample_bin(500, 1000) for _ in range(300)]
mid_count = sum(1 for b in bins if 1 <= b <= 3)
edge_count = sum(1 for b in bins if b == 0 or b == 4)
assert mid_count > edge_count
# ---------------------------------------------------------------------------
# EMA difficulty tracking tests
# ---------------------------------------------------------------------------
class TestDifficultyTracking:
def test_ema_update(self):
scheduler = CurriculumScheduler(strategy="uniform", ema_alpha=0.5)
scheduler.update("item_1", 1.0)
assert math.isclose(scheduler.get_item_difficulty("item_1"), 1.0)
scheduler.update("item_1", 0.0)
# EMA: 0.5 * 0.0 + 0.5 * 1.0 = 0.5
assert math.isclose(scheduler.get_item_difficulty("item_1"), 0.5)
def test_batch_update(self):
scheduler = CurriculumScheduler(strategy="uniform")
scheduler.update_batch("item_1", [0.8, 0.6, 1.0])
# Should use average: 0.8
diff = scheduler.get_item_difficulty("item_1")
assert diff is not None
assert math.isclose(diff, 0.8)
def test_untracked_item_returns_none(self):
scheduler = CurriculumScheduler(strategy="uniform")
assert scheduler.get_item_difficulty("nonexistent") is None
def test_multiple_items_tracked(self):
scheduler = CurriculumScheduler(strategy="uniform")
scheduler.update("easy", 0.9)
scheduler.update("hard", 0.1)
scheduler.update("medium", 0.5)
assert scheduler.n_items_tracked == 3
assert scheduler.get_item_difficulty("easy") > scheduler.get_item_difficulty(
"hard"
)
# ---------------------------------------------------------------------------
# Bin assignment tests
# ---------------------------------------------------------------------------
class TestBinAssignment:
def test_easy_item_gets_low_bin(self):
scheduler = CurriculumScheduler(strategy="uniform", n_bins=5)
# Create items spanning the difficulty range
for i in range(100):
scheduler.update(f"item_{i}", i / 100.0)
# High score = easy = low bin
easy_bin = scheduler.get_item_bin("item_95")
hard_bin = scheduler.get_item_bin("item_5")
assert easy_bin < hard_bin
def test_untracked_gets_middle_bin(self):
scheduler = CurriculumScheduler(strategy="uniform", n_bins=5)
assert scheduler.get_item_bin("unknown") == 2 # n_bins // 2
def test_single_bin(self):
scheduler = CurriculumScheduler(strategy="uniform", n_bins=1)
scheduler.update("item", 0.5)
assert scheduler.get_item_bin("item") == 0
# ---------------------------------------------------------------------------
# Metrics and state tests
# ---------------------------------------------------------------------------
class TestMetrics:
def test_metrics_dict_empty(self):
scheduler = CurriculumScheduler(strategy="uniform")
metrics = scheduler.metrics_dict()
assert "curriculum/items_tracked" in metrics
assert metrics["curriculum/items_tracked"] == 0
def test_metrics_dict_populated(self):
scheduler = CurriculumScheduler(strategy="uniform", n_bins=3)
for i in range(60): # Enough to trigger rebinning
scheduler.update(f"item_{i}", i / 60.0)
metrics = scheduler.metrics_dict()
assert metrics["curriculum/items_tracked"] == 60
assert "curriculum/mean_difficulty" in metrics
assert "curriculum/min_difficulty" in metrics
assert "curriculum/max_difficulty" in metrics
assert "curriculum/total_updates" in metrics
class TestStatePersistence:
def test_save_load_roundtrip(self):
scheduler = CurriculumScheduler(
strategy="competence_based", n_bins=3, temperature=0.8
)
for i in range(20):
scheduler.update(f"item_{i}", i / 20.0)
state = scheduler.state_dict()
scheduler2 = CurriculumScheduler(strategy="uniform")
scheduler2.load_state_dict(state)
assert scheduler2.strategy == "competence_based"
assert scheduler2.n_bins == 3
assert math.isclose(scheduler2.temperature, 0.8)
assert scheduler2.n_items_tracked == 20
# Difficulty scores should match
for i in range(20):
key = f"item_{i}"
d1 = scheduler.get_item_difficulty(key)
d2 = scheduler2.get_item_difficulty(key)
assert math.isclose(d1, d2)
# ---------------------------------------------------------------------------
# Edge cases
# ---------------------------------------------------------------------------
class TestEdgeCases:
def test_invalid_strategy_raises(self):
with pytest.raises(ValueError, match="Invalid curriculum strategy"):
CurriculumScheduler(strategy="invalid")
def test_invalid_n_bins_raises(self):
with pytest.raises(ValueError, match="n_bins must be >= 1"):
CurriculumScheduler(n_bins=0)
def test_temperature_floor(self):
scheduler = CurriculumScheduler(temperature=0.001)
assert scheduler.temperature >= 0.01
def test_ema_alpha_clamped(self):
scheduler = CurriculumScheduler(ema_alpha=2.0)
assert scheduler.ema_alpha <= 1.0
scheduler2 = CurriculumScheduler(ema_alpha=-1.0)
assert scheduler2.ema_alpha >= 0.0
def test_empty_batch_update(self):
scheduler = CurriculumScheduler(strategy="uniform")
scheduler.update_batch("item", [])
assert scheduler.n_items_tracked == 0
def test_strategy_enum_values(self):
assert CurriculumStrategy.UNIFORM.value == "uniform"
assert CurriculumStrategy.EASY_FIRST.value == "easy_first"
assert CurriculumStrategy.COMPETENCE_BASED.value == "competence_based"

View file

@ -26,6 +26,7 @@ dependencies = [
"jsonlines",
"pydantic-cli",
"hf_transfer",
"antlr4-python3-runtime==4.9.3",
]
[project.scripts]