metric calc diff

This commit is contained in:
Jai Suphavadeeprasit 2026-01-29 13:11:29 -05:00
parent 04652fd97c
commit e2b111fea0
2 changed files with 60 additions and 13 deletions

View file

@ -288,10 +288,25 @@ def compute_logprob_alignment(
# Compute diff (key metric for alignment validation)
if "logprobs/inference_mean" in stats and "logprobs/training_mean" in stats:
stats["logprobs/diff"] = stats["logprobs/inference_mean"] - stats["logprobs/training_mean"]
# Old metric: difference of means (can be misleading)
stats["logprobs/diff_of_means"] = stats["logprobs/inference_mean"] - stats["logprobs/training_mean"]
# At step 0, this diff should be very close to 0 if weights are shared correctly
# A large diff indicates the training model is using different weights than vLLM
# Better metric: mean of per-token absolute differences (like real-time test)
# This requires matching token counts
min_len = min(len(inf_filtered), train_flat.numel())
if min_len > 0:
inf_subset = inf_filtered[:min_len]
train_subset = train_flat[:min_len].float().cpu().numpy()
per_token_diff = np.abs(inf_subset - train_subset)
stats["logprobs/mean_abs_diff"] = float(np.mean(per_token_diff))
stats["logprobs/diff"] = stats["logprobs/mean_abs_diff"] # Use this as primary metric
if debug:
print(f" [DEBUG] Per-token comparison ({min_len} tokens)")
print(f" [DEBUG] Mean abs diff: {stats['logprobs/mean_abs_diff']:.4f}")
print(f" [DEBUG] First 5 diffs: {per_token_diff[:5]}")
else:
stats["logprobs/diff"] = stats["logprobs/diff_of_means"]
return stats
@ -391,7 +406,7 @@ def run_training_step(
# Compute logprob alignment stats
if inference_logprobs is not None and all_training_logprobs:
alignment_stats = compute_logprob_alignment(
inference_logprobs, all_training_logprobs, debug=False
inference_logprobs, all_training_logprobs, debug=True # Enable for debugging
)
_logprob_alignment_stats.update(alignment_stats)
result["logprob_alignment"] = alignment_stats
@ -467,14 +482,31 @@ def log_metrics(
# 1. --enforce-eager is set on vLLM
# 2. IPC attachment succeeded with ~100% coverage
# 3. Initial step 1 diff is < 0.5 (before much training)
if abs(diff) < 0.3:
status = "OK"
elif abs(diff) < 0.5:
status = "OK (data may be stale)"
# Use mean_abs_diff if available (better metric), otherwise diff_of_means
mean_abs_diff = _logprob_alignment_stats.get("logprobs/mean_abs_diff")
if mean_abs_diff is not None:
# Per-token comparison (like real-time test)
if mean_abs_diff < 0.05:
status = "PERFECT"
elif mean_abs_diff < 0.15:
status = "OK"
elif mean_abs_diff < 0.3:
status = "OK (some drift)"
else:
status = "stale data"
print(f" LogProb Alignment: mean_abs_diff={mean_abs_diff:.4f} [{status}]")
else:
status = "stale data" # Expected after training progresses
print(f" LogProb Alignment: inf={inf_mean:.4f}, train={train_mean:.4f}, "
f"diff={diff:.4f} [{status}]")
# Fallback to diff of means
if abs(diff) < 0.3:
status = "OK"
elif abs(diff) < 0.5:
status = "OK (data may be stale)"
else:
status = "stale data"
print(f" LogProb Alignment: inf={inf_mean:.4f}, train={train_mean:.4f}, "
f"diff={diff:.4f} [{status}]")
if use_wandb:
log_dict = {