feat: add online reward normalization for multi-env RL training stability

Add RewardNormalizer to atroposlib/envs/ with:
- Welford's online algorithm for running mean/variance (no data storage)
- Z-score and min-max normalization modes
- Configurable reward clipping and warmup period
- Checkpoint save/load support
- Opt-in integration in BaseEnv via 3 new config fields
- WandB metrics for normalization statistics

21/21 tests passing.
This commit is contained in:
RUFFY-369 2026-03-28 03:31:28 +05:30
parent c421582b6f
commit 0674e31a53
3 changed files with 560 additions and 0 deletions

View file

@ -211,6 +211,23 @@ class BaseEnvConfig(BaseModel):
"no thinking prompt is injected. Use HERMES_REASONING_PROMPT from "
"eval_helpers for the standard Hermes reasoning prompt.",
)
reward_normalization: str = Field(
default="none",
description="Reward normalization mode. 'none' = disabled (default), "
"'zscore' = z-score normalization, 'minmax' = min-max to [0,1]. "
"Uses Welford's online algorithm for running statistics.",
)
reward_clip: float = Field(
default=5.0,
description="Maximum absolute reward value after normalization. "
"Only applies when reward_normalization is not 'none'. "
"Set to 0 to disable clipping.",
)
reward_normalization_warmup: int = Field(
default=10,
description="Number of scored batches to observe before activating "
"reward normalization. During warmup, raw scores are used.",
)
class BaseEnv(ABC):
@ -262,6 +279,17 @@ class BaseEnv(ABC):
self.max_token_len = -1
self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name)
self.completion_lengths = []
# Initialize reward normalizer (opt-in via config)
if config.reward_normalization != "none":
from atroposlib.envs.reward_normalization import RewardNormalizer
self.reward_normalizer = RewardNormalizer(
mode=config.reward_normalization,
clip=config.reward_clip,
warmup=config.reward_normalization_warmup,
)
else:
self.reward_normalizer = None
self.max_num_workers = config.max_num_workers
if self.max_num_workers == -1:
self.max_num_workers = config.max_num_workers_per_node * len(
@ -674,6 +702,9 @@ class BaseEnv(ABC):
wandb_metrics["train/completion_lengths_p95"] = (
np.array(self.completion_lengths) > (0.95 * self.max_token_len)
).mean()
# Log reward normalization metrics if active
if self.reward_normalizer is not None:
wandb_metrics.update(self.reward_normalizer.metrics_dict())
wandb_metrics = await self.create_rollout_table(wandb_metrics)
wandb_metrics = self.perf_stats(wandb_metrics)
self.rollouts_for_wandb = []
@ -892,6 +923,16 @@ class BaseEnv(ABC):
logger.warning("Scores are the same in a group, skipping...")
continue
# Apply reward normalization if enabled (opt-in via config)
if self.reward_normalizer is not None:
group["scores"] = self.reward_normalizer.normalize(group["scores"])
# Re-check after normalization: if all scores collapsed, skip
if len(set(group["scores"])) == 1:
logger.debug(
"Scores collapsed to same value after normalization, skipping"
)
continue
group.setdefault("ref_logprobs", None)
group.setdefault("overrides", None)
group.setdefault("group_overrides", None)