mirror of
https://github.com/NousResearch/atropos.git
synced 2026-05-01 17:45:16 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
d07ab3e3ce
commit
5cfd1929f1
19 changed files with 708 additions and 452 deletions
|
|
@ -18,10 +18,10 @@ import wandb
|
|||
|
||||
from .config import TrainingConfig
|
||||
|
||||
|
||||
# Global storage for logprob alignment stats
|
||||
_logprob_alignment_stats: Dict[str, float] = {}
|
||||
|
||||
|
||||
def setup_wandb(config: TrainingConfig) -> bool:
|
||||
"""
|
||||
Initialize Weights & Biases logging if enabled.
|
||||
|
|
@ -80,12 +80,12 @@ def compute_grpo_loss(
|
|||
- Importance sampling ratio: policy(a|s) / policy_old(a|s)
|
||||
- PPO-style clipping to prevent large updates
|
||||
- KL penalty to prevent reward hacking/policy collapse
|
||||
|
||||
|
||||
The loss encourages the model to:
|
||||
- Increase probability for tokens with positive advantages
|
||||
- Decrease probability for tokens with negative advantages
|
||||
- Stay close to the reference policy (inference-time policy)
|
||||
|
||||
|
||||
Args:
|
||||
model: The model to compute loss for
|
||||
tokens: Input token IDs [batch, seq_len]
|
||||
|
|
@ -105,7 +105,7 @@ def compute_grpo_loss(
|
|||
outputs = model(tokens)
|
||||
logits = outputs.logits
|
||||
|
||||
# Temperature scaling for training otherwise likely ratio is off
|
||||
# Temperature scaling for training otherwise likely ratio is off
|
||||
t = temperatures.to(logits.device, logits.dtype)
|
||||
t = torch.where(t <= 0, torch.ones_like(t), t)
|
||||
scaled_logits = logits / t
|
||||
|
|
@ -130,48 +130,56 @@ def compute_grpo_loss(
|
|||
logprob_diff_mean = 0.0
|
||||
logprob_diff_abs_mean = 0.0
|
||||
logprob_diff_max = 0.0
|
||||
|
||||
|
||||
# === GRPO/PPO Loss Computation ===
|
||||
if use_reference_logprobs and inference_logprobs is not None:
|
||||
# Move inference logprobs to correct device/dtype
|
||||
ref_logprobs = inference_logprobs.to(logp_per_token.device, logp_per_token.dtype)
|
||||
ref_logprobs = inference_logprobs.to(
|
||||
logp_per_token.device, logp_per_token.dtype
|
||||
)
|
||||
|
||||
# NOTE: inference_logprobs uses 1.0 for masked (prompt) positions, actual negative values for generated
|
||||
with torch.no_grad():
|
||||
# Only look at generated positions (where mask == 1)
|
||||
ref_at_generated = (ref_logprobs * mask).sum() / mask.sum()
|
||||
train_at_generated = (logp_per_token * mask).sum() / mask.sum()
|
||||
|
||||
|
||||
# Extract logprobs at generated positions for alignment tracking
|
||||
inference_logprobs_flat = ref_logprobs[mask.bool()].detach()
|
||||
training_at_mask = logp_per_token[mask.bool()].detach()
|
||||
|
||||
|
||||
# Token-level difference: THE key metric for alignment verification
|
||||
# If weights are truly shared, this should be ~0 at step start
|
||||
token_diff = training_at_mask - inference_logprobs_flat
|
||||
logprob_diff_mean = token_diff.mean().item()
|
||||
logprob_diff_abs_mean = token_diff.abs().mean().item()
|
||||
logprob_diff_max = token_diff.abs().max().item()
|
||||
|
||||
|
||||
# Check if ref logprobs are negative (as they should be for generated tokens)
|
||||
# If ref_at_generated is close to 1.0, that means the 1.0 placeholder is being used
|
||||
if ref_at_generated > 0.5:
|
||||
print(f" [WARNING] ref_logprobs avg {ref_at_generated:.3f} (should be negative!)")
|
||||
print(" [WARNING] This suggests inference_logprobs alignment is wrong")
|
||||
print(
|
||||
f" [WARNING] ref_logprobs avg {ref_at_generated:.3f} (should be negative!)"
|
||||
)
|
||||
print(
|
||||
" [WARNING] This suggests inference_logprobs alignment is wrong"
|
||||
)
|
||||
elif abs(ref_at_generated - train_at_generated) > 2.0:
|
||||
print(f" [DEBUG] Logprob gap: ref={ref_at_generated:.3f}, train={train_at_generated:.3f}")
|
||||
|
||||
print(
|
||||
f" [DEBUG] Logprob gap: ref={ref_at_generated:.3f}, train={train_at_generated:.3f}"
|
||||
)
|
||||
|
||||
# Compute importance sampling ratio: policy(a|s) / policy_old(a|s) = exp(log policy - log policy_old)
|
||||
log_ratio = logp_per_token - ref_logprobs
|
||||
ratio = torch.exp(log_ratio)
|
||||
|
||||
|
||||
# PPO-style clipping
|
||||
clipped_ratio = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps)
|
||||
|
||||
|
||||
# Surrogate objectives
|
||||
surr1 = ratio * adv_expanded
|
||||
surr2 = clipped_ratio * adv_expanded
|
||||
|
||||
|
||||
# Pessimistic bound: min for positive advantages, max for negative
|
||||
# This is equivalent to: -min(ratio * A, clipped_ratio * A) when A > 0
|
||||
# -max(ratio * A, clipped_ratio * A) when A < 0
|
||||
|
|
@ -180,10 +188,10 @@ def compute_grpo_loss(
|
|||
torch.min(surr1, surr2),
|
||||
torch.max(surr1, surr2),
|
||||
)
|
||||
|
||||
|
||||
# Average over tokens, then over batch
|
||||
policy_loss = ((policy_loss_per_token * mask).sum(dim=-1) / mask_sum).mean()
|
||||
|
||||
|
||||
# KL penalty: encourage staying close to reference policy
|
||||
# Using Schulman's unbiased KL estimator from the DeepSeek GRPO paper (Equation 4):
|
||||
# This estimator is guaranteed to be non-negative (unlike squared log-ratio).
|
||||
|
|
@ -192,23 +200,27 @@ def compute_grpo_loss(
|
|||
# = exp(-log_ratio) + log_ratio - 1
|
||||
kl_per_token = torch.exp(-log_ratio) + log_ratio - 1.0
|
||||
kl_penalty = ((kl_per_token * mask).sum(dim=-1) / mask_sum).mean()
|
||||
total_loss = (policy_loss + kl_coef * kl_penalty) / gradient_accumulation_steps
|
||||
total_loss = (
|
||||
policy_loss + kl_coef * kl_penalty
|
||||
) / gradient_accumulation_steps
|
||||
else:
|
||||
kl_penalty = torch.tensor(0.0, device=logp_per_token.device)
|
||||
total_loss = policy_loss / gradient_accumulation_steps
|
||||
|
||||
|
||||
# Compute metrics for logging
|
||||
with torch.no_grad():
|
||||
# Fraction of tokens where ratio was clipped
|
||||
clipped_fraction = ((ratio < 1.0 - clip_eps) | (ratio > 1.0 + clip_eps)).float()
|
||||
clipped_fraction = (
|
||||
(ratio < 1.0 - clip_eps) | (ratio > 1.0 + clip_eps)
|
||||
).float()
|
||||
clipped_fraction = (clipped_fraction * mask).sum() / mask.sum()
|
||||
|
||||
|
||||
# Mean ratio and KL for monitoring (using Schulman's estimator)
|
||||
mean_ratio = (ratio * mask).sum() / mask.sum()
|
||||
# Schulman KL: exp(-log_ratio) + log_ratio - 1
|
||||
schulman_kl = torch.exp(-log_ratio) + log_ratio - 1.0
|
||||
mean_kl = (schulman_kl * mask).sum() / mask.sum()
|
||||
|
||||
|
||||
# For backward compatibility: collect training logprobs
|
||||
raw_logp_per_token = -F.cross_entropy(
|
||||
outputs.logits.view(-1, outputs.logits.size(-1)),
|
||||
|
|
@ -239,10 +251,10 @@ def compute_grpo_loss(
|
|||
avg_logp = (logp_per_token * mask_float).sum(dim=-1) / mask_sum
|
||||
pos_logp = (logp_per_token * pos).mean().item()
|
||||
neg_logp = (logp_per_token * neg).mean().item()
|
||||
|
||||
|
||||
# Interpretable metric: advantage-weighted average logprob
|
||||
interpretable_loss = (avg_logp * advantages.squeeze()).mean().item()
|
||||
|
||||
|
||||
metrics = {
|
||||
"pos_logp": pos_logp,
|
||||
"neg_logp": neg_logp,
|
||||
|
|
@ -256,7 +268,11 @@ def compute_grpo_loss(
|
|||
"kl_penalty": kl_penalty.item() if torch.is_tensor(kl_penalty) else kl_penalty,
|
||||
"mean_ratio": mean_ratio.item() if torch.is_tensor(mean_ratio) else mean_ratio,
|
||||
"mean_kl": mean_kl.item() if torch.is_tensor(mean_kl) else mean_kl,
|
||||
"clipped_fraction": clipped_fraction.item() if torch.is_tensor(clipped_fraction) else clipped_fraction,
|
||||
"clipped_fraction": (
|
||||
clipped_fraction.item()
|
||||
if torch.is_tensor(clipped_fraction)
|
||||
else clipped_fraction
|
||||
),
|
||||
# Token-level alignment metrics (key for verifying weight sharing)
|
||||
"logprob_diff_mean": logprob_diff_mean,
|
||||
"logprob_diff_abs_mean": logprob_diff_abs_mean,
|
||||
|
|
@ -284,7 +300,7 @@ def run_training_step(
|
|||
2. Backward pass with gradient accumulation
|
||||
3. Gradient clipping
|
||||
4. Optimizer step
|
||||
|
||||
|
||||
Args:
|
||||
model: The model to train
|
||||
optimizer: The optimizer
|
||||
|
|
@ -315,23 +331,25 @@ def run_training_step(
|
|||
all_inference_logprobs: List[torch.Tensor] = []
|
||||
|
||||
# Get GRPO hyperparameters from config
|
||||
kl_coef = getattr(config, 'kl_coef', 0.1)
|
||||
clip_eps = getattr(config, 'clip_eps', 0.2)
|
||||
use_reference_logprobs = getattr(config, 'use_reference_logprobs', True)
|
||||
kl_coef = getattr(config, "kl_coef", 0.1)
|
||||
clip_eps = getattr(config, "clip_eps", 0.2)
|
||||
use_reference_logprobs = getattr(config, "use_reference_logprobs", True)
|
||||
|
||||
# Accumulate gradients over micro-batches
|
||||
num_batches = len(token_batches) if token_batches else 1
|
||||
|
||||
for batch_idx, (tokens, labels, advantages, temperatures) in enumerate(zip(
|
||||
token_batches, label_batches, advantage_batches, temperature_batches
|
||||
)):
|
||||
|
||||
for batch_idx, (tokens, labels, advantages, temperatures) in enumerate(
|
||||
zip(token_batches, label_batches, advantage_batches, temperature_batches)
|
||||
):
|
||||
tokens = tokens.to(config.device)
|
||||
labels = labels.to(config.device)
|
||||
advantages = advantages.to(config.device)
|
||||
|
||||
|
||||
# Get corresponding inference logprobs batch if available
|
||||
inf_logprobs = None
|
||||
if inference_logprob_batches is not None and batch_idx < len(inference_logprob_batches):
|
||||
if inference_logprob_batches is not None and batch_idx < len(
|
||||
inference_logprob_batches
|
||||
):
|
||||
inf_logprobs = inference_logprob_batches[batch_idx]
|
||||
|
||||
loss, metrics = compute_grpo_loss(
|
||||
|
|
@ -353,29 +371,34 @@ def run_training_step(
|
|||
total_neg_logp += metrics["neg_logp"]
|
||||
total_pos += metrics["pos_count"]
|
||||
total_neg += metrics["neg_count"]
|
||||
|
||||
|
||||
# Accumulate GRPO-specific metrics
|
||||
total_kl_penalty += metrics.get("kl_penalty", 0.0)
|
||||
total_mean_ratio += metrics.get("mean_ratio", 1.0)
|
||||
total_mean_kl += metrics.get("mean_kl", 0.0)
|
||||
total_clipped_fraction += metrics.get("clipped_fraction", 0.0)
|
||||
|
||||
|
||||
# Accumulate token-level alignment metrics
|
||||
total_logprob_diff_mean += metrics.get("logprob_diff_mean", 0.0)
|
||||
total_logprob_diff_abs_mean += metrics.get("logprob_diff_abs_mean", 0.0)
|
||||
total_logprob_diff_max = max(total_logprob_diff_max, metrics.get("logprob_diff_max", 0.0))
|
||||
|
||||
total_logprob_diff_max = max(
|
||||
total_logprob_diff_max, metrics.get("logprob_diff_max", 0.0)
|
||||
)
|
||||
|
||||
# Collect logprobs for alignment monitoring
|
||||
if "training_logprobs" in metrics and metrics["training_logprobs"] is not None:
|
||||
all_training_logprobs.append(metrics["training_logprobs"])
|
||||
if "inference_logprobs" in metrics and metrics["inference_logprobs"] is not None:
|
||||
if (
|
||||
"inference_logprobs" in metrics
|
||||
and metrics["inference_logprobs"] is not None
|
||||
):
|
||||
all_inference_logprobs.append(metrics["inference_logprobs"])
|
||||
|
||||
# Gradient clipping and optimizer step
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
||||
# Help prevent memory fragmentation
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
|
@ -387,7 +410,7 @@ def run_training_step(
|
|||
|
||||
result = {
|
||||
"loss": total_loss,
|
||||
"grad_norm": grad_norm.item() if hasattr(grad_norm, 'item') else grad_norm,
|
||||
"grad_norm": grad_norm.item() if hasattr(grad_norm, "item") else grad_norm,
|
||||
"pos_logp": total_pos_logp,
|
||||
"neg_logp": total_neg_logp,
|
||||
"pos_count": total_pos,
|
||||
|
|
@ -398,27 +421,33 @@ def run_training_step(
|
|||
"mean_kl": total_mean_kl / num_batches,
|
||||
"clipped_fraction": total_clipped_fraction / num_batches,
|
||||
}
|
||||
|
||||
|
||||
# Compute logprob alignment stats for monitoring
|
||||
# This proves weight sharing is working: inference & training logprobs should converge
|
||||
if all_training_logprobs:
|
||||
train_flat = torch.cat(all_training_logprobs)
|
||||
if train_flat.numel() > 0:
|
||||
_logprob_alignment_stats["logprobs/training_mean"] = train_flat.mean().item()
|
||||
_logprob_alignment_stats["logprobs/training_mean"] = (
|
||||
train_flat.mean().item()
|
||||
)
|
||||
_logprob_alignment_stats["logprobs/training_std"] = train_flat.std().item()
|
||||
|
||||
|
||||
if all_inference_logprobs:
|
||||
inf_flat = torch.cat(all_inference_logprobs)
|
||||
if inf_flat.numel() > 0:
|
||||
_logprob_alignment_stats["logprobs/inference_mean"] = inf_flat.mean().item()
|
||||
_logprob_alignment_stats["logprobs/inference_std"] = inf_flat.std().item()
|
||||
|
||||
|
||||
# Token-level alignment metrics - THE key metric for verifying weight sharing
|
||||
# diff_abs_mean close to 0 = weights are truly shared
|
||||
_logprob_alignment_stats["alignment/diff_mean"] = total_logprob_diff_mean / num_batches
|
||||
_logprob_alignment_stats["alignment/diff_abs_mean"] = total_logprob_diff_abs_mean / num_batches
|
||||
_logprob_alignment_stats["alignment/diff_mean"] = (
|
||||
total_logprob_diff_mean / num_batches
|
||||
)
|
||||
_logprob_alignment_stats["alignment/diff_abs_mean"] = (
|
||||
total_logprob_diff_abs_mean / num_batches
|
||||
)
|
||||
_logprob_alignment_stats["alignment/diff_max"] = total_logprob_diff_max
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
|
@ -464,7 +493,7 @@ def log_metrics(
|
|||
mean_ratio = metrics.get("mean_ratio", 1.0)
|
||||
mean_kl = metrics.get("mean_kl", 0)
|
||||
clipped_frac = metrics.get("clipped_fraction", 0)
|
||||
|
||||
|
||||
if kl_penalty > 0 or mean_kl > 0:
|
||||
print(
|
||||
f" GRPO: KL={mean_kl:.4f}, ratio={mean_ratio:.3f}, "
|
||||
|
|
@ -495,15 +524,20 @@ def log_metrics(
|
|||
"grpo/clipped_fraction": clipped_frac,
|
||||
}
|
||||
# Add timing metrics if present
|
||||
for key in ["step_time", "sync_time", "data_fetch_time",
|
||||
"gpu_memory_gb", "gpu_memory_reserved_gb"]:
|
||||
for key in [
|
||||
"step_time",
|
||||
"sync_time",
|
||||
"data_fetch_time",
|
||||
"gpu_memory_gb",
|
||||
"gpu_memory_reserved_gb",
|
||||
]:
|
||||
if key in metrics:
|
||||
log_dict[f"train/{key}"] = metrics[key]
|
||||
|
||||
# Add logprob alignment stats
|
||||
|
||||
# Add logprob alignment stats
|
||||
if _logprob_alignment_stats:
|
||||
log_dict.update(_logprob_alignment_stats)
|
||||
|
||||
|
||||
if extra_metrics:
|
||||
log_dict.update(extra_metrics)
|
||||
wandb.log(log_dict, step=step)
|
||||
|
|
@ -549,7 +583,9 @@ def finalize_training(
|
|||
total_step_time = sum(step_times)
|
||||
avg_sync_time = sum(sync_times) / len(sync_times) if sync_times else 0
|
||||
total_sync_time = sum(sync_times)
|
||||
avg_data_fetch = sum(data_fetch_times) / len(data_fetch_times) if data_fetch_times else 0
|
||||
avg_data_fetch = (
|
||||
sum(data_fetch_times) / len(data_fetch_times) if data_fetch_times else 0
|
||||
)
|
||||
total_data_fetch = sum(data_fetch_times)
|
||||
avg_gpu_mem = sum(gpu_memories) / len(gpu_memories) if gpu_memories else 0
|
||||
|
||||
|
|
@ -557,13 +593,17 @@ def finalize_training(
|
|||
print(f"\n{'='*70}")
|
||||
print(f"BENCHMARK SUMMARY ({mode})")
|
||||
print(f"{'='*70}")
|
||||
print(f" Total training time: {total_time:.2f}s ({total_time/60:.2f} min)")
|
||||
print(
|
||||
f" Total training time: {total_time:.2f}s ({total_time/60:.2f} min)"
|
||||
)
|
||||
print(f" Total steps: {total_steps}")
|
||||
print(" ")
|
||||
print(" TIMING BREAKDOWN:")
|
||||
print(f" Avg step time: {avg_step_time:.2f}s")
|
||||
print(f" Total step time: {total_step_time:.2f}s")
|
||||
print(f" Avg sync time: {avg_sync_time:.2f}s (x{len(sync_times)} syncs)")
|
||||
print(
|
||||
f" Avg sync time: {avg_sync_time:.2f}s (x{len(sync_times)} syncs)"
|
||||
)
|
||||
print(f" Total sync time: {total_sync_time:.2f}s")
|
||||
print(f" Avg data fetch time: {avg_data_fetch:.2f}s")
|
||||
print(f" Total data fetch time: {total_data_fetch:.2f}s")
|
||||
|
|
@ -584,4 +624,3 @@ def finalize_training(
|
|||
wandb.finish()
|
||||
elif use_wandb:
|
||||
wandb.finish()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue