mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-30 17:40:36 +00:00
feat: add curriculum learning scheduler for sample-efficient RL training
Add CurriculumScheduler to atroposlib/envs/ with: - EMA-based per-item difficulty tracking from reward signals - Quantile-based difficulty binning (configurable N bins) - Three sampling strategies: uniform, easy_first, competence_based - Competence-based strategy cites Platanios et al. 2019 - Opt-in integration in BaseEnv via 3 config fields - WandB metrics for difficulty distribution tracking - Checkpoint save/load support 22/22 tests passing.
This commit is contained in:
parent
c421582b6f
commit
01da524b6b
3 changed files with 638 additions and 0 deletions
|
|
@ -211,6 +211,24 @@ class BaseEnvConfig(BaseModel):
|
|||
"no thinking prompt is injected. Use HERMES_REASONING_PROMPT from "
|
||||
"eval_helpers for the standard Hermes reasoning prompt.",
|
||||
)
|
||||
curriculum_strategy: str = Field(
|
||||
default="uniform",
|
||||
description="Curriculum learning strategy. 'uniform' = no curriculum (default), "
|
||||
"'easy_first' = oversample easy items early then anneal, "
|
||||
"'competence_based' = sample at competence frontier. "
|
||||
"See Platanios et al. 2019 for competence-based curriculum.",
|
||||
)
|
||||
curriculum_bins: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
description="Number of difficulty bins for curriculum scheduling.",
|
||||
)
|
||||
curriculum_temperature: float = Field(
|
||||
default=1.0,
|
||||
gt=0,
|
||||
description="Temperature for curriculum bin sampling. Higher = more uniform, "
|
||||
"lower = more concentrated on target difficulty.",
|
||||
)
|
||||
|
||||
|
||||
class BaseEnv(ABC):
|
||||
|
|
@ -262,6 +280,17 @@ class BaseEnv(ABC):
|
|||
self.max_token_len = -1
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name)
|
||||
self.completion_lengths = []
|
||||
# Initialize curriculum scheduler (opt-in via config)
|
||||
if config.curriculum_strategy != "uniform":
|
||||
from atroposlib.envs.curriculum import CurriculumScheduler
|
||||
|
||||
self.curriculum = CurriculumScheduler(
|
||||
strategy=config.curriculum_strategy,
|
||||
n_bins=config.curriculum_bins,
|
||||
temperature=config.curriculum_temperature,
|
||||
)
|
||||
else:
|
||||
self.curriculum = None
|
||||
self.max_num_workers = config.max_num_workers
|
||||
if self.max_num_workers == -1:
|
||||
self.max_num_workers = config.max_num_workers_per_node * len(
|
||||
|
|
@ -674,6 +703,9 @@ class BaseEnv(ABC):
|
|||
wandb_metrics["train/completion_lengths_p95"] = (
|
||||
np.array(self.completion_lengths) > (0.95 * self.max_token_len)
|
||||
).mean()
|
||||
# Log curriculum metrics if active
|
||||
if self.curriculum is not None:
|
||||
wandb_metrics.update(self.curriculum.metrics_dict())
|
||||
wandb_metrics = await self.create_rollout_table(wandb_metrics)
|
||||
wandb_metrics = self.perf_stats(wandb_metrics)
|
||||
self.rollouts_for_wandb = []
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue