This commit is contained in:
Jai Suphavadeeprasit 2026-02-02 16:06:19 -05:00
parent 2b5debe0a2
commit 851f0b6e17
2 changed files with 29 additions and 0 deletions

View file

@ -214,6 +214,18 @@ def compute_grpo_loss(
# Move inference logprobs to correct device/dtype
ref_logprobs = inference_logprobs.to(logp_per_token.device, logp_per_token.dtype)
# DEBUG: Check if inference logprobs look valid
with torch.no_grad():
ref_nonzero = (ref_logprobs != 0).float()
ref_nonzero_frac = (ref_nonzero * mask).sum() / mask.sum()
ref_mean = (ref_logprobs * mask).sum() / mask.sum()
train_mean = (logp_per_token * mask).sum() / mask.sum()
if ref_nonzero_frac < 0.5:
print(f" [WARNING] Only {ref_nonzero_frac*100:.1f}% of inference logprobs are non-zero!")
print(f" [WARNING] This suggests inference_logprobs field may be missing from data")
if abs(ref_mean - train_mean) > 1.0:
print(f" [DEBUG] Large logprob gap: ref_mean={ref_mean:.3f}, train_mean={train_mean:.3f}")
# Compute importance sampling ratio: π(a|s) / π_old(a|s) = exp(log π - log π_old)
log_ratio = logp_per_token - ref_logprobs
ratio = torch.exp(log_ratio)