manual testing

This commit is contained in:
Jai Suphavadeeprasit 2026-02-02 15:40:24 -05:00
parent da046d3d3b
commit c1bb4f33f0
5 changed files with 329 additions and 766 deletions

View file

@ -153,13 +153,22 @@ def compute_grpo_loss(
temperatures: torch.Tensor,
gradient_accumulation_steps: int,
inference_logprobs: Optional[torch.Tensor] = None,
kl_coef: float = 0.1,
clip_eps: float = 0.2,
use_reference_logprobs: bool = True,
) -> Tuple[torch.Tensor, dict]:
"""
Compute GRPO (Group Relative Policy Optimization) loss for a single micro-batch.
The GRPO loss encourages the model to:
This implements proper GRPO/PPO with:
- Importance sampling ratio: π(a|s) / π_old(a|s)
- PPO-style clipping to prevent large updates
- KL penalty to prevent reward hacking/policy collapse
The loss encourages the model to:
- Increase probability for tokens with positive advantages
- Decrease probability for tokens with negative advantages
- Stay close to the reference policy (inference-time policy)
Args:
model: The model to compute loss for
@ -168,7 +177,10 @@ def compute_grpo_loss(
advantages: Advantage values [batch, 1]
temperatures: Temperature values [batch, 1, 1]
gradient_accumulation_steps: Number of accumulation steps (for scaling)
inference_logprobs: Optional logprobs from inference for alignment check
inference_logprobs: Logprobs from inference (π_old), aligned with labels [batch, seq_len]
kl_coef: KL penalty coefficient (beta). Higher = more conservative updates
clip_eps: PPO clipping epsilon. Clips ratio to [1-eps, 1+eps]
use_reference_logprobs: If True, use inference_logprobs as reference policy
Returns:
Tuple of (loss tensor, metrics dict)
@ -177,14 +189,14 @@ def compute_grpo_loss(
outputs = model(tokens)
logits = outputs.logits
# Temperature scaling
# Temperature scaling for training
t = temperatures.to(logits.device, logits.dtype)
t = torch.where(t <= 0, torch.ones_like(t), t)
logits = logits / t
scaled_logits = logits / t
# Log probabilities per token
# Log probabilities per token (current policy π)
logp_per_token = -F.cross_entropy(
logits.view(-1, logits.size(-1)),
scaled_logits.view(-1, scaled_logits.size(-1)),
labels.view(-1),
reduction="none",
ignore_index=-100,
@ -192,39 +204,103 @@ def compute_grpo_loss(
# Masking based on labels != -100
mask = (labels != -100).float()
mask_sum = mask.sum(dim=-1).clamp_min(1e-8)
# Compute metrics (no grad needed)
# Expand advantages to match token shape [batch, 1] -> [batch, seq_len]
adv_expanded = advantages.expand_as(logp_per_token).to(logp_per_token.device)
# === GRPO/PPO Loss Computation ===
if use_reference_logprobs and inference_logprobs is not None:
# Move inference logprobs to correct device/dtype
ref_logprobs = inference_logprobs.to(logp_per_token.device, logp_per_token.dtype)
# Compute importance sampling ratio: π(a|s) / π_old(a|s) = exp(log π - log π_old)
log_ratio = logp_per_token - ref_logprobs
ratio = torch.exp(log_ratio)
# PPO-style clipping
clipped_ratio = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps)
# Surrogate objectives
surr1 = ratio * adv_expanded
surr2 = clipped_ratio * adv_expanded
# Pessimistic bound: min for positive advantages, max for negative
# This is equivalent to: -min(ratio * A, clipped_ratio * A) when A > 0
# -max(ratio * A, clipped_ratio * A) when A < 0
policy_loss_per_token = -torch.where(
adv_expanded >= 0,
torch.min(surr1, surr2),
torch.max(surr1, surr2),
)
# Average over tokens, then over batch
policy_loss = ((policy_loss_per_token * mask).sum(dim=-1) / mask_sum).mean()
# KL penalty: encourage staying close to reference policy
# KL(π || π_ref) ≈ log(π/π_ref) = log_ratio (when π_ref is the reference)
# We use the approximation: KL ≈ (ratio - 1) - log(ratio)
# But simpler: just penalize squared log-ratio which is symmetric
if kl_coef > 0:
# Approximate KL using (log_ratio)^2 / 2 (Taylor expansion)
# Or just use log_ratio directly as a penalty
kl_per_token = log_ratio.pow(2) # Squared for symmetric penalty
kl_penalty = ((kl_per_token * mask).sum(dim=-1) / mask_sum).mean()
total_loss = (policy_loss + kl_coef * kl_penalty) / gradient_accumulation_steps
else:
kl_penalty = torch.tensor(0.0, device=logp_per_token.device)
total_loss = policy_loss / gradient_accumulation_steps
# Compute metrics for logging
with torch.no_grad():
# Fraction of tokens where ratio was clipped
clipped_fraction = ((ratio < 1.0 - clip_eps) | (ratio > 1.0 + clip_eps)).float()
clipped_fraction = (clipped_fraction * mask).sum() / mask.sum()
# Mean ratio and KL for monitoring
mean_ratio = (ratio * mask).sum() / mask.sum()
mean_kl = (log_ratio.pow(2) * mask).sum() / mask.sum()
# For backward compatibility: collect training logprobs
raw_logp_per_token = -F.cross_entropy(
outputs.logits.view(-1, outputs.logits.size(-1)),
labels.view(-1),
reduction="none",
ignore_index=-100,
).view(labels.shape)
training_logprobs_flat = raw_logp_per_token[mask.bool()].detach()
else:
# Fallback: REINFORCE-style (no reference policy)
# This is what the original code did - NOT recommended!
print(" [WARNING] No reference logprobs - using REINFORCE (may cause reward hacking!)")
# Simple policy gradient: -log(π) * A
policy_loss = ((-logp_per_token * mask * adv_expanded).sum(dim=-1) / mask_sum).mean()
total_loss = policy_loss / gradient_accumulation_steps
kl_penalty = torch.tensor(0.0, device=logp_per_token.device)
with torch.no_grad():
clipped_fraction = torch.tensor(0.0)
mean_ratio = torch.tensor(1.0)
mean_kl = torch.tensor(0.0)
raw_logp_per_token = -F.cross_entropy(
outputs.logits.view(-1, outputs.logits.size(-1)),
labels.view(-1),
reduction="none",
ignore_index=-100,
).view(labels.shape)
training_logprobs_flat = raw_logp_per_token[mask.bool()].detach()
# === Compute Additional Metrics ===
with torch.no_grad():
pos = (advantages > 0).float()
neg = (advantages <= 0).float()
mask_float = mask.to(logp_per_token.dtype)
mask_sum = mask_float.sum(dim=-1).clamp_min(1e-8)
avg_logp = (logp_per_token * mask_float).sum(dim=-1) / mask_sum
pos_logp = (logp_per_token * pos).mean().item()
neg_logp = (logp_per_token * neg).mean().item()
# For alignment check: compute logprobs WITHOUT temperature scaling
# This allows fair comparison with inference logprobs (which are at temp=1.0)
raw_logp_per_token = -F.cross_entropy(
outputs.logits.view(-1, outputs.logits.size(-1)), # Use original logits, not temp-scaled
labels.view(-1),
reduction="none",
ignore_index=-100,
).view(labels.shape)
# Collect raw training logprobs for masked positions (generated tokens only)
# Keep as PyTorch tensor (supports bfloat16 natively)
training_logprobs_flat = raw_logp_per_token[mask.bool()].detach()
# GRPO loss: weighted log probabilities by advantages
grpo_loss_term = torch.exp(logp_per_token - logp_per_token.detach())
grpo_loss = (
((-grpo_loss_term * mask).sum(-1) / mask.sum(-1))
* advantages.to(logp_per_token.device)
).mean() / gradient_accumulation_steps
# Compute a more interpretable loss metric (advantage-weighted logprobs)
with torch.no_grad():
# Interpretable metric: advantage-weighted average logprob
interpretable_loss = (avg_logp * advantages.squeeze()).mean().item()
metrics = {
@ -233,11 +309,16 @@ def compute_grpo_loss(
"avg_logp": avg_logp,
"pos_count": pos.sum().item(),
"neg_count": neg.sum().item(),
"training_logprobs": training_logprobs_flat, # For alignment check
"interpretable_loss": interpretable_loss, # More meaningful metric
"training_logprobs": training_logprobs_flat,
"interpretable_loss": interpretable_loss,
# GRPO-specific metrics
"kl_penalty": kl_penalty.item() if torch.is_tensor(kl_penalty) else kl_penalty,
"mean_ratio": mean_ratio.item() if torch.is_tensor(mean_ratio) else mean_ratio,
"mean_kl": mean_kl.item() if torch.is_tensor(mean_kl) else mean_kl,
"clipped_fraction": clipped_fraction.item() if torch.is_tensor(clipped_fraction) else clipped_fraction,
}
return grpo_loss, metrics
return total_loss, metrics
def compute_logprob_alignment(
@ -309,17 +390,16 @@ def run_training_step(
advantage_batches: List[torch.Tensor],
temperature_batches: List[torch.Tensor],
config: TrainingConfig,
inference_logprobs: Optional[List[np.ndarray]] = None,
inference_logprob_batches: Optional[List[torch.Tensor]] = None,
) -> dict:
"""
Run a single training step with gradient accumulation.
Performs:
1. Forward pass through all micro-batches
1. Forward pass through all micro-batches with proper GRPO loss
2. Backward pass with gradient accumulation
3. Gradient clipping
4. Optimizer step
5. (Optional) Logprob alignment check
Args:
model: The model to train
@ -328,8 +408,8 @@ def run_training_step(
label_batches: List of label tensors
advantage_batches: List of advantage tensors
temperature_batches: List of temperature tensors
config: Training configuration
inference_logprobs: Optional logprobs from inference for alignment check
config: Training configuration (includes kl_coef, clip_eps, use_reference_logprobs)
inference_logprob_batches: Batched logprobs from inference (π_old), aligned with labels
Returns:
Dict of training metrics for this step
@ -341,16 +421,32 @@ def run_training_step(
total_neg_logp = 0.0
total_pos = 0.0
total_neg = 0.0
total_kl_penalty = 0.0
total_mean_ratio = 0.0
total_mean_kl = 0.0
total_clipped_fraction = 0.0
grad_norm = 0.0
all_training_logprobs: List[torch.Tensor] = []
# Get GRPO hyperparameters from config
kl_coef = getattr(config, 'kl_coef', 0.1)
clip_eps = getattr(config, 'clip_eps', 0.2)
use_reference_logprobs = getattr(config, 'use_reference_logprobs', True)
# Accumulate gradients over micro-batches
for tokens, labels, advantages, temperatures in zip(
num_batches = len(token_batches) if token_batches else 1
for batch_idx, (tokens, labels, advantages, temperatures) in enumerate(zip(
token_batches, label_batches, advantage_batches, temperature_batches
):
)):
tokens = tokens.to(config.device)
labels = labels.to(config.device)
advantages = advantages.to(config.device)
# Get corresponding inference logprobs batch if available
inf_logprobs = None
if inference_logprob_batches is not None and batch_idx < len(inference_logprob_batches):
inf_logprobs = inference_logprob_batches[batch_idx]
loss, metrics = compute_grpo_loss(
model,
@ -359,6 +455,10 @@ def run_training_step(
advantages,
temperatures,
config.gradient_accumulation_steps,
inference_logprobs=inf_logprobs,
kl_coef=kl_coef,
clip_eps=clip_eps,
use_reference_logprobs=use_reference_logprobs,
)
loss.backward()
@ -368,7 +468,13 @@ def run_training_step(
total_pos += metrics["pos_count"]
total_neg += metrics["neg_count"]
# Collect training logprobs for alignment check
# Accumulate GRPO-specific metrics
total_kl_penalty += metrics.get("kl_penalty", 0.0)
total_mean_ratio += metrics.get("mean_ratio", 1.0)
total_mean_kl += metrics.get("mean_kl", 0.0)
total_clipped_fraction += metrics.get("clipped_fraction", 0.0)
# Collect training logprobs for alignment monitoring
if "training_logprobs" in metrics:
all_training_logprobs.append(metrics["training_logprobs"])
@ -380,8 +486,7 @@ def run_training_step(
# Help prevent memory fragmentation
torch.cuda.empty_cache()
# Normalize metrics by count
num_batches = len(token_batches) if token_batches else 1
# Normalize metrics by batch count
if total_pos > 0:
total_pos_logp /= num_batches
if total_neg > 0:
@ -394,18 +499,22 @@ def run_training_step(
"neg_logp": total_neg_logp,
"pos_count": total_pos,
"neg_count": total_neg,
# GRPO-specific metrics (averaged over batches)
"kl_penalty": total_kl_penalty / num_batches,
"mean_ratio": total_mean_ratio / num_batches,
"mean_kl": total_mean_kl / num_batches,
"clipped_fraction": total_clipped_fraction / num_batches,
}
# Compute logprob alignment stats
# NOTE: This comparison is approximate - inference and training logprobs
# come from different batching, so token-by-token alignment isn't possible.
# The real-time test at startup is the reliable alignment check.
if inference_logprobs is not None and all_training_logprobs:
alignment_stats = compute_logprob_alignment(
inference_logprobs, all_training_logprobs, debug=False
)
_logprob_alignment_stats.update(alignment_stats)
result["logprob_alignment"] = alignment_stats
# Compute logprob alignment stats for monitoring
# NOTE: Now that we use proper GRPO, this is less critical
# but still useful for debugging weight sharing issues
if all_training_logprobs:
# Store training logprob stats
train_flat = torch.cat(all_training_logprobs)
if train_flat.numel() > 0:
_logprob_alignment_stats["logprobs/training_mean"] = train_flat.mean().item()
_logprob_alignment_stats["logprobs/training_std"] = train_flat.std().item()
return result
@ -441,19 +550,27 @@ def log_metrics(
if "gpu_memory_gb" in metrics:
timing_str += f", GPU mem: {metrics['gpu_memory_gb']:.2f}GB"
# Show interpretable loss (advantage-weighted logprobs) if available
interp_loss = metrics.get("interpretable_loss")
if interp_loss is not None:
print(f" AdvWeightedLogP: {interp_loss:.4f}, Grad norm: {metrics['grad_norm']:.4f}{timing_str}")
else:
loss_str = (
f"{metrics['loss']:.6f}"
if abs(metrics["loss"]) < 0.01
else f"{metrics['loss']:.4f}"
)
print(f" Loss: {loss_str}, Grad norm: {metrics['grad_norm']:.4f}{timing_str}")
# Primary metrics line: Loss and grad norm
loss_str = (
f"{metrics['loss']:.6f}"
if abs(metrics["loss"]) < 0.01
else f"{metrics['loss']:.4f}"
)
print(f" Loss: {loss_str}, Grad norm: {metrics['grad_norm']:.4f}{timing_str}")
# Show GRPO-specific metrics if available
# GRPO metrics line: KL, ratio, clipping
kl_penalty = metrics.get("kl_penalty", 0)
mean_ratio = metrics.get("mean_ratio", 1.0)
mean_kl = metrics.get("mean_kl", 0)
clipped_frac = metrics.get("clipped_fraction", 0)
if kl_penalty > 0 or mean_kl > 0:
print(
f" GRPO: KL={mean_kl:.4f}, ratio={mean_ratio:.3f}, "
f"clipped={clipped_frac*100:.1f}%"
)
# Advantage distribution
if "pos_count" in metrics or "neg_count" in metrics:
pos_count = metrics.get("pos_count", 0)
neg_count = metrics.get("neg_count", 0)
@ -463,24 +580,6 @@ def log_metrics(
f" Advantages: +{int(pos_count)} / -{int(neg_count)}, "
f"LogP: pos={pos_logp:.3f}, neg={neg_logp:.3f}"
)
# Show logprob alignment stats (important for shared_vllm validation!)
if "logprob_alignment" in metrics:
alignment = metrics["logprob_alignment"]
if "logprobs/diff" in alignment:
diff = alignment["logprobs/diff"]
inf_mean = alignment.get("logprobs/inference_mean", 0)
train_mean = alignment.get("logprobs/training_mean", 0)
# NOTE: This comparison has a fundamental timing issue!
# - inference_logprobs: from vLLM at generation time (possibly stale)
# - training_logprobs: from trainer's current forward pass
# After training starts, weights change, making comparison invalid.
#
# NOTE: This diff is just for monitoring, not validation!
# The real-time test at startup is the reliable alignment check.
# This diff will naturally drift as training progresses (expected).
print(f" LogProb Stats: inf_mean={inf_mean:.4f}, train_mean={train_mean:.4f}")
if use_wandb:
log_dict = {
@ -488,6 +587,11 @@ def log_metrics(
"train/grad_norm": metrics["grad_norm"],
"train/pos_logp": metrics.get("pos_logp", 0),
"train/neg_logp": metrics.get("neg_logp", 0),
# GRPO-specific metrics
"grpo/kl_penalty": kl_penalty,
"grpo/mean_ratio": mean_ratio,
"grpo/mean_kl": mean_kl,
"grpo/clipped_fraction": clipped_frac,
}
# Add timing metrics if present
for key in ["step_time", "sync_time", "data_fetch_time",