feat: add curriculum learning scheduler for sample-efficient RL training

Add CurriculumScheduler to atroposlib/envs/ with:
- EMA-based per-item difficulty tracking from reward signals
- Quantile-based difficulty binning (configurable N bins)
- Three sampling strategies: uniform, easy_first, competence_based
- Competence-based strategy cites Platanios et al. 2019
- Opt-in integration in BaseEnv via 3 config fields
- WandB metrics for difficulty distribution tracking
- Checkpoint save/load support

22/22 tests passing.
This commit is contained in:
RUFFY-369 2026-03-28 03:39:01 +05:30
parent c421582b6f
commit 01da524b6b
3 changed files with 638 additions and 0 deletions

View file

@ -211,6 +211,24 @@ class BaseEnvConfig(BaseModel):
"no thinking prompt is injected. Use HERMES_REASONING_PROMPT from "
"eval_helpers for the standard Hermes reasoning prompt.",
)
curriculum_strategy: str = Field(
default="uniform",
description="Curriculum learning strategy. 'uniform' = no curriculum (default), "
"'easy_first' = oversample easy items early then anneal, "
"'competence_based' = sample at competence frontier. "
"See Platanios et al. 2019 for competence-based curriculum.",
)
curriculum_bins: int = Field(
default=5,
ge=1,
description="Number of difficulty bins for curriculum scheduling.",
)
curriculum_temperature: float = Field(
default=1.0,
gt=0,
description="Temperature for curriculum bin sampling. Higher = more uniform, "
"lower = more concentrated on target difficulty.",
)
class BaseEnv(ABC):
@ -262,6 +280,17 @@ class BaseEnv(ABC):
self.max_token_len = -1
self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name)
self.completion_lengths = []
# Initialize curriculum scheduler (opt-in via config)
if config.curriculum_strategy != "uniform":
from atroposlib.envs.curriculum import CurriculumScheduler
self.curriculum = CurriculumScheduler(
strategy=config.curriculum_strategy,
n_bins=config.curriculum_bins,
temperature=config.curriculum_temperature,
)
else:
self.curriculum = None
self.max_num_workers = config.max_num_workers
if self.max_num_workers == -1:
self.max_num_workers = config.max_num_workers_per_node * len(
@ -674,6 +703,9 @@ class BaseEnv(ABC):
wandb_metrics["train/completion_lengths_p95"] = (
np.array(self.completion_lengths) > (0.95 * self.max_token_len)
).mean()
# Log curriculum metrics if active
if self.curriculum is not None:
wandb_metrics.update(self.curriculum.metrics_dict())
wandb_metrics = await self.create_rollout_table(wandb_metrics)
wandb_metrics = self.perf_stats(wandb_metrics)
self.rollouts_for_wandb = []