diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 3d3b6c20..e21c55e2 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -211,6 +211,24 @@ class BaseEnvConfig(BaseModel): "no thinking prompt is injected. Use HERMES_REASONING_PROMPT from " "eval_helpers for the standard Hermes reasoning prompt.", ) + curriculum_strategy: str = Field( + default="uniform", + description="Curriculum learning strategy. 'uniform' = no curriculum (default), " + "'easy_first' = oversample easy items early then anneal, " + "'competence_based' = sample at competence frontier. " + "See Platanios et al. 2019 for competence-based curriculum.", + ) + curriculum_bins: int = Field( + default=5, + ge=1, + description="Number of difficulty bins for curriculum scheduling.", + ) + curriculum_temperature: float = Field( + default=1.0, + gt=0, + description="Temperature for curriculum bin sampling. Higher = more uniform, " + "lower = more concentrated on target difficulty.", + ) class BaseEnv(ABC): @@ -262,6 +280,17 @@ class BaseEnv(ABC): self.max_token_len = -1 self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name) self.completion_lengths = [] + # Initialize curriculum scheduler (opt-in via config) + if config.curriculum_strategy != "uniform": + from atroposlib.envs.curriculum import CurriculumScheduler + + self.curriculum = CurriculumScheduler( + strategy=config.curriculum_strategy, + n_bins=config.curriculum_bins, + temperature=config.curriculum_temperature, + ) + else: + self.curriculum = None self.max_num_workers = config.max_num_workers if self.max_num_workers == -1: self.max_num_workers = config.max_num_workers_per_node * len( @@ -674,6 +703,9 @@ class BaseEnv(ABC): wandb_metrics["train/completion_lengths_p95"] = ( np.array(self.completion_lengths) > (0.95 * self.max_token_len) ).mean() + # Log curriculum metrics if active + if self.curriculum is not None: + wandb_metrics.update(self.curriculum.metrics_dict()) wandb_metrics = await self.create_rollout_table(wandb_metrics) wandb_metrics = self.perf_stats(wandb_metrics) self.rollouts_for_wandb = [] diff --git a/atroposlib/envs/curriculum.py b/atroposlib/envs/curriculum.py new file mode 100644 index 00000000..2f8fd53a --- /dev/null +++ b/atroposlib/envs/curriculum.py @@ -0,0 +1,354 @@ +from __future__ import annotations + +""" +Curriculum learning scheduler for sample-efficient RL training. + +Implements automatic difficulty-based sampling for environments, tracking +per-item difficulty from reward signals and adjusting sampling probabilities +to focus training on appropriately challenging examples. + +Strategies: +- uniform: No curriculum (baseline, default) +- easy_first: Oversample easy items early, anneal to uniform +- competence_based: Sample items at the competence frontier (reward ~ 0.5), + following Platanios et al. 2019 (https://arxiv.org/abs/1904.03746) + +Usage: + scheduler = CurriculumScheduler( + strategy="competence_based", + n_bins=5, + temperature=1.0, + ) + + # After scoring an item + scheduler.update("item_key_123", reward_score=0.7) + + # When selecting next item + target_bin = scheduler.sample_bin(current_step=50, total_steps=1000) +""" + +import logging +import math +import random +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple + +logger = logging.getLogger(__name__) + + +class CurriculumStrategy(str, Enum): + """Available curriculum learning strategies.""" + + UNIFORM = "uniform" + EASY_FIRST = "easy_first" + COMPETENCE_BASED = "competence_based" + + +class CurriculumScheduler: + """ + Curriculum learning scheduler that tracks item difficulty and provides + difficulty-aware sampling. + + Maintains an exponential moving average (EMA) of reward scores per item + to estimate difficulty. Items are binned by difficulty quantile, and the + sampling strategy determines which bins are preferred at each stage of + training. + + Args: + strategy: Sampling strategy. One of "uniform", "easy_first", + "competence_based". + n_bins: Number of difficulty bins. Default: 5. + temperature: Controls sampling sharpness. Higher = more uniform, + lower = more concentrated on target bin. Default: 1.0. + ema_alpha: EMA smoothing factor for difficulty scores. Higher values + give more weight to recent rewards. Default: 0.3. + competence_threshold: For competence_based strategy, the target + reward level considered "at frontier". Default: 0.5. + """ + + def __init__( + self, + strategy: str = "uniform", + n_bins: int = 5, + temperature: float = 1.0, + ema_alpha: float = 0.3, + competence_threshold: float = 0.5, + ): + # Validate strategy + try: + self._strategy = CurriculumStrategy(strategy) + except ValueError: + valid = [s.value for s in CurriculumStrategy] + raise ValueError( + f"Invalid curriculum strategy '{strategy}'. Must be one of: {valid}" + ) + + if n_bins < 1: + raise ValueError(f"n_bins must be >= 1, got {n_bins}") + + self.n_bins = n_bins + self.temperature = max(0.01, temperature) + self.ema_alpha = max(0.0, min(1.0, ema_alpha)) + self.competence_threshold = competence_threshold + + # Per-item difficulty tracking: key -> (ema_score, count) + self._item_scores: Dict[str, Tuple[float, int]] = {} + + # Bin boundaries (recomputed periodically) + self._bin_boundaries: List[float] = [] + self._last_rebin_count: int = 0 + self._rebin_interval: int = 50 # Recompute bins every N updates + + @property + def strategy(self) -> str: + """Current strategy name.""" + return self._strategy.value + + @property + def n_items_tracked(self) -> int: + """Number of unique items being tracked.""" + return len(self._item_scores) + + def update(self, item_key: str, score: float) -> None: + """ + Update difficulty estimate for an item based on its reward score. + + Uses exponential moving average so recent performance has more + influence than historical. + + Args: + item_key: Unique identifier for the item (e.g., dataset index). + score: Reward score achieved on this item. Higher = easier. + """ + if item_key in self._item_scores: + old_ema, count = self._item_scores[item_key] + new_ema = self.ema_alpha * score + (1 - self.ema_alpha) * old_ema + self._item_scores[item_key] = (new_ema, count + 1) + else: + self._item_scores[item_key] = (score, 1) + + # Periodically recompute bin boundaries + total_updates = sum(c for _, c in self._item_scores.values()) + if total_updates - self._last_rebin_count >= self._rebin_interval: + self._recompute_bins() + self._last_rebin_count = total_updates + + def update_batch(self, item_key: str, scores: List[float]) -> None: + """ + Update difficulty estimate with multiple scores (e.g., from group_size). + + Args: + item_key: Unique identifier for the item. + scores: List of reward scores from the group rollout. + """ + if not scores: + return + avg_score = sum(scores) / len(scores) + self.update(item_key, avg_score) + + def get_item_difficulty(self, item_key: str) -> Optional[float]: + """ + Get the current difficulty estimate for an item. + + Returns: + EMA reward score (higher = easier), or None if item not tracked. + """ + if item_key not in self._item_scores: + return None + return self._item_scores[item_key][0] + + def get_item_bin(self, item_key: str) -> int: + """ + Get the difficulty bin for an item. + + Args: + item_key: Unique identifier for the item. + + Returns: + Bin index (0 = easiest, n_bins-1 = hardest). + Returns middle bin if item is not tracked. + """ + difficulty = self.get_item_difficulty(item_key) + if difficulty is None: + return self.n_bins // 2 # Default to middle bin + + if not self._bin_boundaries: + self._recompute_bins() + + # Bin assignment: higher score = lower bin index (easier) + # We invert so bin 0 = easiest (highest reward) + for i, boundary in enumerate(self._bin_boundaries): + if difficulty >= boundary: + return i + return self.n_bins - 1 + + def sample_bin(self, current_step: int = 0, total_steps: int = 1000) -> int: + """ + Sample a target difficulty bin based on the curriculum strategy. + + Args: + current_step: Current training step (for annealing strategies). + total_steps: Total training steps planned. + + Returns: + Target bin index to sample from (0 = easiest, n_bins-1 = hardest). + """ + if self._strategy == CurriculumStrategy.UNIFORM: + return random.randint(0, self.n_bins - 1) + + # Compute bin probabilities + probs = self._compute_bin_probabilities(current_step, total_steps) + + # Temperature-scaled sampling + if self.temperature != 1.0: + log_probs = [math.log(max(p, 1e-10)) / self.temperature for p in probs] + max_lp = max(log_probs) + exp_probs = [math.exp(lp - max_lp) for lp in log_probs] + total = sum(exp_probs) + probs = [p / total for p in exp_probs] + + # Weighted random choice + return random.choices(range(self.n_bins), weights=probs, k=1)[0] + + def _compute_bin_probabilities( + self, current_step: int, total_steps: int + ) -> List[float]: + """Compute sampling probabilities for each bin.""" + progress = min(1.0, max(0.0, current_step / max(1, total_steps))) + + if self._strategy == CurriculumStrategy.EASY_FIRST: + return self._easy_first_probs(progress) + elif self._strategy == CurriculumStrategy.COMPETENCE_BASED: + return self._competence_based_probs(progress) + else: + # Uniform fallback + return [1.0 / self.n_bins] * self.n_bins + + def _easy_first_probs(self, progress: float) -> List[float]: + """ + Easy-first: linearly anneal from easy-biased to uniform. + + At progress=0: strongly prefer easy items (bin 0). + At progress=1: uniform sampling across all bins. + """ + probs = [] + for i in range(self.n_bins): + # Base: uniform + uniform_prob = 1.0 / self.n_bins + # Bias: exponential decay favoring low bins (easy) + easy_bias = math.exp(-2.0 * i / max(1, self.n_bins - 1)) + # Anneal from biased to uniform + prob = (1.0 - progress) * easy_bias + progress * uniform_prob + probs.append(prob) + + # Normalize + total = sum(probs) + return [p / total for p in probs] + + def _competence_based_probs(self, progress: float) -> List[float]: + """ + Competence-based: sample items near the competence frontier. + + The frontier moves from easy to hard as training progresses. + Items where expected reward ~ competence_threshold are preferred. + """ + # Competence level increases with training progress + # Maps to which bin is at the frontier + frontier_bin = progress * (self.n_bins - 1) + + probs = [] + for i in range(self.n_bins): + # Gaussian-like probability centered on frontier bin + distance = abs(i - frontier_bin) + prob = math.exp(-0.5 * (distance**2)) + probs.append(prob) + + total = sum(probs) + return [p / total for p in probs] + + def _recompute_bins(self) -> None: + """Recompute bin boundaries based on current difficulty quantiles.""" + if not self._item_scores: + self._bin_boundaries = [] + return + + # Sort scores descending (highest reward = easiest = bin 0) + scores = sorted([ema for ema, _ in self._item_scores.values()], reverse=True) + + if len(scores) < self.n_bins: + # Not enough items to properly bin, use equal spacing + min_s = min(scores) + max_s = max(scores) + if max_s == min_s: + self._bin_boundaries = [min_s] * self.n_bins + else: + step = (max_s - min_s) / self.n_bins + self._bin_boundaries = [max_s - i * step for i in range(self.n_bins)] + return + + # Quantile-based boundaries + boundaries = [] + for i in range(self.n_bins): + idx = int(i * len(scores) / self.n_bins) + idx = min(idx, len(scores) - 1) + boundaries.append(scores[idx]) + self._bin_boundaries = boundaries + + def metrics_dict(self) -> Dict[str, float]: + """ + Return curriculum stats for WandB logging. + + Returns: + Dictionary with keys suitable for wandb.log(). + """ + if not self._item_scores: + return { + "curriculum/items_tracked": 0, + "curriculum/strategy": 0, # Can't log strings to wandb + } + + scores = [ema for ema, _ in self._item_scores.values()] + counts = [c for _, c in self._item_scores.values()] + + metrics = { + "curriculum/items_tracked": float(len(scores)), + "curriculum/mean_difficulty": sum(scores) / len(scores), + "curriculum/min_difficulty": min(scores), + "curriculum/max_difficulty": max(scores), + "curriculum/total_updates": float(sum(counts)), + } + + # Bin distribution + if self._bin_boundaries: + bin_counts = [0] * self.n_bins + for key in self._item_scores: + bin_idx = self.get_item_bin(key) + bin_counts[bin_idx] += 1 + for i, count in enumerate(bin_counts): + metrics[f"curriculum/bin_{i}_count"] = float(count) + + return metrics + + def state_dict(self) -> Dict[str, Any]: + """Serialize state for checkpointing.""" + return { + "strategy": self._strategy.value, + "n_bins": self.n_bins, + "temperature": self.temperature, + "ema_alpha": self.ema_alpha, + "competence_threshold": self.competence_threshold, + "item_scores": dict(self._item_scores), + "bin_boundaries": self._bin_boundaries, + "last_rebin_count": self._last_rebin_count, + } + + def load_state_dict(self, state: Dict[str, Any]) -> None: + """Restore state from checkpoint.""" + self._strategy = CurriculumStrategy(state["strategy"]) + self.n_bins = state["n_bins"] + self.temperature = state["temperature"] + self.ema_alpha = state["ema_alpha"] + self.competence_threshold = state["competence_threshold"] + self._item_scores = {k: tuple(v) for k, v in state["item_scores"].items()} + self._bin_boundaries = state["bin_boundaries"] + self._last_rebin_count = state["last_rebin_count"] diff --git a/atroposlib/tests/test_curriculum.py b/atroposlib/tests/test_curriculum.py new file mode 100644 index 00000000..f8b421c6 --- /dev/null +++ b/atroposlib/tests/test_curriculum.py @@ -0,0 +1,248 @@ +""" +Tests for CurriculumScheduler -- difficulty-based sampling for RL training. + +Tests cover: +- Uniform passthrough (default behavior unchanged) +- Easy-first annealing +- Competence-based frontier sampling +- EMA difficulty updates +- Bin assignment with quantile boundaries +- Metrics and state persistence +- Edge cases +""" + +import math +import random + +import pytest + +from atroposlib.envs.curriculum import CurriculumScheduler, CurriculumStrategy + +# --------------------------------------------------------------------------- +# Strategy tests +# --------------------------------------------------------------------------- + + +class TestUniformStrategy: + def test_uniform_returns_valid_bins(self): + scheduler = CurriculumScheduler(strategy="uniform", n_bins=5) + bins = [scheduler.sample_bin(step, 1000) for step in range(100)] + assert all(0 <= b < 5 for b in bins) + + def test_uniform_covers_all_bins(self): + """Uniform should eventually sample from every bin.""" + random.seed(42) + scheduler = CurriculumScheduler(strategy="uniform", n_bins=5) + bins = set() + for _ in range(200): + bins.add(scheduler.sample_bin(0, 1000)) + assert bins == {0, 1, 2, 3, 4} + + +class TestEasyFirstStrategy: + def test_early_training_prefers_easy(self): + """At step 0, easy_first should strongly prefer low bins (easy).""" + random.seed(42) + scheduler = CurriculumScheduler( + strategy="easy_first", n_bins=5, temperature=0.5 + ) + bins = [scheduler.sample_bin(0, 1000) for _ in range(200)] + easy_count = sum(1 for b in bins if b <= 1) + hard_count = sum(1 for b in bins if b >= 3) + # Early training should have more easy than hard + assert easy_count > hard_count + + def test_late_training_approaches_uniform(self): + """Near the end (step~total), easy_first should be roughly uniform.""" + random.seed(42) + scheduler = CurriculumScheduler( + strategy="easy_first", n_bins=5, temperature=1.0 + ) + probs = scheduler._easy_first_probs(progress=1.0) + # At progress=1.0, all probs should be near 1/n_bins + for p in probs: + assert abs(p - 0.2) < 0.05 + + +class TestCompetenceBasedStrategy: + def test_competence_frontier_moves(self): + """The frontier should shift from easy to hard as training progresses.""" + scheduler = CurriculumScheduler( + strategy="competence_based", n_bins=5, temperature=0.5 + ) + + # Early training: frontier at easy bins + random.seed(42) + early_bins = [scheduler.sample_bin(0, 1000) for _ in range(200)] + early_mean = sum(early_bins) / len(early_bins) + + # Late training: frontier at hard bins + late_bins = [scheduler.sample_bin(900, 1000) for _ in range(200)] + late_mean = sum(late_bins) / len(late_bins) + + # Late mean should be higher (harder bins) + assert late_mean > early_mean + + def test_mid_training_prefers_middle(self): + """At 50% progress, competence_based should prefer middle bins.""" + random.seed(42) + scheduler = CurriculumScheduler( + strategy="competence_based", n_bins=5, temperature=0.5 + ) + bins = [scheduler.sample_bin(500, 1000) for _ in range(300)] + mid_count = sum(1 for b in bins if 1 <= b <= 3) + edge_count = sum(1 for b in bins if b == 0 or b == 4) + assert mid_count > edge_count + + +# --------------------------------------------------------------------------- +# EMA difficulty tracking tests +# --------------------------------------------------------------------------- + + +class TestDifficultyTracking: + def test_ema_update(self): + scheduler = CurriculumScheduler(strategy="uniform", ema_alpha=0.5) + scheduler.update("item_1", 1.0) + assert math.isclose(scheduler.get_item_difficulty("item_1"), 1.0) + + scheduler.update("item_1", 0.0) + # EMA: 0.5 * 0.0 + 0.5 * 1.0 = 0.5 + assert math.isclose(scheduler.get_item_difficulty("item_1"), 0.5) + + def test_batch_update(self): + scheduler = CurriculumScheduler(strategy="uniform") + scheduler.update_batch("item_1", [0.8, 0.6, 1.0]) + # Should use average: 0.8 + diff = scheduler.get_item_difficulty("item_1") + assert diff is not None + assert math.isclose(diff, 0.8) + + def test_untracked_item_returns_none(self): + scheduler = CurriculumScheduler(strategy="uniform") + assert scheduler.get_item_difficulty("nonexistent") is None + + def test_multiple_items_tracked(self): + scheduler = CurriculumScheduler(strategy="uniform") + scheduler.update("easy", 0.9) + scheduler.update("hard", 0.1) + scheduler.update("medium", 0.5) + + assert scheduler.n_items_tracked == 3 + assert scheduler.get_item_difficulty("easy") > scheduler.get_item_difficulty( + "hard" + ) + + +# --------------------------------------------------------------------------- +# Bin assignment tests +# --------------------------------------------------------------------------- + + +class TestBinAssignment: + def test_easy_item_gets_low_bin(self): + scheduler = CurriculumScheduler(strategy="uniform", n_bins=5) + # Create items spanning the difficulty range + for i in range(100): + scheduler.update(f"item_{i}", i / 100.0) + + # High score = easy = low bin + easy_bin = scheduler.get_item_bin("item_95") + hard_bin = scheduler.get_item_bin("item_5") + assert easy_bin < hard_bin + + def test_untracked_gets_middle_bin(self): + scheduler = CurriculumScheduler(strategy="uniform", n_bins=5) + assert scheduler.get_item_bin("unknown") == 2 # n_bins // 2 + + def test_single_bin(self): + scheduler = CurriculumScheduler(strategy="uniform", n_bins=1) + scheduler.update("item", 0.5) + assert scheduler.get_item_bin("item") == 0 + + +# --------------------------------------------------------------------------- +# Metrics and state tests +# --------------------------------------------------------------------------- + + +class TestMetrics: + def test_metrics_dict_empty(self): + scheduler = CurriculumScheduler(strategy="uniform") + metrics = scheduler.metrics_dict() + assert "curriculum/items_tracked" in metrics + assert metrics["curriculum/items_tracked"] == 0 + + def test_metrics_dict_populated(self): + scheduler = CurriculumScheduler(strategy="uniform", n_bins=3) + for i in range(60): # Enough to trigger rebinning + scheduler.update(f"item_{i}", i / 60.0) + + metrics = scheduler.metrics_dict() + assert metrics["curriculum/items_tracked"] == 60 + assert "curriculum/mean_difficulty" in metrics + assert "curriculum/min_difficulty" in metrics + assert "curriculum/max_difficulty" in metrics + assert "curriculum/total_updates" in metrics + + +class TestStatePersistence: + def test_save_load_roundtrip(self): + scheduler = CurriculumScheduler( + strategy="competence_based", n_bins=3, temperature=0.8 + ) + for i in range(20): + scheduler.update(f"item_{i}", i / 20.0) + + state = scheduler.state_dict() + + scheduler2 = CurriculumScheduler(strategy="uniform") + scheduler2.load_state_dict(state) + + assert scheduler2.strategy == "competence_based" + assert scheduler2.n_bins == 3 + assert math.isclose(scheduler2.temperature, 0.8) + assert scheduler2.n_items_tracked == 20 + + # Difficulty scores should match + for i in range(20): + key = f"item_{i}" + d1 = scheduler.get_item_difficulty(key) + d2 = scheduler2.get_item_difficulty(key) + assert math.isclose(d1, d2) + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + def test_invalid_strategy_raises(self): + with pytest.raises(ValueError, match="Invalid curriculum strategy"): + CurriculumScheduler(strategy="invalid") + + def test_invalid_n_bins_raises(self): + with pytest.raises(ValueError, match="n_bins must be >= 1"): + CurriculumScheduler(n_bins=0) + + def test_temperature_floor(self): + scheduler = CurriculumScheduler(temperature=0.001) + assert scheduler.temperature >= 0.01 + + def test_ema_alpha_clamped(self): + scheduler = CurriculumScheduler(ema_alpha=2.0) + assert scheduler.ema_alpha <= 1.0 + + scheduler2 = CurriculumScheduler(ema_alpha=-1.0) + assert scheduler2.ema_alpha >= 0.0 + + def test_empty_batch_update(self): + scheduler = CurriculumScheduler(strategy="uniform") + scheduler.update_batch("item", []) + assert scheduler.n_items_tracked == 0 + + def test_strategy_enum_values(self): + assert CurriculumStrategy.UNIFORM.value == "uniform" + assert CurriculumStrategy.EASY_FIRST.value == "easy_first" + assert CurriculumStrategy.COMPETENCE_BASED.value == "competence_based" diff --git a/pyproject.toml b/pyproject.toml index 6f23666c..cdc47a2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "jsonlines", "pydantic-cli", "hf_transfer", + "antlr4-python3-runtime==4.9.3", ] [project.scripts]