diff --git a/example_trainer/training.py b/example_trainer/training.py index fdf87c23..6dfe7f6a 100644 --- a/example_trainer/training.py +++ b/example_trainer/training.py @@ -125,11 +125,16 @@ def compute_grpo_loss( # Expand advantages to match token shape [batch, 1] -> [batch, seq_len] adv_expanded = advantages.expand_as(logp_per_token).to(logp_per_token.device) + # Track logprobs for alignment verification + inference_logprobs_flat = None + logprob_diff_mean = 0.0 + logprob_diff_abs_mean = 0.0 + logprob_diff_max = 0.0 + # === GRPO/PPO Loss Computation === if use_reference_logprobs and inference_logprobs is not None: # Move inference logprobs to correct device/dtype ref_logprobs = inference_logprobs.to(logp_per_token.device, logp_per_token.dtype) - # NOTE: inference_logprobs uses 1.0 for masked (prompt) positions, actual negative values for generated with torch.no_grad(): @@ -137,6 +142,17 @@ def compute_grpo_loss( ref_at_generated = (ref_logprobs * mask).sum() / mask.sum() train_at_generated = (logp_per_token * mask).sum() / mask.sum() + # Extract logprobs at generated positions for alignment tracking + inference_logprobs_flat = ref_logprobs[mask.bool()].detach() + training_at_mask = logp_per_token[mask.bool()].detach() + + # Token-level difference: THE key metric for alignment verification + # If weights are truly shared, this should be ~0 at step start + token_diff = training_at_mask - inference_logprobs_flat + logprob_diff_mean = token_diff.mean().item() + logprob_diff_abs_mean = token_diff.abs().mean().item() + logprob_diff_max = token_diff.abs().max().item() + # Check if ref logprobs are negative (as they should be for generated tokens) # If ref_at_generated is close to 1.0, that means the 1.0 placeholder is being used if ref_at_generated > 0.5: @@ -234,12 +250,17 @@ def compute_grpo_loss( "pos_count": pos.sum().item(), "neg_count": neg.sum().item(), "training_logprobs": training_logprobs_flat, + "inference_logprobs": inference_logprobs_flat, "interpretable_loss": interpretable_loss, # GRPO-specific metrics "kl_penalty": kl_penalty.item() if torch.is_tensor(kl_penalty) else kl_penalty, "mean_ratio": mean_ratio.item() if torch.is_tensor(mean_ratio) else mean_ratio, "mean_kl": mean_kl.item() if torch.is_tensor(mean_kl) else mean_kl, "clipped_fraction": clipped_fraction.item() if torch.is_tensor(clipped_fraction) else clipped_fraction, + # Token-level alignment metrics (key for verifying weight sharing) + "logprob_diff_mean": logprob_diff_mean, + "logprob_diff_abs_mean": logprob_diff_abs_mean, + "logprob_diff_max": logprob_diff_max, } return total_loss, metrics @@ -286,8 +307,12 @@ def run_training_step( total_mean_ratio = 0.0 total_mean_kl = 0.0 total_clipped_fraction = 0.0 + total_logprob_diff_mean = 0.0 + total_logprob_diff_abs_mean = 0.0 + total_logprob_diff_max = 0.0 grad_norm = 0.0 all_training_logprobs: List[torch.Tensor] = [] + all_inference_logprobs: List[torch.Tensor] = [] # Get GRPO hyperparameters from config kl_coef = getattr(config, 'kl_coef', 0.1) @@ -335,9 +360,16 @@ def run_training_step( total_mean_kl += metrics.get("mean_kl", 0.0) total_clipped_fraction += metrics.get("clipped_fraction", 0.0) - # Collect training logprobs for alignment monitoring - if "training_logprobs" in metrics: + # Accumulate token-level alignment metrics + total_logprob_diff_mean += metrics.get("logprob_diff_mean", 0.0) + total_logprob_diff_abs_mean += metrics.get("logprob_diff_abs_mean", 0.0) + total_logprob_diff_max = max(total_logprob_diff_max, metrics.get("logprob_diff_max", 0.0)) + + # Collect logprobs for alignment monitoring + if "training_logprobs" in metrics and metrics["training_logprobs"] is not None: all_training_logprobs.append(metrics["training_logprobs"]) + if "inference_logprobs" in metrics and metrics["inference_logprobs"] is not None: + all_inference_logprobs.append(metrics["inference_logprobs"]) # Gradient clipping and optimizer step grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) @@ -368,15 +400,25 @@ def run_training_step( } # Compute logprob alignment stats for monitoring - # NOTE: Now that we use proper GRPO, this is less critical - # but still useful for debugging weight sharing issues + # This proves weight sharing is working: inference & training logprobs should converge if all_training_logprobs: - # Store training logprob stats train_flat = torch.cat(all_training_logprobs) if train_flat.numel() > 0: _logprob_alignment_stats["logprobs/training_mean"] = train_flat.mean().item() _logprob_alignment_stats["logprobs/training_std"] = train_flat.std().item() + if all_inference_logprobs: + inf_flat = torch.cat(all_inference_logprobs) + if inf_flat.numel() > 0: + _logprob_alignment_stats["logprobs/inference_mean"] = inf_flat.mean().item() + _logprob_alignment_stats["logprobs/inference_std"] = inf_flat.std().item() + + # Token-level alignment metrics - THE key metric for verifying weight sharing + # diff_abs_mean close to 0 = weights are truly shared + _logprob_alignment_stats["alignment/diff_mean"] = total_logprob_diff_mean / num_batches + _logprob_alignment_stats["alignment/diff_abs_mean"] = total_logprob_diff_abs_mean / num_batches + _logprob_alignment_stats["alignment/diff_max"] = total_logprob_diff_max + return result