atropos/atroposlib/envs/reward_normalization.py
RUFFY-369 0674e31a53 feat: add online reward normalization for multi-env RL training stability
Add RewardNormalizer to atroposlib/envs/ with:
- Welford's online algorithm for running mean/variance (no data storage)
- Z-score and min-max normalization modes
- Configurable reward clipping and warmup period
- Checkpoint save/load support
- Opt-in integration in BaseEnv via 3 new config fields
- WandB metrics for normalization statistics

21/21 tests passing.
2026-03-28 03:31:28 +05:30

267 lines
8.5 KiB
Python

"""
Online reward normalization for multi-environment RL training stability.
Implements Welford's online algorithm for running mean/variance computation,
enabling z-score and min-max normalization of reward signals without needing
to store all historical values.
This is critical for multi-environment training where different environments
produce rewards on different scales (e.g., GSM8K gives {-1, 1} while
tool-use environments give continuous [0, 1] scores).
Usage:
normalizer = RewardNormalizer(mode="zscore", clip=5.0)
# During training loop
scores = [0.5, -0.3, 0.8, 1.0]
normalized = normalizer.normalize(scores)
# Checkpointing
state = normalizer.state_dict()
normalizer.load_state_dict(state)
"""
import logging
import math
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
class WelfordAccumulator:
"""
Welford's online algorithm for computing running mean and variance.
Numerically stable single-pass algorithm that avoids catastrophic
cancellation. Maintains count, mean, and M2 (sum of squared deviations)
to compute variance on demand.
Reference: Welford, B. P. (1962). "Note on a method for calculating
corrected sums of squares and products". Technometrics. 4(3): 419-420.
"""
def __init__(self):
self.count: int = 0
self.mean: float = 0.0
self._m2: float = 0.0
self._min: float = float("inf")
self._max: float = float("-inf")
def update(self, value: float) -> None:
"""Update running statistics with a new value."""
self.count += 1
delta = value - self.mean
self.mean += delta / self.count
delta2 = value - self.mean
self._m2 += delta * delta2
self._min = min(self._min, value)
self._max = max(self._max, value)
def update_batch(self, values: List[float]) -> None:
"""Update running statistics with a batch of values."""
for v in values:
self.update(v)
@property
def variance(self) -> float:
"""Population variance of all observed values."""
if self.count < 2:
return 0.0
return self._m2 / self.count
@property
def std(self) -> float:
"""Population standard deviation of all observed values."""
return math.sqrt(self.variance)
@property
def min_val(self) -> float:
"""Minimum observed value."""
return self._min if self.count > 0 else 0.0
@property
def max_val(self) -> float:
"""Maximum observed value."""
return self._max if self.count > 0 else 0.0
def state_dict(self) -> Dict[str, Any]:
"""Serialize state for checkpointing."""
return {
"count": self.count,
"mean": self.mean,
"m2": self._m2,
"min": self._min,
"max": self._max,
}
def load_state_dict(self, state: Dict[str, Any]) -> None:
"""Restore state from checkpoint."""
self.count = state["count"]
self.mean = state["mean"]
self._m2 = state["m2"]
self._min = state["min"]
self._max = state["max"]
class RewardNormalizer:
"""
Reward normalization for stable multi-environment RL training.
Supports two normalization modes:
- "zscore": Standardize to zero mean, unit variance using running stats
- "minmax": Scale to [0, 1] range using observed min/max
Both modes use Welford's online algorithm so no historical data storage
is required. Optional reward clipping prevents extreme values from
destabilizing training.
Args:
mode: Normalization mode. One of "zscore", "minmax", or "none".
clip: Maximum absolute value after normalization. Set to 0 or None
to disable clipping. Default: 5.0.
warmup: Minimum number of samples before normalization activates.
During warmup, raw scores are returned (optionally clipped).
Default: 10.
eps: Small constant for numerical stability in division. Default: 1e-8.
"""
VALID_MODES = {"zscore", "minmax", "none"}
def __init__(
self,
mode: str = "zscore",
clip: Optional[float] = 5.0,
warmup: int = 10,
eps: float = 1e-8,
):
if mode not in self.VALID_MODES:
raise ValueError(
f"Invalid normalization mode '{mode}'. "
f"Must be one of: {self.VALID_MODES}"
)
self.mode = mode
self.clip = clip if clip and clip > 0 else None
self.warmup = max(0, warmup)
self.eps = eps
self._accumulator = WelfordAccumulator()
@property
def count(self) -> int:
"""Number of samples observed."""
return self._accumulator.count
@property
def mean(self) -> float:
"""Running mean of observed values."""
return self._accumulator.mean
@property
def std(self) -> float:
"""Running standard deviation of observed values."""
return self._accumulator.std
@property
def is_warmed_up(self) -> bool:
"""Whether enough samples have been observed for normalization."""
return self._accumulator.count >= self.warmup
def normalize(self, scores: List[float]) -> List[float]:
"""
Normalize a batch of reward scores.
Updates running statistics with the new scores, then applies
normalization. During warmup, raw scores are returned (with
optional clipping).
Args:
scores: Raw reward scores to normalize.
Returns:
Normalized (and optionally clipped) scores.
"""
if not scores:
return []
if self.mode == "none":
return list(scores)
# Update running statistics
self._accumulator.update_batch(scores)
# During warmup, return raw scores (optionally clipped)
if not self.is_warmed_up:
logger.debug(
"Reward normalizer warmup: %d/%d samples",
self._accumulator.count,
self.warmup,
)
return self._clip(list(scores))
# Apply normalization
if self.mode == "zscore":
normalized = self._zscore(scores)
elif self.mode == "minmax":
normalized = self._minmax(scores)
else:
normalized = list(scores)
return self._clip(normalized)
def _zscore(self, scores: List[float]) -> List[float]:
"""Z-score normalize: (x - mean) / std."""
mean = self._accumulator.mean
std = self._accumulator.std
if std < self.eps:
# All values nearly identical -- return zeros
return [0.0] * len(scores)
return [(s - mean) / (std + self.eps) for s in scores]
def _minmax(self, scores: List[float]) -> List[float]:
"""Min-max normalize to [0, 1] range."""
min_val = self._accumulator.min_val
max_val = self._accumulator.max_val
range_val = max_val - min_val
if range_val < self.eps:
return [0.5] * len(scores)
return [(s - min_val) / (range_val + self.eps) for s in scores]
def _clip(self, scores: List[float]) -> List[float]:
"""Clip scores to [-clip, clip] range."""
if self.clip is None:
return scores
return [max(-self.clip, min(self.clip, s)) for s in scores]
def metrics_dict(self) -> Dict[str, float]:
"""
Return current normalization statistics for WandB logging.
Returns:
Dictionary with keys suitable for wandb.log().
"""
metrics = {
"reward_norm/count": float(self._accumulator.count),
"reward_norm/mean": self._accumulator.mean,
"reward_norm/std": self._accumulator.std,
"reward_norm/min": self._accumulator.min_val,
"reward_norm/max": self._accumulator.max_val,
}
return metrics
def state_dict(self) -> Dict[str, Any]:
"""Serialize full state for checkpointing."""
return {
"mode": self.mode,
"clip": self.clip,
"warmup": self.warmup,
"eps": self.eps,
"accumulator": self._accumulator.state_dict(),
}
def load_state_dict(self, state: Dict[str, Any]) -> None:
"""Restore state from checkpoint."""
self.mode = state["mode"]
self.clip = state["clip"]
self.warmup = state["warmup"]
self.eps = state["eps"]
self._accumulator.load_state_dict(state["accumulator"])