mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-28 17:29:30 +00:00
KL
This commit is contained in:
parent
13def6bdab
commit
3e7705c17d
2 changed files with 39 additions and 26 deletions
|
|
@ -238,13 +238,20 @@ def compute_grpo_loss(
|
|||
policy_loss = ((policy_loss_per_token * mask).sum(dim=-1) / mask_sum).mean()
|
||||
|
||||
# KL penalty: encourage staying close to reference policy
|
||||
# KL(π || π_ref) ≈ log(π/π_ref) = log_ratio (when π_ref is the reference)
|
||||
# We use the approximation: KL ≈ (ratio - 1) - log(ratio)
|
||||
# But simpler: just penalize squared log-ratio which is symmetric
|
||||
# Using Schulman's unbiased KL estimator from the DeepSeek GRPO paper (Equation 4):
|
||||
# D_KL(π_θ || π_ref) = (π_ref / π_θ) - log(π_ref / π_θ) - 1
|
||||
#
|
||||
# In terms of log probabilities:
|
||||
# log_ratio = log π_θ - log π_ref (what we computed above)
|
||||
# ratio_ref_over_pi = exp(-log_ratio) = π_ref / π_θ
|
||||
# kl = ratio_ref_over_pi - log(ratio_ref_over_pi) - 1
|
||||
# = exp(-log_ratio) + log_ratio - 1
|
||||
#
|
||||
# This estimator is guaranteed to be non-negative (unlike squared log-ratio).
|
||||
if kl_coef > 0:
|
||||
# Approximate KL using (log_ratio)^2 / 2 (Taylor expansion)
|
||||
# Or just use log_ratio directly as a penalty
|
||||
kl_per_token = log_ratio.pow(2) # Squared for symmetric penalty
|
||||
# Schulman's unbiased KL estimator: (π_ref/π) - log(π_ref/π) - 1
|
||||
# = exp(-log_ratio) + log_ratio - 1
|
||||
kl_per_token = torch.exp(-log_ratio) + log_ratio - 1.0
|
||||
kl_penalty = ((kl_per_token * mask).sum(dim=-1) / mask_sum).mean()
|
||||
total_loss = (policy_loss + kl_coef * kl_penalty) / gradient_accumulation_steps
|
||||
else:
|
||||
|
|
@ -257,9 +264,11 @@ def compute_grpo_loss(
|
|||
clipped_fraction = ((ratio < 1.0 - clip_eps) | (ratio > 1.0 + clip_eps)).float()
|
||||
clipped_fraction = (clipped_fraction * mask).sum() / mask.sum()
|
||||
|
||||
# Mean ratio and KL for monitoring
|
||||
# Mean ratio and KL for monitoring (using Schulman's estimator)
|
||||
mean_ratio = (ratio * mask).sum() / mask.sum()
|
||||
mean_kl = (log_ratio.pow(2) * mask).sum() / mask.sum()
|
||||
# Schulman KL: exp(-log_ratio) + log_ratio - 1
|
||||
schulman_kl = torch.exp(-log_ratio) + log_ratio - 1.0
|
||||
mean_kl = (schulman_kl * mask).sum() / mask.sum()
|
||||
|
||||
# For backward compatibility: collect training logprobs
|
||||
raw_logp_per_token = -F.cross_entropy(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue