mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-27 17:23:08 +00:00
logprobs
This commit is contained in:
parent
8f1f8acbde
commit
77c592c909
1 changed files with 48 additions and 6 deletions
|
|
@ -125,11 +125,16 @@ def compute_grpo_loss(
|
|||
# Expand advantages to match token shape [batch, 1] -> [batch, seq_len]
|
||||
adv_expanded = advantages.expand_as(logp_per_token).to(logp_per_token.device)
|
||||
|
||||
# Track logprobs for alignment verification
|
||||
inference_logprobs_flat = None
|
||||
logprob_diff_mean = 0.0
|
||||
logprob_diff_abs_mean = 0.0
|
||||
logprob_diff_max = 0.0
|
||||
|
||||
# === GRPO/PPO Loss Computation ===
|
||||
if use_reference_logprobs and inference_logprobs is not None:
|
||||
# Move inference logprobs to correct device/dtype
|
||||
ref_logprobs = inference_logprobs.to(logp_per_token.device, logp_per_token.dtype)
|
||||
|
||||
|
||||
# NOTE: inference_logprobs uses 1.0 for masked (prompt) positions, actual negative values for generated
|
||||
with torch.no_grad():
|
||||
|
|
@ -137,6 +142,17 @@ def compute_grpo_loss(
|
|||
ref_at_generated = (ref_logprobs * mask).sum() / mask.sum()
|
||||
train_at_generated = (logp_per_token * mask).sum() / mask.sum()
|
||||
|
||||
# Extract logprobs at generated positions for alignment tracking
|
||||
inference_logprobs_flat = ref_logprobs[mask.bool()].detach()
|
||||
training_at_mask = logp_per_token[mask.bool()].detach()
|
||||
|
||||
# Token-level difference: THE key metric for alignment verification
|
||||
# If weights are truly shared, this should be ~0 at step start
|
||||
token_diff = training_at_mask - inference_logprobs_flat
|
||||
logprob_diff_mean = token_diff.mean().item()
|
||||
logprob_diff_abs_mean = token_diff.abs().mean().item()
|
||||
logprob_diff_max = token_diff.abs().max().item()
|
||||
|
||||
# Check if ref logprobs are negative (as they should be for generated tokens)
|
||||
# If ref_at_generated is close to 1.0, that means the 1.0 placeholder is being used
|
||||
if ref_at_generated > 0.5:
|
||||
|
|
@ -234,12 +250,17 @@ def compute_grpo_loss(
|
|||
"pos_count": pos.sum().item(),
|
||||
"neg_count": neg.sum().item(),
|
||||
"training_logprobs": training_logprobs_flat,
|
||||
"inference_logprobs": inference_logprobs_flat,
|
||||
"interpretable_loss": interpretable_loss,
|
||||
# GRPO-specific metrics
|
||||
"kl_penalty": kl_penalty.item() if torch.is_tensor(kl_penalty) else kl_penalty,
|
||||
"mean_ratio": mean_ratio.item() if torch.is_tensor(mean_ratio) else mean_ratio,
|
||||
"mean_kl": mean_kl.item() if torch.is_tensor(mean_kl) else mean_kl,
|
||||
"clipped_fraction": clipped_fraction.item() if torch.is_tensor(clipped_fraction) else clipped_fraction,
|
||||
# Token-level alignment metrics (key for verifying weight sharing)
|
||||
"logprob_diff_mean": logprob_diff_mean,
|
||||
"logprob_diff_abs_mean": logprob_diff_abs_mean,
|
||||
"logprob_diff_max": logprob_diff_max,
|
||||
}
|
||||
|
||||
return total_loss, metrics
|
||||
|
|
@ -286,8 +307,12 @@ def run_training_step(
|
|||
total_mean_ratio = 0.0
|
||||
total_mean_kl = 0.0
|
||||
total_clipped_fraction = 0.0
|
||||
total_logprob_diff_mean = 0.0
|
||||
total_logprob_diff_abs_mean = 0.0
|
||||
total_logprob_diff_max = 0.0
|
||||
grad_norm = 0.0
|
||||
all_training_logprobs: List[torch.Tensor] = []
|
||||
all_inference_logprobs: List[torch.Tensor] = []
|
||||
|
||||
# Get GRPO hyperparameters from config
|
||||
kl_coef = getattr(config, 'kl_coef', 0.1)
|
||||
|
|
@ -335,9 +360,16 @@ def run_training_step(
|
|||
total_mean_kl += metrics.get("mean_kl", 0.0)
|
||||
total_clipped_fraction += metrics.get("clipped_fraction", 0.0)
|
||||
|
||||
# Collect training logprobs for alignment monitoring
|
||||
if "training_logprobs" in metrics:
|
||||
# Accumulate token-level alignment metrics
|
||||
total_logprob_diff_mean += metrics.get("logprob_diff_mean", 0.0)
|
||||
total_logprob_diff_abs_mean += metrics.get("logprob_diff_abs_mean", 0.0)
|
||||
total_logprob_diff_max = max(total_logprob_diff_max, metrics.get("logprob_diff_max", 0.0))
|
||||
|
||||
# Collect logprobs for alignment monitoring
|
||||
if "training_logprobs" in metrics and metrics["training_logprobs"] is not None:
|
||||
all_training_logprobs.append(metrics["training_logprobs"])
|
||||
if "inference_logprobs" in metrics and metrics["inference_logprobs"] is not None:
|
||||
all_inference_logprobs.append(metrics["inference_logprobs"])
|
||||
|
||||
# Gradient clipping and optimizer step
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||
|
|
@ -368,15 +400,25 @@ def run_training_step(
|
|||
}
|
||||
|
||||
# Compute logprob alignment stats for monitoring
|
||||
# NOTE: Now that we use proper GRPO, this is less critical
|
||||
# but still useful for debugging weight sharing issues
|
||||
# This proves weight sharing is working: inference & training logprobs should converge
|
||||
if all_training_logprobs:
|
||||
# Store training logprob stats
|
||||
train_flat = torch.cat(all_training_logprobs)
|
||||
if train_flat.numel() > 0:
|
||||
_logprob_alignment_stats["logprobs/training_mean"] = train_flat.mean().item()
|
||||
_logprob_alignment_stats["logprobs/training_std"] = train_flat.std().item()
|
||||
|
||||
if all_inference_logprobs:
|
||||
inf_flat = torch.cat(all_inference_logprobs)
|
||||
if inf_flat.numel() > 0:
|
||||
_logprob_alignment_stats["logprobs/inference_mean"] = inf_flat.mean().item()
|
||||
_logprob_alignment_stats["logprobs/inference_std"] = inf_flat.std().item()
|
||||
|
||||
# Token-level alignment metrics - THE key metric for verifying weight sharing
|
||||
# diff_abs_mean close to 0 = weights are truly shared
|
||||
_logprob_alignment_stats["alignment/diff_mean"] = total_logprob_diff_mean / num_batches
|
||||
_logprob_alignment_stats["alignment/diff_abs_mean"] = total_logprob_diff_abs_mean / num_batches
|
||||
_logprob_alignment_stats["alignment/diff_max"] = total_logprob_diff_max
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue