mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Merge 55ddbf49cf into c20c85256e
This commit is contained in:
commit
1ebd8aee3e
4 changed files with 635 additions and 0 deletions
|
|
@ -211,6 +211,24 @@ class BaseEnvConfig(BaseModel):
|
|||
"no thinking prompt is injected. Use HERMES_REASONING_PROMPT from "
|
||||
"eval_helpers for the standard Hermes reasoning prompt.",
|
||||
)
|
||||
curriculum_strategy: str = Field(
|
||||
default="uniform",
|
||||
description="Curriculum learning strategy. 'uniform' = no curriculum (default), "
|
||||
"'easy_first' = oversample easy items early then anneal, "
|
||||
"'competence_based' = sample at competence frontier. "
|
||||
"See Platanios et al. 2019 for competence-based curriculum.",
|
||||
)
|
||||
curriculum_bins: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
description="Number of difficulty bins for curriculum scheduling.",
|
||||
)
|
||||
curriculum_temperature: float = Field(
|
||||
default=1.0,
|
||||
gt=0,
|
||||
description="Temperature for curriculum bin sampling. Higher = more uniform, "
|
||||
"lower = more concentrated on target difficulty.",
|
||||
)
|
||||
|
||||
|
||||
class BaseEnv(ABC):
|
||||
|
|
@ -262,6 +280,17 @@ class BaseEnv(ABC):
|
|||
self.max_token_len = -1
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name)
|
||||
self.completion_lengths = []
|
||||
# Initialize curriculum scheduler (opt-in via config)
|
||||
if config.curriculum_strategy != "uniform":
|
||||
from atroposlib.envs.curriculum import CurriculumScheduler
|
||||
|
||||
self.curriculum = CurriculumScheduler(
|
||||
strategy=config.curriculum_strategy,
|
||||
n_bins=config.curriculum_bins,
|
||||
temperature=config.curriculum_temperature,
|
||||
)
|
||||
else:
|
||||
self.curriculum = None
|
||||
self.max_num_workers = config.max_num_workers
|
||||
if self.max_num_workers == -1:
|
||||
self.max_num_workers = config.max_num_workers_per_node * len(
|
||||
|
|
@ -674,6 +703,9 @@ class BaseEnv(ABC):
|
|||
wandb_metrics["train/completion_lengths_p95"] = (
|
||||
np.array(self.completion_lengths) > (0.95 * self.max_token_len)
|
||||
).mean()
|
||||
# Log curriculum metrics if active
|
||||
if self.curriculum is not None:
|
||||
wandb_metrics.update(self.curriculum.metrics_dict())
|
||||
wandb_metrics = await self.create_rollout_table(wandb_metrics)
|
||||
wandb_metrics = self.perf_stats(wandb_metrics)
|
||||
self.rollouts_for_wandb = []
|
||||
|
|
|
|||
354
atroposlib/envs/curriculum.py
Normal file
354
atroposlib/envs/curriculum.py
Normal file
|
|
@ -0,0 +1,354 @@
|
|||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
Curriculum learning scheduler for sample-efficient RL training.
|
||||
|
||||
Implements automatic difficulty-based sampling for environments, tracking
|
||||
per-item difficulty from reward signals and adjusting sampling probabilities
|
||||
to focus training on appropriately challenging examples.
|
||||
|
||||
Strategies:
|
||||
- uniform: No curriculum (baseline, default)
|
||||
- easy_first: Oversample easy items early, anneal to uniform
|
||||
- competence_based: Sample items at the competence frontier (reward ~ 0.5),
|
||||
following Platanios et al. 2019 (https://arxiv.org/abs/1904.03746)
|
||||
|
||||
Usage:
|
||||
scheduler = CurriculumScheduler(
|
||||
strategy="competence_based",
|
||||
n_bins=5,
|
||||
temperature=1.0,
|
||||
)
|
||||
|
||||
# After scoring an item
|
||||
scheduler.update("item_key_123", reward_score=0.7)
|
||||
|
||||
# When selecting next item
|
||||
target_bin = scheduler.sample_bin(current_step=50, total_steps=1000)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CurriculumStrategy(str, Enum):
|
||||
"""Available curriculum learning strategies."""
|
||||
|
||||
UNIFORM = "uniform"
|
||||
EASY_FIRST = "easy_first"
|
||||
COMPETENCE_BASED = "competence_based"
|
||||
|
||||
|
||||
class CurriculumScheduler:
|
||||
"""
|
||||
Curriculum learning scheduler that tracks item difficulty and provides
|
||||
difficulty-aware sampling.
|
||||
|
||||
Maintains an exponential moving average (EMA) of reward scores per item
|
||||
to estimate difficulty. Items are binned by difficulty quantile, and the
|
||||
sampling strategy determines which bins are preferred at each stage of
|
||||
training.
|
||||
|
||||
Args:
|
||||
strategy: Sampling strategy. One of "uniform", "easy_first",
|
||||
"competence_based".
|
||||
n_bins: Number of difficulty bins. Default: 5.
|
||||
temperature: Controls sampling sharpness. Higher = more uniform,
|
||||
lower = more concentrated on target bin. Default: 1.0.
|
||||
ema_alpha: EMA smoothing factor for difficulty scores. Higher values
|
||||
give more weight to recent rewards. Default: 0.3.
|
||||
competence_threshold: For competence_based strategy, the target
|
||||
reward level considered "at frontier". Default: 0.5.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
strategy: str = "uniform",
|
||||
n_bins: int = 5,
|
||||
temperature: float = 1.0,
|
||||
ema_alpha: float = 0.3,
|
||||
competence_threshold: float = 0.5,
|
||||
):
|
||||
# Validate strategy
|
||||
try:
|
||||
self._strategy = CurriculumStrategy(strategy)
|
||||
except ValueError:
|
||||
valid = [s.value for s in CurriculumStrategy]
|
||||
raise ValueError(
|
||||
f"Invalid curriculum strategy '{strategy}'. Must be one of: {valid}"
|
||||
)
|
||||
|
||||
if n_bins < 1:
|
||||
raise ValueError(f"n_bins must be >= 1, got {n_bins}")
|
||||
|
||||
self.n_bins = n_bins
|
||||
self.temperature = max(0.01, temperature)
|
||||
self.ema_alpha = max(0.0, min(1.0, ema_alpha))
|
||||
self.competence_threshold = competence_threshold
|
||||
|
||||
# Per-item difficulty tracking: key -> (ema_score, count)
|
||||
self._item_scores: Dict[str, Tuple[float, int]] = {}
|
||||
|
||||
# Bin boundaries (recomputed periodically)
|
||||
self._bin_boundaries: List[float] = []
|
||||
self._last_rebin_count: int = 0
|
||||
self._rebin_interval: int = 50 # Recompute bins every N updates
|
||||
|
||||
@property
|
||||
def strategy(self) -> str:
|
||||
"""Current strategy name."""
|
||||
return self._strategy.value
|
||||
|
||||
@property
|
||||
def n_items_tracked(self) -> int:
|
||||
"""Number of unique items being tracked."""
|
||||
return len(self._item_scores)
|
||||
|
||||
def update(self, item_key: str, score: float) -> None:
|
||||
"""
|
||||
Update difficulty estimate for an item based on its reward score.
|
||||
|
||||
Uses exponential moving average so recent performance has more
|
||||
influence than historical.
|
||||
|
||||
Args:
|
||||
item_key: Unique identifier for the item (e.g., dataset index).
|
||||
score: Reward score achieved on this item. Higher = easier.
|
||||
"""
|
||||
if item_key in self._item_scores:
|
||||
old_ema, count = self._item_scores[item_key]
|
||||
new_ema = self.ema_alpha * score + (1 - self.ema_alpha) * old_ema
|
||||
self._item_scores[item_key] = (new_ema, count + 1)
|
||||
else:
|
||||
self._item_scores[item_key] = (score, 1)
|
||||
|
||||
# Periodically recompute bin boundaries
|
||||
total_updates = sum(c for _, c in self._item_scores.values())
|
||||
if total_updates - self._last_rebin_count >= self._rebin_interval:
|
||||
self._recompute_bins()
|
||||
self._last_rebin_count = total_updates
|
||||
|
||||
def update_batch(self, item_key: str, scores: List[float]) -> None:
|
||||
"""
|
||||
Update difficulty estimate with multiple scores (e.g., from group_size).
|
||||
|
||||
Args:
|
||||
item_key: Unique identifier for the item.
|
||||
scores: List of reward scores from the group rollout.
|
||||
"""
|
||||
if not scores:
|
||||
return
|
||||
avg_score = sum(scores) / len(scores)
|
||||
self.update(item_key, avg_score)
|
||||
|
||||
def get_item_difficulty(self, item_key: str) -> Optional[float]:
|
||||
"""
|
||||
Get the current difficulty estimate for an item.
|
||||
|
||||
Returns:
|
||||
EMA reward score (higher = easier), or None if item not tracked.
|
||||
"""
|
||||
if item_key not in self._item_scores:
|
||||
return None
|
||||
return self._item_scores[item_key][0]
|
||||
|
||||
def get_item_bin(self, item_key: str) -> int:
|
||||
"""
|
||||
Get the difficulty bin for an item.
|
||||
|
||||
Args:
|
||||
item_key: Unique identifier for the item.
|
||||
|
||||
Returns:
|
||||
Bin index (0 = easiest, n_bins-1 = hardest).
|
||||
Returns middle bin if item is not tracked.
|
||||
"""
|
||||
difficulty = self.get_item_difficulty(item_key)
|
||||
if difficulty is None:
|
||||
return self.n_bins // 2 # Default to middle bin
|
||||
|
||||
if not self._bin_boundaries:
|
||||
self._recompute_bins()
|
||||
|
||||
# Bin assignment: higher score = lower bin index (easier)
|
||||
# We invert so bin 0 = easiest (highest reward)
|
||||
for i, boundary in enumerate(self._bin_boundaries):
|
||||
if difficulty >= boundary:
|
||||
return i
|
||||
return self.n_bins - 1
|
||||
|
||||
def sample_bin(self, current_step: int = 0, total_steps: int = 1000) -> int:
|
||||
"""
|
||||
Sample a target difficulty bin based on the curriculum strategy.
|
||||
|
||||
Args:
|
||||
current_step: Current training step (for annealing strategies).
|
||||
total_steps: Total training steps planned.
|
||||
|
||||
Returns:
|
||||
Target bin index to sample from (0 = easiest, n_bins-1 = hardest).
|
||||
"""
|
||||
if self._strategy == CurriculumStrategy.UNIFORM:
|
||||
return random.randint(0, self.n_bins - 1)
|
||||
|
||||
# Compute bin probabilities
|
||||
probs = self._compute_bin_probabilities(current_step, total_steps)
|
||||
|
||||
# Temperature-scaled sampling
|
||||
if self.temperature != 1.0:
|
||||
log_probs = [math.log(max(p, 1e-10)) / self.temperature for p in probs]
|
||||
max_lp = max(log_probs)
|
||||
exp_probs = [math.exp(lp - max_lp) for lp in log_probs]
|
||||
total = sum(exp_probs)
|
||||
probs = [p / total for p in exp_probs]
|
||||
|
||||
# Weighted random choice
|
||||
return random.choices(range(self.n_bins), weights=probs, k=1)[0]
|
||||
|
||||
def _compute_bin_probabilities(
|
||||
self, current_step: int, total_steps: int
|
||||
) -> List[float]:
|
||||
"""Compute sampling probabilities for each bin."""
|
||||
progress = min(1.0, max(0.0, current_step / max(1, total_steps)))
|
||||
|
||||
if self._strategy == CurriculumStrategy.EASY_FIRST:
|
||||
return self._easy_first_probs(progress)
|
||||
elif self._strategy == CurriculumStrategy.COMPETENCE_BASED:
|
||||
return self._competence_based_probs(progress)
|
||||
else:
|
||||
# Uniform fallback
|
||||
return [1.0 / self.n_bins] * self.n_bins
|
||||
|
||||
def _easy_first_probs(self, progress: float) -> List[float]:
|
||||
"""
|
||||
Easy-first: linearly anneal from easy-biased to uniform.
|
||||
|
||||
At progress=0: strongly prefer easy items (bin 0).
|
||||
At progress=1: uniform sampling across all bins.
|
||||
"""
|
||||
probs = []
|
||||
for i in range(self.n_bins):
|
||||
# Base: uniform
|
||||
uniform_prob = 1.0 / self.n_bins
|
||||
# Bias: exponential decay favoring low bins (easy)
|
||||
easy_bias = math.exp(-2.0 * i / max(1, self.n_bins - 1))
|
||||
# Anneal from biased to uniform
|
||||
prob = (1.0 - progress) * easy_bias + progress * uniform_prob
|
||||
probs.append(prob)
|
||||
|
||||
# Normalize
|
||||
total = sum(probs)
|
||||
return [p / total for p in probs]
|
||||
|
||||
def _competence_based_probs(self, progress: float) -> List[float]:
|
||||
"""
|
||||
Competence-based: sample items near the competence frontier.
|
||||
|
||||
The frontier moves from easy to hard as training progresses.
|
||||
Items where expected reward ~ competence_threshold are preferred.
|
||||
"""
|
||||
# Competence level increases with training progress
|
||||
# Maps to which bin is at the frontier
|
||||
frontier_bin = progress * (self.n_bins - 1)
|
||||
|
||||
probs = []
|
||||
for i in range(self.n_bins):
|
||||
# Gaussian-like probability centered on frontier bin
|
||||
distance = abs(i - frontier_bin)
|
||||
prob = math.exp(-0.5 * (distance**2))
|
||||
probs.append(prob)
|
||||
|
||||
total = sum(probs)
|
||||
return [p / total for p in probs]
|
||||
|
||||
def _recompute_bins(self) -> None:
|
||||
"""Recompute bin boundaries based on current difficulty quantiles."""
|
||||
if not self._item_scores:
|
||||
self._bin_boundaries = []
|
||||
return
|
||||
|
||||
# Sort scores descending (highest reward = easiest = bin 0)
|
||||
scores = sorted([ema for ema, _ in self._item_scores.values()], reverse=True)
|
||||
|
||||
if len(scores) < self.n_bins:
|
||||
# Not enough items to properly bin, use equal spacing
|
||||
min_s = min(scores)
|
||||
max_s = max(scores)
|
||||
if max_s == min_s:
|
||||
self._bin_boundaries = [min_s] * self.n_bins
|
||||
else:
|
||||
step = (max_s - min_s) / self.n_bins
|
||||
self._bin_boundaries = [max_s - i * step for i in range(self.n_bins)]
|
||||
return
|
||||
|
||||
# Quantile-based boundaries
|
||||
boundaries = []
|
||||
for i in range(self.n_bins):
|
||||
idx = int(i * len(scores) / self.n_bins)
|
||||
idx = min(idx, len(scores) - 1)
|
||||
boundaries.append(scores[idx])
|
||||
self._bin_boundaries = boundaries
|
||||
|
||||
def metrics_dict(self) -> Dict[str, float]:
|
||||
"""
|
||||
Return curriculum stats for WandB logging.
|
||||
|
||||
Returns:
|
||||
Dictionary with keys suitable for wandb.log().
|
||||
"""
|
||||
if not self._item_scores:
|
||||
return {
|
||||
"curriculum/items_tracked": 0,
|
||||
"curriculum/strategy": 0, # Can't log strings to wandb
|
||||
}
|
||||
|
||||
scores = [ema for ema, _ in self._item_scores.values()]
|
||||
counts = [c for _, c in self._item_scores.values()]
|
||||
|
||||
metrics = {
|
||||
"curriculum/items_tracked": float(len(scores)),
|
||||
"curriculum/mean_difficulty": sum(scores) / len(scores),
|
||||
"curriculum/min_difficulty": min(scores),
|
||||
"curriculum/max_difficulty": max(scores),
|
||||
"curriculum/total_updates": float(sum(counts)),
|
||||
}
|
||||
|
||||
# Bin distribution
|
||||
if self._bin_boundaries:
|
||||
bin_counts = [0] * self.n_bins
|
||||
for key in self._item_scores:
|
||||
bin_idx = self.get_item_bin(key)
|
||||
bin_counts[bin_idx] += 1
|
||||
for i, count in enumerate(bin_counts):
|
||||
metrics[f"curriculum/bin_{i}_count"] = float(count)
|
||||
|
||||
return metrics
|
||||
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize state for checkpointing."""
|
||||
return {
|
||||
"strategy": self._strategy.value,
|
||||
"n_bins": self.n_bins,
|
||||
"temperature": self.temperature,
|
||||
"ema_alpha": self.ema_alpha,
|
||||
"competence_threshold": self.competence_threshold,
|
||||
"item_scores": dict(self._item_scores),
|
||||
"bin_boundaries": self._bin_boundaries,
|
||||
"last_rebin_count": self._last_rebin_count,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state: Dict[str, Any]) -> None:
|
||||
"""Restore state from checkpoint."""
|
||||
self._strategy = CurriculumStrategy(state["strategy"])
|
||||
self.n_bins = state["n_bins"]
|
||||
self.temperature = state["temperature"]
|
||||
self.ema_alpha = state["ema_alpha"]
|
||||
self.competence_threshold = state["competence_threshold"]
|
||||
self._item_scores = {k: tuple(v) for k, v in state["item_scores"].items()}
|
||||
self._bin_boundaries = state["bin_boundaries"]
|
||||
self._last_rebin_count = state["last_rebin_count"]
|
||||
248
atroposlib/tests/test_curriculum.py
Normal file
248
atroposlib/tests/test_curriculum.py
Normal file
|
|
@ -0,0 +1,248 @@
|
|||
"""
|
||||
Tests for CurriculumScheduler -- difficulty-based sampling for RL training.
|
||||
|
||||
Tests cover:
|
||||
- Uniform passthrough (default behavior unchanged)
|
||||
- Easy-first annealing
|
||||
- Competence-based frontier sampling
|
||||
- EMA difficulty updates
|
||||
- Bin assignment with quantile boundaries
|
||||
- Metrics and state persistence
|
||||
- Edge cases
|
||||
"""
|
||||
|
||||
import math
|
||||
import random
|
||||
|
||||
import pytest
|
||||
|
||||
from atroposlib.envs.curriculum import CurriculumScheduler, CurriculumStrategy
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Strategy tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUniformStrategy:
|
||||
def test_uniform_returns_valid_bins(self):
|
||||
scheduler = CurriculumScheduler(strategy="uniform", n_bins=5)
|
||||
bins = [scheduler.sample_bin(step, 1000) for step in range(100)]
|
||||
assert all(0 <= b < 5 for b in bins)
|
||||
|
||||
def test_uniform_covers_all_bins(self):
|
||||
"""Uniform should eventually sample from every bin."""
|
||||
random.seed(42)
|
||||
scheduler = CurriculumScheduler(strategy="uniform", n_bins=5)
|
||||
bins = set()
|
||||
for _ in range(200):
|
||||
bins.add(scheduler.sample_bin(0, 1000))
|
||||
assert bins == {0, 1, 2, 3, 4}
|
||||
|
||||
|
||||
class TestEasyFirstStrategy:
|
||||
def test_early_training_prefers_easy(self):
|
||||
"""At step 0, easy_first should strongly prefer low bins (easy)."""
|
||||
random.seed(42)
|
||||
scheduler = CurriculumScheduler(
|
||||
strategy="easy_first", n_bins=5, temperature=0.5
|
||||
)
|
||||
bins = [scheduler.sample_bin(0, 1000) for _ in range(200)]
|
||||
easy_count = sum(1 for b in bins if b <= 1)
|
||||
hard_count = sum(1 for b in bins if b >= 3)
|
||||
# Early training should have more easy than hard
|
||||
assert easy_count > hard_count
|
||||
|
||||
def test_late_training_approaches_uniform(self):
|
||||
"""Near the end (step~total), easy_first should be roughly uniform."""
|
||||
random.seed(42)
|
||||
scheduler = CurriculumScheduler(
|
||||
strategy="easy_first", n_bins=5, temperature=1.0
|
||||
)
|
||||
probs = scheduler._easy_first_probs(progress=1.0)
|
||||
# At progress=1.0, all probs should be near 1/n_bins
|
||||
for p in probs:
|
||||
assert abs(p - 0.2) < 0.05
|
||||
|
||||
|
||||
class TestCompetenceBasedStrategy:
|
||||
def test_competence_frontier_moves(self):
|
||||
"""The frontier should shift from easy to hard as training progresses."""
|
||||
scheduler = CurriculumScheduler(
|
||||
strategy="competence_based", n_bins=5, temperature=0.5
|
||||
)
|
||||
|
||||
# Early training: frontier at easy bins
|
||||
random.seed(42)
|
||||
early_bins = [scheduler.sample_bin(0, 1000) for _ in range(200)]
|
||||
early_mean = sum(early_bins) / len(early_bins)
|
||||
|
||||
# Late training: frontier at hard bins
|
||||
late_bins = [scheduler.sample_bin(900, 1000) for _ in range(200)]
|
||||
late_mean = sum(late_bins) / len(late_bins)
|
||||
|
||||
# Late mean should be higher (harder bins)
|
||||
assert late_mean > early_mean
|
||||
|
||||
def test_mid_training_prefers_middle(self):
|
||||
"""At 50% progress, competence_based should prefer middle bins."""
|
||||
random.seed(42)
|
||||
scheduler = CurriculumScheduler(
|
||||
strategy="competence_based", n_bins=5, temperature=0.5
|
||||
)
|
||||
bins = [scheduler.sample_bin(500, 1000) for _ in range(300)]
|
||||
mid_count = sum(1 for b in bins if 1 <= b <= 3)
|
||||
edge_count = sum(1 for b in bins if b == 0 or b == 4)
|
||||
assert mid_count > edge_count
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EMA difficulty tracking tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDifficultyTracking:
|
||||
def test_ema_update(self):
|
||||
scheduler = CurriculumScheduler(strategy="uniform", ema_alpha=0.5)
|
||||
scheduler.update("item_1", 1.0)
|
||||
assert math.isclose(scheduler.get_item_difficulty("item_1"), 1.0)
|
||||
|
||||
scheduler.update("item_1", 0.0)
|
||||
# EMA: 0.5 * 0.0 + 0.5 * 1.0 = 0.5
|
||||
assert math.isclose(scheduler.get_item_difficulty("item_1"), 0.5)
|
||||
|
||||
def test_batch_update(self):
|
||||
scheduler = CurriculumScheduler(strategy="uniform")
|
||||
scheduler.update_batch("item_1", [0.8, 0.6, 1.0])
|
||||
# Should use average: 0.8
|
||||
diff = scheduler.get_item_difficulty("item_1")
|
||||
assert diff is not None
|
||||
assert math.isclose(diff, 0.8)
|
||||
|
||||
def test_untracked_item_returns_none(self):
|
||||
scheduler = CurriculumScheduler(strategy="uniform")
|
||||
assert scheduler.get_item_difficulty("nonexistent") is None
|
||||
|
||||
def test_multiple_items_tracked(self):
|
||||
scheduler = CurriculumScheduler(strategy="uniform")
|
||||
scheduler.update("easy", 0.9)
|
||||
scheduler.update("hard", 0.1)
|
||||
scheduler.update("medium", 0.5)
|
||||
|
||||
assert scheduler.n_items_tracked == 3
|
||||
assert scheduler.get_item_difficulty("easy") > scheduler.get_item_difficulty(
|
||||
"hard"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bin assignment tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBinAssignment:
|
||||
def test_easy_item_gets_low_bin(self):
|
||||
scheduler = CurriculumScheduler(strategy="uniform", n_bins=5)
|
||||
# Create items spanning the difficulty range
|
||||
for i in range(100):
|
||||
scheduler.update(f"item_{i}", i / 100.0)
|
||||
|
||||
# High score = easy = low bin
|
||||
easy_bin = scheduler.get_item_bin("item_95")
|
||||
hard_bin = scheduler.get_item_bin("item_5")
|
||||
assert easy_bin < hard_bin
|
||||
|
||||
def test_untracked_gets_middle_bin(self):
|
||||
scheduler = CurriculumScheduler(strategy="uniform", n_bins=5)
|
||||
assert scheduler.get_item_bin("unknown") == 2 # n_bins // 2
|
||||
|
||||
def test_single_bin(self):
|
||||
scheduler = CurriculumScheduler(strategy="uniform", n_bins=1)
|
||||
scheduler.update("item", 0.5)
|
||||
assert scheduler.get_item_bin("item") == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Metrics and state tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMetrics:
|
||||
def test_metrics_dict_empty(self):
|
||||
scheduler = CurriculumScheduler(strategy="uniform")
|
||||
metrics = scheduler.metrics_dict()
|
||||
assert "curriculum/items_tracked" in metrics
|
||||
assert metrics["curriculum/items_tracked"] == 0
|
||||
|
||||
def test_metrics_dict_populated(self):
|
||||
scheduler = CurriculumScheduler(strategy="uniform", n_bins=3)
|
||||
for i in range(60): # Enough to trigger rebinning
|
||||
scheduler.update(f"item_{i}", i / 60.0)
|
||||
|
||||
metrics = scheduler.metrics_dict()
|
||||
assert metrics["curriculum/items_tracked"] == 60
|
||||
assert "curriculum/mean_difficulty" in metrics
|
||||
assert "curriculum/min_difficulty" in metrics
|
||||
assert "curriculum/max_difficulty" in metrics
|
||||
assert "curriculum/total_updates" in metrics
|
||||
|
||||
|
||||
class TestStatePersistence:
|
||||
def test_save_load_roundtrip(self):
|
||||
scheduler = CurriculumScheduler(
|
||||
strategy="competence_based", n_bins=3, temperature=0.8
|
||||
)
|
||||
for i in range(20):
|
||||
scheduler.update(f"item_{i}", i / 20.0)
|
||||
|
||||
state = scheduler.state_dict()
|
||||
|
||||
scheduler2 = CurriculumScheduler(strategy="uniform")
|
||||
scheduler2.load_state_dict(state)
|
||||
|
||||
assert scheduler2.strategy == "competence_based"
|
||||
assert scheduler2.n_bins == 3
|
||||
assert math.isclose(scheduler2.temperature, 0.8)
|
||||
assert scheduler2.n_items_tracked == 20
|
||||
|
||||
# Difficulty scores should match
|
||||
for i in range(20):
|
||||
key = f"item_{i}"
|
||||
d1 = scheduler.get_item_difficulty(key)
|
||||
d2 = scheduler2.get_item_difficulty(key)
|
||||
assert math.isclose(d1, d2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Edge cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
def test_invalid_strategy_raises(self):
|
||||
with pytest.raises(ValueError, match="Invalid curriculum strategy"):
|
||||
CurriculumScheduler(strategy="invalid")
|
||||
|
||||
def test_invalid_n_bins_raises(self):
|
||||
with pytest.raises(ValueError, match="n_bins must be >= 1"):
|
||||
CurriculumScheduler(n_bins=0)
|
||||
|
||||
def test_temperature_floor(self):
|
||||
scheduler = CurriculumScheduler(temperature=0.001)
|
||||
assert scheduler.temperature >= 0.01
|
||||
|
||||
def test_ema_alpha_clamped(self):
|
||||
scheduler = CurriculumScheduler(ema_alpha=2.0)
|
||||
assert scheduler.ema_alpha <= 1.0
|
||||
|
||||
scheduler2 = CurriculumScheduler(ema_alpha=-1.0)
|
||||
assert scheduler2.ema_alpha >= 0.0
|
||||
|
||||
def test_empty_batch_update(self):
|
||||
scheduler = CurriculumScheduler(strategy="uniform")
|
||||
scheduler.update_batch("item", [])
|
||||
assert scheduler.n_items_tracked == 0
|
||||
|
||||
def test_strategy_enum_values(self):
|
||||
assert CurriculumStrategy.UNIFORM.value == "uniform"
|
||||
assert CurriculumStrategy.EASY_FIRST.value == "easy_first"
|
||||
assert CurriculumStrategy.COMPETENCE_BASED.value == "competence_based"
|
||||
|
|
@ -26,6 +26,7 @@ dependencies = [
|
|||
"jsonlines",
|
||||
"pydantic-cli",
|
||||
"hf_transfer",
|
||||
"antlr4-python3-runtime==4.9.3",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue