This commit is contained in:
Jai Suphavadeeprasit 2026-02-03 12:10:39 -05:00
parent 04f2850980
commit 0b61dd047a
2 changed files with 453 additions and 540 deletions

View file

@ -12,7 +12,6 @@ import string
import time
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
import wandb
@ -131,7 +130,7 @@ def compute_grpo_loss(
# Move inference logprobs to correct device/dtype
ref_logprobs = inference_logprobs.to(logp_per_token.device, logp_per_token.dtype)
# DEBUG: Check if inference logprobs look valid
# NOTE: inference_logprobs uses 1.0 for masked (prompt) positions, actual negative values for generated
with torch.no_grad():
# Only look at generated positions (where mask == 1)
@ -146,7 +145,7 @@ def compute_grpo_loss(
elif abs(ref_at_generated - train_at_generated) > 2.0:
print(f" [DEBUG] Logprob gap (may be OK for first step): ref={ref_at_generated:.3f}, train={train_at_generated:.3f}")
# Compute importance sampling ratio: π(a|s) / π_old(a|s) = exp(log π - log π_old)
# Compute importance sampling ratio: policy(a|s) / policy_old(a|s) = exp(log policy - log policy_old)
log_ratio = logp_per_token - ref_logprobs
ratio = torch.exp(log_ratio)
@ -159,7 +158,7 @@ def compute_grpo_loss(
# Pessimistic bound: min for positive advantages, max for negative
# This is equivalent to: -min(ratio * A, clipped_ratio * A) when A > 0
# -max(ratio * A, clipped_ratio * A) when A < 0
# -max(ratio * A, clipped_ratio * A) when A < 0
policy_loss_per_token = -torch.where(
adv_expanded >= 0,
torch.min(surr1, surr2),
@ -171,14 +170,6 @@ def compute_grpo_loss(
# KL penalty: encourage staying close to reference policy
# Using Schulman's unbiased KL estimator from the DeepSeek GRPO paper (Equation 4):
# D_KL(π_θ || π_ref) = (π_ref / π_θ) - log(π_ref / π_θ) - 1
#
# In terms of log probabilities:
# log_ratio = log π_θ - log π_ref (what we computed above)
# ratio_ref_over_pi = exp(-log_ratio) = π_ref / π_θ
# kl = ratio_ref_over_pi - log(ratio_ref_over_pi) - 1
# = exp(-log_ratio) + log_ratio - 1
#
# This estimator is guaranteed to be non-negative (unlike squared log-ratio).
if kl_coef > 0:
# Schulman's unbiased KL estimator: (π_ref/π) - log(π_ref/π) - 1
@ -211,26 +202,18 @@ def compute_grpo_loss(
).view(labels.shape)
training_logprobs_flat = raw_logp_per_token[mask.bool()].detach()
else:
# Fallback: REINFORCE-style (no reference policy)
# This is what the original code did - NOT recommended!
print(" [WARNING] No reference logprobs - using REINFORCE (may cause reward hacking!)")
# Simple policy gradient: -log(π) * A
policy_loss = ((-logp_per_token * mask * adv_expanded).sum(dim=-1) / mask_sum).mean()
total_loss = policy_loss / gradient_accumulation_steps
kl_penalty = torch.tensor(0.0, device=logp_per_token.device)
with torch.no_grad():
clipped_fraction = torch.tensor(0.0)
mean_ratio = torch.tensor(1.0)
mean_kl = torch.tensor(0.0)
raw_logp_per_token = -F.cross_entropy(
outputs.logits.view(-1, outputs.logits.size(-1)),
labels.view(-1),
reduction="none",
ignore_index=-100,
).view(labels.shape)
training_logprobs_flat = raw_logp_per_token[mask.bool()].detach()
# Fail loudly
raise ValueError(
"GRPO requires inference_logprobs for importance sampling!\n"
"\n"
"This error means the environment isn't providing logprobs. To fix:\n"
" 1. Use --openai.server_type vllm (not 'openai')\n"
" 2. Ensure vLLM is returning logprobs in /generate response\n"
" 3. Check that gsm8k_server is configured correctly\n"
"\n"
"Without inference logprobs, training will cause reward hacking.\n"
"If you REALLY want vanilla REINFORCE (not recommended), set use_reference_logprobs=False"
)
# === Compute Additional Metrics ===
with torch.no_grad():
@ -262,67 +245,6 @@ def compute_grpo_loss(
return total_loss, metrics
def compute_logprob_alignment(
inference_logprobs: List[np.ndarray],
training_logprobs: List[torch.Tensor],
debug: bool = False,
) -> Dict[str, float]:
"""
Compute alignment stats between inference and training logprobs.
At initialization (step 0), these should match closely if the model
weights are correctly shared between training and inference.
Args:
inference_logprobs: Logprobs from vLLM inference (numpy arrays)
training_logprobs: Logprobs computed during training forward pass (PyTorch tensors, bfloat16 supported)
debug: If True, print detailed debugging info
Returns:
Dict of alignment statistics
"""
if not inference_logprobs or not training_logprobs:
return {}
# Process inference logprobs (numpy)
inf_flat = np.concatenate(inference_logprobs)
# Filter out placeholder values (1.0 or 0.0 used for prompt tokens)
inf_mask = (inf_flat != 1.0) & (inf_flat != 0.0)
inf_filtered = inf_flat[inf_mask]
# Process training logprobs (PyTorch - supports bfloat16 natively)
train_flat = torch.cat(training_logprobs)
if debug:
print(f" [DEBUG] Inference: {len(inf_flat)} total, {len(inf_filtered)} after filter")
print(f" [DEBUG] Training: {train_flat.numel()} logprobs")
if len(inf_filtered) > 0:
print(f" [DEBUG] Inf sample (first 5): {inf_filtered[:5]}")
if train_flat.numel() > 0:
print(f" [DEBUG] Train sample (first 5): {train_flat[:5].tolist()}")
# Compute stats using PyTorch for training (keeps bfloat16 precision)
stats = {}
if len(inf_filtered) > 0:
stats["logprobs/inference_mean"] = float(np.mean(inf_filtered))
stats["logprobs/inference_std"] = float(np.std(inf_filtered))
if train_flat.numel() > 0:
# PyTorch operations - fully support bfloat16
stats["logprobs/training_mean"] = train_flat.mean().item()
stats["logprobs/training_std"] = train_flat.std().item()
# Compute diff (for tracking, not validation)
# NOTE: Per-token comparison is NOT reliable here because inference and training
# logprobs come from different batch orderings and can't be aligned token-by-token.
# The real-time test at startup is the proper alignment validation.
if "logprobs/inference_mean" in stats and "logprobs/training_mean" in stats:
stats["logprobs/diff"] = stats["logprobs/inference_mean"] - stats["logprobs/training_mean"]
return stats
def run_training_step(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,