[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2026-02-06 06:46:14 +00:00 committed by Jai Suphavadeeprasit
parent d07ab3e3ce
commit 5cfd1929f1
19 changed files with 708 additions and 452 deletions

View file

@ -18,10 +18,10 @@ import wandb
from .config import TrainingConfig
# Global storage for logprob alignment stats
_logprob_alignment_stats: Dict[str, float] = {}
def setup_wandb(config: TrainingConfig) -> bool:
"""
Initialize Weights & Biases logging if enabled.
@ -80,12 +80,12 @@ def compute_grpo_loss(
- Importance sampling ratio: policy(a|s) / policy_old(a|s)
- PPO-style clipping to prevent large updates
- KL penalty to prevent reward hacking/policy collapse
The loss encourages the model to:
- Increase probability for tokens with positive advantages
- Decrease probability for tokens with negative advantages
- Stay close to the reference policy (inference-time policy)
Args:
model: The model to compute loss for
tokens: Input token IDs [batch, seq_len]
@ -105,7 +105,7 @@ def compute_grpo_loss(
outputs = model(tokens)
logits = outputs.logits
# Temperature scaling for training otherwise likely ratio is off
# Temperature scaling for training otherwise likely ratio is off
t = temperatures.to(logits.device, logits.dtype)
t = torch.where(t <= 0, torch.ones_like(t), t)
scaled_logits = logits / t
@ -130,48 +130,56 @@ def compute_grpo_loss(
logprob_diff_mean = 0.0
logprob_diff_abs_mean = 0.0
logprob_diff_max = 0.0
# === GRPO/PPO Loss Computation ===
if use_reference_logprobs and inference_logprobs is not None:
# Move inference logprobs to correct device/dtype
ref_logprobs = inference_logprobs.to(logp_per_token.device, logp_per_token.dtype)
ref_logprobs = inference_logprobs.to(
logp_per_token.device, logp_per_token.dtype
)
# NOTE: inference_logprobs uses 1.0 for masked (prompt) positions, actual negative values for generated
with torch.no_grad():
# Only look at generated positions (where mask == 1)
ref_at_generated = (ref_logprobs * mask).sum() / mask.sum()
train_at_generated = (logp_per_token * mask).sum() / mask.sum()
# Extract logprobs at generated positions for alignment tracking
inference_logprobs_flat = ref_logprobs[mask.bool()].detach()
training_at_mask = logp_per_token[mask.bool()].detach()
# Token-level difference: THE key metric for alignment verification
# If weights are truly shared, this should be ~0 at step start
token_diff = training_at_mask - inference_logprobs_flat
logprob_diff_mean = token_diff.mean().item()
logprob_diff_abs_mean = token_diff.abs().mean().item()
logprob_diff_max = token_diff.abs().max().item()
# Check if ref logprobs are negative (as they should be for generated tokens)
# If ref_at_generated is close to 1.0, that means the 1.0 placeholder is being used
if ref_at_generated > 0.5:
print(f" [WARNING] ref_logprobs avg {ref_at_generated:.3f} (should be negative!)")
print(" [WARNING] This suggests inference_logprobs alignment is wrong")
print(
f" [WARNING] ref_logprobs avg {ref_at_generated:.3f} (should be negative!)"
)
print(
" [WARNING] This suggests inference_logprobs alignment is wrong"
)
elif abs(ref_at_generated - train_at_generated) > 2.0:
print(f" [DEBUG] Logprob gap: ref={ref_at_generated:.3f}, train={train_at_generated:.3f}")
print(
f" [DEBUG] Logprob gap: ref={ref_at_generated:.3f}, train={train_at_generated:.3f}"
)
# Compute importance sampling ratio: policy(a|s) / policy_old(a|s) = exp(log policy - log policy_old)
log_ratio = logp_per_token - ref_logprobs
ratio = torch.exp(log_ratio)
# PPO-style clipping
clipped_ratio = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps)
# Surrogate objectives
surr1 = ratio * adv_expanded
surr2 = clipped_ratio * adv_expanded
# Pessimistic bound: min for positive advantages, max for negative
# This is equivalent to: -min(ratio * A, clipped_ratio * A) when A > 0
# -max(ratio * A, clipped_ratio * A) when A < 0
@ -180,10 +188,10 @@ def compute_grpo_loss(
torch.min(surr1, surr2),
torch.max(surr1, surr2),
)
# Average over tokens, then over batch
policy_loss = ((policy_loss_per_token * mask).sum(dim=-1) / mask_sum).mean()
# KL penalty: encourage staying close to reference policy
# Using Schulman's unbiased KL estimator from the DeepSeek GRPO paper (Equation 4):
# This estimator is guaranteed to be non-negative (unlike squared log-ratio).
@ -192,23 +200,27 @@ def compute_grpo_loss(
# = exp(-log_ratio) + log_ratio - 1
kl_per_token = torch.exp(-log_ratio) + log_ratio - 1.0
kl_penalty = ((kl_per_token * mask).sum(dim=-1) / mask_sum).mean()
total_loss = (policy_loss + kl_coef * kl_penalty) / gradient_accumulation_steps
total_loss = (
policy_loss + kl_coef * kl_penalty
) / gradient_accumulation_steps
else:
kl_penalty = torch.tensor(0.0, device=logp_per_token.device)
total_loss = policy_loss / gradient_accumulation_steps
# Compute metrics for logging
with torch.no_grad():
# Fraction of tokens where ratio was clipped
clipped_fraction = ((ratio < 1.0 - clip_eps) | (ratio > 1.0 + clip_eps)).float()
clipped_fraction = (
(ratio < 1.0 - clip_eps) | (ratio > 1.0 + clip_eps)
).float()
clipped_fraction = (clipped_fraction * mask).sum() / mask.sum()
# Mean ratio and KL for monitoring (using Schulman's estimator)
mean_ratio = (ratio * mask).sum() / mask.sum()
# Schulman KL: exp(-log_ratio) + log_ratio - 1
schulman_kl = torch.exp(-log_ratio) + log_ratio - 1.0
mean_kl = (schulman_kl * mask).sum() / mask.sum()
# For backward compatibility: collect training logprobs
raw_logp_per_token = -F.cross_entropy(
outputs.logits.view(-1, outputs.logits.size(-1)),
@ -239,10 +251,10 @@ def compute_grpo_loss(
avg_logp = (logp_per_token * mask_float).sum(dim=-1) / mask_sum
pos_logp = (logp_per_token * pos).mean().item()
neg_logp = (logp_per_token * neg).mean().item()
# Interpretable metric: advantage-weighted average logprob
interpretable_loss = (avg_logp * advantages.squeeze()).mean().item()
metrics = {
"pos_logp": pos_logp,
"neg_logp": neg_logp,
@ -256,7 +268,11 @@ def compute_grpo_loss(
"kl_penalty": kl_penalty.item() if torch.is_tensor(kl_penalty) else kl_penalty,
"mean_ratio": mean_ratio.item() if torch.is_tensor(mean_ratio) else mean_ratio,
"mean_kl": mean_kl.item() if torch.is_tensor(mean_kl) else mean_kl,
"clipped_fraction": clipped_fraction.item() if torch.is_tensor(clipped_fraction) else clipped_fraction,
"clipped_fraction": (
clipped_fraction.item()
if torch.is_tensor(clipped_fraction)
else clipped_fraction
),
# Token-level alignment metrics (key for verifying weight sharing)
"logprob_diff_mean": logprob_diff_mean,
"logprob_diff_abs_mean": logprob_diff_abs_mean,
@ -284,7 +300,7 @@ def run_training_step(
2. Backward pass with gradient accumulation
3. Gradient clipping
4. Optimizer step
Args:
model: The model to train
optimizer: The optimizer
@ -315,23 +331,25 @@ def run_training_step(
all_inference_logprobs: List[torch.Tensor] = []
# Get GRPO hyperparameters from config
kl_coef = getattr(config, 'kl_coef', 0.1)
clip_eps = getattr(config, 'clip_eps', 0.2)
use_reference_logprobs = getattr(config, 'use_reference_logprobs', True)
kl_coef = getattr(config, "kl_coef", 0.1)
clip_eps = getattr(config, "clip_eps", 0.2)
use_reference_logprobs = getattr(config, "use_reference_logprobs", True)
# Accumulate gradients over micro-batches
num_batches = len(token_batches) if token_batches else 1
for batch_idx, (tokens, labels, advantages, temperatures) in enumerate(zip(
token_batches, label_batches, advantage_batches, temperature_batches
)):
for batch_idx, (tokens, labels, advantages, temperatures) in enumerate(
zip(token_batches, label_batches, advantage_batches, temperature_batches)
):
tokens = tokens.to(config.device)
labels = labels.to(config.device)
advantages = advantages.to(config.device)
# Get corresponding inference logprobs batch if available
inf_logprobs = None
if inference_logprob_batches is not None and batch_idx < len(inference_logprob_batches):
if inference_logprob_batches is not None and batch_idx < len(
inference_logprob_batches
):
inf_logprobs = inference_logprob_batches[batch_idx]
loss, metrics = compute_grpo_loss(
@ -353,29 +371,34 @@ def run_training_step(
total_neg_logp += metrics["neg_logp"]
total_pos += metrics["pos_count"]
total_neg += metrics["neg_count"]
# Accumulate GRPO-specific metrics
total_kl_penalty += metrics.get("kl_penalty", 0.0)
total_mean_ratio += metrics.get("mean_ratio", 1.0)
total_mean_kl += metrics.get("mean_kl", 0.0)
total_clipped_fraction += metrics.get("clipped_fraction", 0.0)
# Accumulate token-level alignment metrics
total_logprob_diff_mean += metrics.get("logprob_diff_mean", 0.0)
total_logprob_diff_abs_mean += metrics.get("logprob_diff_abs_mean", 0.0)
total_logprob_diff_max = max(total_logprob_diff_max, metrics.get("logprob_diff_max", 0.0))
total_logprob_diff_max = max(
total_logprob_diff_max, metrics.get("logprob_diff_max", 0.0)
)
# Collect logprobs for alignment monitoring
if "training_logprobs" in metrics and metrics["training_logprobs"] is not None:
all_training_logprobs.append(metrics["training_logprobs"])
if "inference_logprobs" in metrics and metrics["inference_logprobs"] is not None:
if (
"inference_logprobs" in metrics
and metrics["inference_logprobs"] is not None
):
all_inference_logprobs.append(metrics["inference_logprobs"])
# Gradient clipping and optimizer step
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad()
# Help prevent memory fragmentation
torch.cuda.empty_cache()
@ -387,7 +410,7 @@ def run_training_step(
result = {
"loss": total_loss,
"grad_norm": grad_norm.item() if hasattr(grad_norm, 'item') else grad_norm,
"grad_norm": grad_norm.item() if hasattr(grad_norm, "item") else grad_norm,
"pos_logp": total_pos_logp,
"neg_logp": total_neg_logp,
"pos_count": total_pos,
@ -398,27 +421,33 @@ def run_training_step(
"mean_kl": total_mean_kl / num_batches,
"clipped_fraction": total_clipped_fraction / num_batches,
}
# Compute logprob alignment stats for monitoring
# This proves weight sharing is working: inference & training logprobs should converge
if all_training_logprobs:
train_flat = torch.cat(all_training_logprobs)
if train_flat.numel() > 0:
_logprob_alignment_stats["logprobs/training_mean"] = train_flat.mean().item()
_logprob_alignment_stats["logprobs/training_mean"] = (
train_flat.mean().item()
)
_logprob_alignment_stats["logprobs/training_std"] = train_flat.std().item()
if all_inference_logprobs:
inf_flat = torch.cat(all_inference_logprobs)
if inf_flat.numel() > 0:
_logprob_alignment_stats["logprobs/inference_mean"] = inf_flat.mean().item()
_logprob_alignment_stats["logprobs/inference_std"] = inf_flat.std().item()
# Token-level alignment metrics - THE key metric for verifying weight sharing
# diff_abs_mean close to 0 = weights are truly shared
_logprob_alignment_stats["alignment/diff_mean"] = total_logprob_diff_mean / num_batches
_logprob_alignment_stats["alignment/diff_abs_mean"] = total_logprob_diff_abs_mean / num_batches
_logprob_alignment_stats["alignment/diff_mean"] = (
total_logprob_diff_mean / num_batches
)
_logprob_alignment_stats["alignment/diff_abs_mean"] = (
total_logprob_diff_abs_mean / num_batches
)
_logprob_alignment_stats["alignment/diff_max"] = total_logprob_diff_max
return result
@ -464,7 +493,7 @@ def log_metrics(
mean_ratio = metrics.get("mean_ratio", 1.0)
mean_kl = metrics.get("mean_kl", 0)
clipped_frac = metrics.get("clipped_fraction", 0)
if kl_penalty > 0 or mean_kl > 0:
print(
f" GRPO: KL={mean_kl:.4f}, ratio={mean_ratio:.3f}, "
@ -495,15 +524,20 @@ def log_metrics(
"grpo/clipped_fraction": clipped_frac,
}
# Add timing metrics if present
for key in ["step_time", "sync_time", "data_fetch_time",
"gpu_memory_gb", "gpu_memory_reserved_gb"]:
for key in [
"step_time",
"sync_time",
"data_fetch_time",
"gpu_memory_gb",
"gpu_memory_reserved_gb",
]:
if key in metrics:
log_dict[f"train/{key}"] = metrics[key]
# Add logprob alignment stats
# Add logprob alignment stats
if _logprob_alignment_stats:
log_dict.update(_logprob_alignment_stats)
if extra_metrics:
log_dict.update(extra_metrics)
wandb.log(log_dict, step=step)
@ -549,7 +583,9 @@ def finalize_training(
total_step_time = sum(step_times)
avg_sync_time = sum(sync_times) / len(sync_times) if sync_times else 0
total_sync_time = sum(sync_times)
avg_data_fetch = sum(data_fetch_times) / len(data_fetch_times) if data_fetch_times else 0
avg_data_fetch = (
sum(data_fetch_times) / len(data_fetch_times) if data_fetch_times else 0
)
total_data_fetch = sum(data_fetch_times)
avg_gpu_mem = sum(gpu_memories) / len(gpu_memories) if gpu_memories else 0
@ -557,13 +593,17 @@ def finalize_training(
print(f"\n{'='*70}")
print(f"BENCHMARK SUMMARY ({mode})")
print(f"{'='*70}")
print(f" Total training time: {total_time:.2f}s ({total_time/60:.2f} min)")
print(
f" Total training time: {total_time:.2f}s ({total_time/60:.2f} min)"
)
print(f" Total steps: {total_steps}")
print(" ")
print(" TIMING BREAKDOWN:")
print(f" Avg step time: {avg_step_time:.2f}s")
print(f" Total step time: {total_step_time:.2f}s")
print(f" Avg sync time: {avg_sync_time:.2f}s (x{len(sync_times)} syncs)")
print(
f" Avg sync time: {avg_sync_time:.2f}s (x{len(sync_times)} syncs)"
)
print(f" Total sync time: {total_sync_time:.2f}s")
print(f" Avg data fetch time: {avg_data_fetch:.2f}s")
print(f" Total data fetch time: {total_data_fetch:.2f}s")
@ -584,4 +624,3 @@ def finalize_training(
wandb.finish()
elif use_wandb:
wandb.finish()