mirror of
https://github.com/NousResearch/atropos.git
synced 2026-05-01 17:45:16 +00:00
metric calc diff
This commit is contained in:
parent
4112329308
commit
05df4a4953
1 changed files with 13 additions and 51 deletions
|
|
@ -286,27 +286,12 @@ def compute_logprob_alignment(
|
|||
stats["logprobs/training_mean"] = train_flat.mean().item()
|
||||
stats["logprobs/training_std"] = train_flat.std().item()
|
||||
|
||||
# Compute diff (key metric for alignment validation)
|
||||
# Compute diff (for tracking, not validation)
|
||||
# NOTE: Per-token comparison is NOT reliable here because inference and training
|
||||
# logprobs come from different batch orderings and can't be aligned token-by-token.
|
||||
# The real-time test at startup is the proper alignment validation.
|
||||
if "logprobs/inference_mean" in stats and "logprobs/training_mean" in stats:
|
||||
# Old metric: difference of means (can be misleading)
|
||||
stats["logprobs/diff_of_means"] = stats["logprobs/inference_mean"] - stats["logprobs/training_mean"]
|
||||
|
||||
# Better metric: mean of per-token absolute differences (like real-time test)
|
||||
# This requires matching token counts
|
||||
min_len = min(len(inf_filtered), train_flat.numel())
|
||||
if min_len > 0:
|
||||
inf_subset = inf_filtered[:min_len]
|
||||
train_subset = train_flat[:min_len].float().cpu().numpy()
|
||||
per_token_diff = np.abs(inf_subset - train_subset)
|
||||
stats["logprobs/mean_abs_diff"] = float(np.mean(per_token_diff))
|
||||
stats["logprobs/diff"] = stats["logprobs/mean_abs_diff"] # Use this as primary metric
|
||||
|
||||
if debug:
|
||||
print(f" [DEBUG] Per-token comparison ({min_len} tokens)")
|
||||
print(f" [DEBUG] Mean abs diff: {stats['logprobs/mean_abs_diff']:.4f}")
|
||||
print(f" [DEBUG] First 5 diffs: {per_token_diff[:5]}")
|
||||
else:
|
||||
stats["logprobs/diff"] = stats["logprobs/diff_of_means"]
|
||||
stats["logprobs/diff"] = stats["logprobs/inference_mean"] - stats["logprobs/training_mean"]
|
||||
|
||||
return stats
|
||||
|
||||
|
|
@ -404,9 +389,12 @@ def run_training_step(
|
|||
}
|
||||
|
||||
# Compute logprob alignment stats
|
||||
# NOTE: This comparison is approximate - inference and training logprobs
|
||||
# come from different batching, so token-by-token alignment isn't possible.
|
||||
# The real-time test at startup is the reliable alignment check.
|
||||
if inference_logprobs is not None and all_training_logprobs:
|
||||
alignment_stats = compute_logprob_alignment(
|
||||
inference_logprobs, all_training_logprobs, debug=True # Enable for debugging
|
||||
inference_logprobs, all_training_logprobs, debug=False
|
||||
)
|
||||
_logprob_alignment_stats.update(alignment_stats)
|
||||
result["logprob_alignment"] = alignment_stats
|
||||
|
|
@ -477,36 +465,10 @@ def log_metrics(
|
|||
# - training_logprobs: from trainer's current forward pass
|
||||
# After training starts, weights change, making comparison invalid.
|
||||
#
|
||||
# The diff WILL increase over training - this is EXPECTED, not a bug!
|
||||
# Trust weight sharing if:
|
||||
# 1. --enforce-eager is set on vLLM
|
||||
# 2. IPC attachment succeeded with ~100% coverage
|
||||
# 3. Initial step 1 diff is < 0.5 (before much training)
|
||||
|
||||
# Use mean_abs_diff if available (better metric), otherwise diff_of_means
|
||||
mean_abs_diff = _logprob_alignment_stats.get("logprobs/mean_abs_diff")
|
||||
|
||||
if mean_abs_diff is not None:
|
||||
# Per-token comparison (like real-time test)
|
||||
if mean_abs_diff < 0.05:
|
||||
status = "PERFECT"
|
||||
elif mean_abs_diff < 0.15:
|
||||
status = "OK"
|
||||
elif mean_abs_diff < 0.3:
|
||||
status = "OK (some drift)"
|
||||
else:
|
||||
status = "stale data"
|
||||
print(f" LogProb Alignment: mean_abs_diff={mean_abs_diff:.4f} [{status}]")
|
||||
else:
|
||||
# Fallback to diff of means
|
||||
if abs(diff) < 0.3:
|
||||
status = "OK"
|
||||
elif abs(diff) < 0.5:
|
||||
status = "OK (data may be stale)"
|
||||
else:
|
||||
status = "stale data"
|
||||
print(f" LogProb Alignment: inf={inf_mean:.4f}, train={train_mean:.4f}, "
|
||||
f"diff={diff:.4f} [{status}]")
|
||||
# NOTE: This diff is just for monitoring, not validation!
|
||||
# The real-time test at startup is the reliable alignment check.
|
||||
# This diff will naturally drift as training progresses (expected).
|
||||
print(f" LogProb Stats: inf_mean={inf_mean:.4f}, train_mean={train_mean:.4f}")
|
||||
|
||||
if use_wandb:
|
||||
log_dict = {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue