diff --git a/example_trainer/training.py b/example_trainer/training.py index 4f1b8ccc..1a49364f 100644 --- a/example_trainer/training.py +++ b/example_trainer/training.py @@ -286,27 +286,12 @@ def compute_logprob_alignment( stats["logprobs/training_mean"] = train_flat.mean().item() stats["logprobs/training_std"] = train_flat.std().item() - # Compute diff (key metric for alignment validation) + # Compute diff (for tracking, not validation) + # NOTE: Per-token comparison is NOT reliable here because inference and training + # logprobs come from different batch orderings and can't be aligned token-by-token. + # The real-time test at startup is the proper alignment validation. if "logprobs/inference_mean" in stats and "logprobs/training_mean" in stats: - # Old metric: difference of means (can be misleading) - stats["logprobs/diff_of_means"] = stats["logprobs/inference_mean"] - stats["logprobs/training_mean"] - - # Better metric: mean of per-token absolute differences (like real-time test) - # This requires matching token counts - min_len = min(len(inf_filtered), train_flat.numel()) - if min_len > 0: - inf_subset = inf_filtered[:min_len] - train_subset = train_flat[:min_len].float().cpu().numpy() - per_token_diff = np.abs(inf_subset - train_subset) - stats["logprobs/mean_abs_diff"] = float(np.mean(per_token_diff)) - stats["logprobs/diff"] = stats["logprobs/mean_abs_diff"] # Use this as primary metric - - if debug: - print(f" [DEBUG] Per-token comparison ({min_len} tokens)") - print(f" [DEBUG] Mean abs diff: {stats['logprobs/mean_abs_diff']:.4f}") - print(f" [DEBUG] First 5 diffs: {per_token_diff[:5]}") - else: - stats["logprobs/diff"] = stats["logprobs/diff_of_means"] + stats["logprobs/diff"] = stats["logprobs/inference_mean"] - stats["logprobs/training_mean"] return stats @@ -404,9 +389,12 @@ def run_training_step( } # Compute logprob alignment stats + # NOTE: This comparison is approximate - inference and training logprobs + # come from different batching, so token-by-token alignment isn't possible. + # The real-time test at startup is the reliable alignment check. if inference_logprobs is not None and all_training_logprobs: alignment_stats = compute_logprob_alignment( - inference_logprobs, all_training_logprobs, debug=True # Enable for debugging + inference_logprobs, all_training_logprobs, debug=False ) _logprob_alignment_stats.update(alignment_stats) result["logprob_alignment"] = alignment_stats @@ -477,36 +465,10 @@ def log_metrics( # - training_logprobs: from trainer's current forward pass # After training starts, weights change, making comparison invalid. # - # The diff WILL increase over training - this is EXPECTED, not a bug! - # Trust weight sharing if: - # 1. --enforce-eager is set on vLLM - # 2. IPC attachment succeeded with ~100% coverage - # 3. Initial step 1 diff is < 0.5 (before much training) - - # Use mean_abs_diff if available (better metric), otherwise diff_of_means - mean_abs_diff = _logprob_alignment_stats.get("logprobs/mean_abs_diff") - - if mean_abs_diff is not None: - # Per-token comparison (like real-time test) - if mean_abs_diff < 0.05: - status = "PERFECT" - elif mean_abs_diff < 0.15: - status = "OK" - elif mean_abs_diff < 0.3: - status = "OK (some drift)" - else: - status = "stale data" - print(f" LogProb Alignment: mean_abs_diff={mean_abs_diff:.4f} [{status}]") - else: - # Fallback to diff of means - if abs(diff) < 0.3: - status = "OK" - elif abs(diff) < 0.5: - status = "OK (data may be stale)" - else: - status = "stale data" - print(f" LogProb Alignment: inf={inf_mean:.4f}, train={train_mean:.4f}, " - f"diff={diff:.4f} [{status}]") + # NOTE: This diff is just for monitoring, not validation! + # The real-time test at startup is the reliable alignment check. + # This diff will naturally drift as training progresses (expected). + print(f" LogProb Stats: inf_mean={inf_mean:.4f}, train_mean={train_mean:.4f}") if use_wandb: log_dict = {