This commit is contained in:
Jai Suphavadeeprasit 2026-02-02 15:57:12 -05:00
parent 13def6bdab
commit 3e7705c17d
2 changed files with 39 additions and 26 deletions

View file

@ -238,13 +238,20 @@ def compute_grpo_loss(
policy_loss = ((policy_loss_per_token * mask).sum(dim=-1) / mask_sum).mean()
# KL penalty: encourage staying close to reference policy
# KL(π || π_ref) ≈ log(π/π_ref) = log_ratio (when π_ref is the reference)
# We use the approximation: KL ≈ (ratio - 1) - log(ratio)
# But simpler: just penalize squared log-ratio which is symmetric
# Using Schulman's unbiased KL estimator from the DeepSeek GRPO paper (Equation 4):
# D_KL(π_θ || π_ref) = (π_ref / π_θ) - log(π_ref / π_θ) - 1
#
# In terms of log probabilities:
# log_ratio = log π_θ - log π_ref (what we computed above)
# ratio_ref_over_pi = exp(-log_ratio) = π_ref / π_θ
# kl = ratio_ref_over_pi - log(ratio_ref_over_pi) - 1
# = exp(-log_ratio) + log_ratio - 1
#
# This estimator is guaranteed to be non-negative (unlike squared log-ratio).
if kl_coef > 0:
# Approximate KL using (log_ratio)^2 / 2 (Taylor expansion)
# Or just use log_ratio directly as a penalty
kl_per_token = log_ratio.pow(2) # Squared for symmetric penalty
# Schulman's unbiased KL estimator: (π_ref/π) - log(π_ref/π) - 1
# = exp(-log_ratio) + log_ratio - 1
kl_per_token = torch.exp(-log_ratio) + log_ratio - 1.0
kl_penalty = ((kl_per_token * mask).sum(dim=-1) / mask_sum).mean()
total_loss = (policy_loss + kl_coef * kl_penalty) / gradient_accumulation_steps
else:
@ -257,9 +264,11 @@ def compute_grpo_loss(
clipped_fraction = ((ratio < 1.0 - clip_eps) | (ratio > 1.0 + clip_eps)).float()
clipped_fraction = (clipped_fraction * mask).sum() / mask.sum()
# Mean ratio and KL for monitoring
# Mean ratio and KL for monitoring (using Schulman's estimator)
mean_ratio = (ratio * mask).sum() / mask.sum()
mean_kl = (log_ratio.pow(2) * mask).sum() / mask.sum()
# Schulman KL: exp(-log_ratio) + log_ratio - 1
schulman_kl = torch.exp(-log_ratio) + log_ratio - 1.0
mean_kl = (schulman_kl * mask).sum() / mask.sum()
# For backward compatibility: collect training logprobs
raw_logp_per_token = -F.cross_entropy(