This commit is contained in:
Jai Suphavadeeprasit 2026-02-03 14:34:02 -05:00
parent 8f1f8acbde
commit 77c592c909

View file

@ -125,11 +125,16 @@ def compute_grpo_loss(
# Expand advantages to match token shape [batch, 1] -> [batch, seq_len]
adv_expanded = advantages.expand_as(logp_per_token).to(logp_per_token.device)
# Track logprobs for alignment verification
inference_logprobs_flat = None
logprob_diff_mean = 0.0
logprob_diff_abs_mean = 0.0
logprob_diff_max = 0.0
# === GRPO/PPO Loss Computation ===
if use_reference_logprobs and inference_logprobs is not None:
# Move inference logprobs to correct device/dtype
ref_logprobs = inference_logprobs.to(logp_per_token.device, logp_per_token.dtype)
# NOTE: inference_logprobs uses 1.0 for masked (prompt) positions, actual negative values for generated
with torch.no_grad():
@ -137,6 +142,17 @@ def compute_grpo_loss(
ref_at_generated = (ref_logprobs * mask).sum() / mask.sum()
train_at_generated = (logp_per_token * mask).sum() / mask.sum()
# Extract logprobs at generated positions for alignment tracking
inference_logprobs_flat = ref_logprobs[mask.bool()].detach()
training_at_mask = logp_per_token[mask.bool()].detach()
# Token-level difference: THE key metric for alignment verification
# If weights are truly shared, this should be ~0 at step start
token_diff = training_at_mask - inference_logprobs_flat
logprob_diff_mean = token_diff.mean().item()
logprob_diff_abs_mean = token_diff.abs().mean().item()
logprob_diff_max = token_diff.abs().max().item()
# Check if ref logprobs are negative (as they should be for generated tokens)
# If ref_at_generated is close to 1.0, that means the 1.0 placeholder is being used
if ref_at_generated > 0.5:
@ -234,12 +250,17 @@ def compute_grpo_loss(
"pos_count": pos.sum().item(),
"neg_count": neg.sum().item(),
"training_logprobs": training_logprobs_flat,
"inference_logprobs": inference_logprobs_flat,
"interpretable_loss": interpretable_loss,
# GRPO-specific metrics
"kl_penalty": kl_penalty.item() if torch.is_tensor(kl_penalty) else kl_penalty,
"mean_ratio": mean_ratio.item() if torch.is_tensor(mean_ratio) else mean_ratio,
"mean_kl": mean_kl.item() if torch.is_tensor(mean_kl) else mean_kl,
"clipped_fraction": clipped_fraction.item() if torch.is_tensor(clipped_fraction) else clipped_fraction,
# Token-level alignment metrics (key for verifying weight sharing)
"logprob_diff_mean": logprob_diff_mean,
"logprob_diff_abs_mean": logprob_diff_abs_mean,
"logprob_diff_max": logprob_diff_max,
}
return total_loss, metrics
@ -286,8 +307,12 @@ def run_training_step(
total_mean_ratio = 0.0
total_mean_kl = 0.0
total_clipped_fraction = 0.0
total_logprob_diff_mean = 0.0
total_logprob_diff_abs_mean = 0.0
total_logprob_diff_max = 0.0
grad_norm = 0.0
all_training_logprobs: List[torch.Tensor] = []
all_inference_logprobs: List[torch.Tensor] = []
# Get GRPO hyperparameters from config
kl_coef = getattr(config, 'kl_coef', 0.1)
@ -335,9 +360,16 @@ def run_training_step(
total_mean_kl += metrics.get("mean_kl", 0.0)
total_clipped_fraction += metrics.get("clipped_fraction", 0.0)
# Collect training logprobs for alignment monitoring
if "training_logprobs" in metrics:
# Accumulate token-level alignment metrics
total_logprob_diff_mean += metrics.get("logprob_diff_mean", 0.0)
total_logprob_diff_abs_mean += metrics.get("logprob_diff_abs_mean", 0.0)
total_logprob_diff_max = max(total_logprob_diff_max, metrics.get("logprob_diff_max", 0.0))
# Collect logprobs for alignment monitoring
if "training_logprobs" in metrics and metrics["training_logprobs"] is not None:
all_training_logprobs.append(metrics["training_logprobs"])
if "inference_logprobs" in metrics and metrics["inference_logprobs"] is not None:
all_inference_logprobs.append(metrics["inference_logprobs"])
# Gradient clipping and optimizer step
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
@ -368,15 +400,25 @@ def run_training_step(
}
# Compute logprob alignment stats for monitoring
# NOTE: Now that we use proper GRPO, this is less critical
# but still useful for debugging weight sharing issues
# This proves weight sharing is working: inference & training logprobs should converge
if all_training_logprobs:
# Store training logprob stats
train_flat = torch.cat(all_training_logprobs)
if train_flat.numel() > 0:
_logprob_alignment_stats["logprobs/training_mean"] = train_flat.mean().item()
_logprob_alignment_stats["logprobs/training_std"] = train_flat.std().item()
if all_inference_logprobs:
inf_flat = torch.cat(all_inference_logprobs)
if inf_flat.numel() > 0:
_logprob_alignment_stats["logprobs/inference_mean"] = inf_flat.mean().item()
_logprob_alignment_stats["logprobs/inference_std"] = inf_flat.std().item()
# Token-level alignment metrics - THE key metric for verifying weight sharing
# diff_abs_mean close to 0 = weights are truly shared
_logprob_alignment_stats["alignment/diff_mean"] = total_logprob_diff_mean / num_batches
_logprob_alignment_stats["alignment/diff_abs_mean"] = total_logprob_diff_abs_mean / num_batches
_logprob_alignment_stats["alignment/diff_max"] = total_logprob_diff_max
return result