metric calc diff

This commit is contained in:
Jai Suphavadeeprasit 2026-01-29 13:25:06 -05:00
parent 4112329308
commit 05df4a4953

View file

@ -286,27 +286,12 @@ def compute_logprob_alignment(
stats["logprobs/training_mean"] = train_flat.mean().item()
stats["logprobs/training_std"] = train_flat.std().item()
# Compute diff (key metric for alignment validation)
# Compute diff (for tracking, not validation)
# NOTE: Per-token comparison is NOT reliable here because inference and training
# logprobs come from different batch orderings and can't be aligned token-by-token.
# The real-time test at startup is the proper alignment validation.
if "logprobs/inference_mean" in stats and "logprobs/training_mean" in stats:
# Old metric: difference of means (can be misleading)
stats["logprobs/diff_of_means"] = stats["logprobs/inference_mean"] - stats["logprobs/training_mean"]
# Better metric: mean of per-token absolute differences (like real-time test)
# This requires matching token counts
min_len = min(len(inf_filtered), train_flat.numel())
if min_len > 0:
inf_subset = inf_filtered[:min_len]
train_subset = train_flat[:min_len].float().cpu().numpy()
per_token_diff = np.abs(inf_subset - train_subset)
stats["logprobs/mean_abs_diff"] = float(np.mean(per_token_diff))
stats["logprobs/diff"] = stats["logprobs/mean_abs_diff"] # Use this as primary metric
if debug:
print(f" [DEBUG] Per-token comparison ({min_len} tokens)")
print(f" [DEBUG] Mean abs diff: {stats['logprobs/mean_abs_diff']:.4f}")
print(f" [DEBUG] First 5 diffs: {per_token_diff[:5]}")
else:
stats["logprobs/diff"] = stats["logprobs/diff_of_means"]
stats["logprobs/diff"] = stats["logprobs/inference_mean"] - stats["logprobs/training_mean"]
return stats
@ -404,9 +389,12 @@ def run_training_step(
}
# Compute logprob alignment stats
# NOTE: This comparison is approximate - inference and training logprobs
# come from different batching, so token-by-token alignment isn't possible.
# The real-time test at startup is the reliable alignment check.
if inference_logprobs is not None and all_training_logprobs:
alignment_stats = compute_logprob_alignment(
inference_logprobs, all_training_logprobs, debug=True # Enable for debugging
inference_logprobs, all_training_logprobs, debug=False
)
_logprob_alignment_stats.update(alignment_stats)
result["logprob_alignment"] = alignment_stats
@ -477,36 +465,10 @@ def log_metrics(
# - training_logprobs: from trainer's current forward pass
# After training starts, weights change, making comparison invalid.
#
# The diff WILL increase over training - this is EXPECTED, not a bug!
# Trust weight sharing if:
# 1. --enforce-eager is set on vLLM
# 2. IPC attachment succeeded with ~100% coverage
# 3. Initial step 1 diff is < 0.5 (before much training)
# Use mean_abs_diff if available (better metric), otherwise diff_of_means
mean_abs_diff = _logprob_alignment_stats.get("logprobs/mean_abs_diff")
if mean_abs_diff is not None:
# Per-token comparison (like real-time test)
if mean_abs_diff < 0.05:
status = "PERFECT"
elif mean_abs_diff < 0.15:
status = "OK"
elif mean_abs_diff < 0.3:
status = "OK (some drift)"
else:
status = "stale data"
print(f" LogProb Alignment: mean_abs_diff={mean_abs_diff:.4f} [{status}]")
else:
# Fallback to diff of means
if abs(diff) < 0.3:
status = "OK"
elif abs(diff) < 0.5:
status = "OK (data may be stale)"
else:
status = "stale data"
print(f" LogProb Alignment: inf={inf_mean:.4f}, train={train_mean:.4f}, "
f"diff={diff:.4f} [{status}]")
# NOTE: This diff is just for monitoring, not validation!
# The real-time test at startup is the reliable alignment check.
# This diff will naturally drift as training progresses (expected).
print(f" LogProb Stats: inf_mean={inf_mean:.4f}, train_mean={train_mean:.4f}")
if use_wandb:
log_dict = {