mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-24 17:04:55 +00:00
feat: add online reward normalization for multi-env RL training stability
Add RewardNormalizer to atroposlib/envs/ with: - Welford's online algorithm for running mean/variance (no data storage) - Z-score and min-max normalization modes - Configurable reward clipping and warmup period - Checkpoint save/load support - Opt-in integration in BaseEnv via 3 new config fields - WandB metrics for normalization statistics 21/21 tests passing.
This commit is contained in:
parent
c421582b6f
commit
0674e31a53
3 changed files with 560 additions and 0 deletions
|
|
@ -211,6 +211,23 @@ class BaseEnvConfig(BaseModel):
|
|||
"no thinking prompt is injected. Use HERMES_REASONING_PROMPT from "
|
||||
"eval_helpers for the standard Hermes reasoning prompt.",
|
||||
)
|
||||
reward_normalization: str = Field(
|
||||
default="none",
|
||||
description="Reward normalization mode. 'none' = disabled (default), "
|
||||
"'zscore' = z-score normalization, 'minmax' = min-max to [0,1]. "
|
||||
"Uses Welford's online algorithm for running statistics.",
|
||||
)
|
||||
reward_clip: float = Field(
|
||||
default=5.0,
|
||||
description="Maximum absolute reward value after normalization. "
|
||||
"Only applies when reward_normalization is not 'none'. "
|
||||
"Set to 0 to disable clipping.",
|
||||
)
|
||||
reward_normalization_warmup: int = Field(
|
||||
default=10,
|
||||
description="Number of scored batches to observe before activating "
|
||||
"reward normalization. During warmup, raw scores are used.",
|
||||
)
|
||||
|
||||
|
||||
class BaseEnv(ABC):
|
||||
|
|
@ -262,6 +279,17 @@ class BaseEnv(ABC):
|
|||
self.max_token_len = -1
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name)
|
||||
self.completion_lengths = []
|
||||
# Initialize reward normalizer (opt-in via config)
|
||||
if config.reward_normalization != "none":
|
||||
from atroposlib.envs.reward_normalization import RewardNormalizer
|
||||
|
||||
self.reward_normalizer = RewardNormalizer(
|
||||
mode=config.reward_normalization,
|
||||
clip=config.reward_clip,
|
||||
warmup=config.reward_normalization_warmup,
|
||||
)
|
||||
else:
|
||||
self.reward_normalizer = None
|
||||
self.max_num_workers = config.max_num_workers
|
||||
if self.max_num_workers == -1:
|
||||
self.max_num_workers = config.max_num_workers_per_node * len(
|
||||
|
|
@ -674,6 +702,9 @@ class BaseEnv(ABC):
|
|||
wandb_metrics["train/completion_lengths_p95"] = (
|
||||
np.array(self.completion_lengths) > (0.95 * self.max_token_len)
|
||||
).mean()
|
||||
# Log reward normalization metrics if active
|
||||
if self.reward_normalizer is not None:
|
||||
wandb_metrics.update(self.reward_normalizer.metrics_dict())
|
||||
wandb_metrics = await self.create_rollout_table(wandb_metrics)
|
||||
wandb_metrics = self.perf_stats(wandb_metrics)
|
||||
self.rollouts_for_wandb = []
|
||||
|
|
@ -892,6 +923,16 @@ class BaseEnv(ABC):
|
|||
logger.warning("Scores are the same in a group, skipping...")
|
||||
continue
|
||||
|
||||
# Apply reward normalization if enabled (opt-in via config)
|
||||
if self.reward_normalizer is not None:
|
||||
group["scores"] = self.reward_normalizer.normalize(group["scores"])
|
||||
# Re-check after normalization: if all scores collapsed, skip
|
||||
if len(set(group["scores"])) == 1:
|
||||
logger.debug(
|
||||
"Scores collapsed to same value after normalization, skipping"
|
||||
)
|
||||
continue
|
||||
|
||||
group.setdefault("ref_logprobs", None)
|
||||
group.setdefault("overrides", None)
|
||||
group.setdefault("group_overrides", None)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue