mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-28 17:29:30 +00:00
metric calc diff
This commit is contained in:
parent
04652fd97c
commit
e2b111fea0
2 changed files with 60 additions and 13 deletions
|
|
@ -288,10 +288,25 @@ def compute_logprob_alignment(
|
|||
|
||||
# Compute diff (key metric for alignment validation)
|
||||
if "logprobs/inference_mean" in stats and "logprobs/training_mean" in stats:
|
||||
stats["logprobs/diff"] = stats["logprobs/inference_mean"] - stats["logprobs/training_mean"]
|
||||
# Old metric: difference of means (can be misleading)
|
||||
stats["logprobs/diff_of_means"] = stats["logprobs/inference_mean"] - stats["logprobs/training_mean"]
|
||||
|
||||
# At step 0, this diff should be very close to 0 if weights are shared correctly
|
||||
# A large diff indicates the training model is using different weights than vLLM
|
||||
# Better metric: mean of per-token absolute differences (like real-time test)
|
||||
# This requires matching token counts
|
||||
min_len = min(len(inf_filtered), train_flat.numel())
|
||||
if min_len > 0:
|
||||
inf_subset = inf_filtered[:min_len]
|
||||
train_subset = train_flat[:min_len].float().cpu().numpy()
|
||||
per_token_diff = np.abs(inf_subset - train_subset)
|
||||
stats["logprobs/mean_abs_diff"] = float(np.mean(per_token_diff))
|
||||
stats["logprobs/diff"] = stats["logprobs/mean_abs_diff"] # Use this as primary metric
|
||||
|
||||
if debug:
|
||||
print(f" [DEBUG] Per-token comparison ({min_len} tokens)")
|
||||
print(f" [DEBUG] Mean abs diff: {stats['logprobs/mean_abs_diff']:.4f}")
|
||||
print(f" [DEBUG] First 5 diffs: {per_token_diff[:5]}")
|
||||
else:
|
||||
stats["logprobs/diff"] = stats["logprobs/diff_of_means"]
|
||||
|
||||
return stats
|
||||
|
||||
|
|
@ -391,7 +406,7 @@ def run_training_step(
|
|||
# Compute logprob alignment stats
|
||||
if inference_logprobs is not None and all_training_logprobs:
|
||||
alignment_stats = compute_logprob_alignment(
|
||||
inference_logprobs, all_training_logprobs, debug=False
|
||||
inference_logprobs, all_training_logprobs, debug=True # Enable for debugging
|
||||
)
|
||||
_logprob_alignment_stats.update(alignment_stats)
|
||||
result["logprob_alignment"] = alignment_stats
|
||||
|
|
@ -467,14 +482,31 @@ def log_metrics(
|
|||
# 1. --enforce-eager is set on vLLM
|
||||
# 2. IPC attachment succeeded with ~100% coverage
|
||||
# 3. Initial step 1 diff is < 0.5 (before much training)
|
||||
if abs(diff) < 0.3:
|
||||
status = "OK"
|
||||
elif abs(diff) < 0.5:
|
||||
status = "OK (data may be stale)"
|
||||
|
||||
# Use mean_abs_diff if available (better metric), otherwise diff_of_means
|
||||
mean_abs_diff = _logprob_alignment_stats.get("logprobs/mean_abs_diff")
|
||||
|
||||
if mean_abs_diff is not None:
|
||||
# Per-token comparison (like real-time test)
|
||||
if mean_abs_diff < 0.05:
|
||||
status = "PERFECT"
|
||||
elif mean_abs_diff < 0.15:
|
||||
status = "OK"
|
||||
elif mean_abs_diff < 0.3:
|
||||
status = "OK (some drift)"
|
||||
else:
|
||||
status = "stale data"
|
||||
print(f" LogProb Alignment: mean_abs_diff={mean_abs_diff:.4f} [{status}]")
|
||||
else:
|
||||
status = "stale data" # Expected after training progresses
|
||||
print(f" LogProb Alignment: inf={inf_mean:.4f}, train={train_mean:.4f}, "
|
||||
f"diff={diff:.4f} [{status}]")
|
||||
# Fallback to diff of means
|
||||
if abs(diff) < 0.3:
|
||||
status = "OK"
|
||||
elif abs(diff) < 0.5:
|
||||
status = "OK (data may be stale)"
|
||||
else:
|
||||
status = "stale data"
|
||||
print(f" LogProb Alignment: inf={inf_mean:.4f}, train={train_mean:.4f}, "
|
||||
f"diff={diff:.4f} [{status}]")
|
||||
|
||||
if use_wandb:
|
||||
log_dict = {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue