mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
manual testing
This commit is contained in:
parent
da046d3d3b
commit
c1bb4f33f0
5 changed files with 329 additions and 766 deletions
|
|
@ -153,13 +153,22 @@ def compute_grpo_loss(
|
|||
temperatures: torch.Tensor,
|
||||
gradient_accumulation_steps: int,
|
||||
inference_logprobs: Optional[torch.Tensor] = None,
|
||||
kl_coef: float = 0.1,
|
||||
clip_eps: float = 0.2,
|
||||
use_reference_logprobs: bool = True,
|
||||
) -> Tuple[torch.Tensor, dict]:
|
||||
"""
|
||||
Compute GRPO (Group Relative Policy Optimization) loss for a single micro-batch.
|
||||
|
||||
The GRPO loss encourages the model to:
|
||||
This implements proper GRPO/PPO with:
|
||||
- Importance sampling ratio: π(a|s) / π_old(a|s)
|
||||
- PPO-style clipping to prevent large updates
|
||||
- KL penalty to prevent reward hacking/policy collapse
|
||||
|
||||
The loss encourages the model to:
|
||||
- Increase probability for tokens with positive advantages
|
||||
- Decrease probability for tokens with negative advantages
|
||||
- Stay close to the reference policy (inference-time policy)
|
||||
|
||||
Args:
|
||||
model: The model to compute loss for
|
||||
|
|
@ -168,7 +177,10 @@ def compute_grpo_loss(
|
|||
advantages: Advantage values [batch, 1]
|
||||
temperatures: Temperature values [batch, 1, 1]
|
||||
gradient_accumulation_steps: Number of accumulation steps (for scaling)
|
||||
inference_logprobs: Optional logprobs from inference for alignment check
|
||||
inference_logprobs: Logprobs from inference (π_old), aligned with labels [batch, seq_len]
|
||||
kl_coef: KL penalty coefficient (beta). Higher = more conservative updates
|
||||
clip_eps: PPO clipping epsilon. Clips ratio to [1-eps, 1+eps]
|
||||
use_reference_logprobs: If True, use inference_logprobs as reference policy
|
||||
|
||||
Returns:
|
||||
Tuple of (loss tensor, metrics dict)
|
||||
|
|
@ -177,14 +189,14 @@ def compute_grpo_loss(
|
|||
outputs = model(tokens)
|
||||
logits = outputs.logits
|
||||
|
||||
# Temperature scaling
|
||||
# Temperature scaling for training
|
||||
t = temperatures.to(logits.device, logits.dtype)
|
||||
t = torch.where(t <= 0, torch.ones_like(t), t)
|
||||
logits = logits / t
|
||||
scaled_logits = logits / t
|
||||
|
||||
# Log probabilities per token
|
||||
# Log probabilities per token (current policy π)
|
||||
logp_per_token = -F.cross_entropy(
|
||||
logits.view(-1, logits.size(-1)),
|
||||
scaled_logits.view(-1, scaled_logits.size(-1)),
|
||||
labels.view(-1),
|
||||
reduction="none",
|
||||
ignore_index=-100,
|
||||
|
|
@ -192,39 +204,103 @@ def compute_grpo_loss(
|
|||
|
||||
# Masking based on labels != -100
|
||||
mask = (labels != -100).float()
|
||||
mask_sum = mask.sum(dim=-1).clamp_min(1e-8)
|
||||
|
||||
# Compute metrics (no grad needed)
|
||||
# Expand advantages to match token shape [batch, 1] -> [batch, seq_len]
|
||||
adv_expanded = advantages.expand_as(logp_per_token).to(logp_per_token.device)
|
||||
|
||||
# === GRPO/PPO Loss Computation ===
|
||||
if use_reference_logprobs and inference_logprobs is not None:
|
||||
# Move inference logprobs to correct device/dtype
|
||||
ref_logprobs = inference_logprobs.to(logp_per_token.device, logp_per_token.dtype)
|
||||
|
||||
# Compute importance sampling ratio: π(a|s) / π_old(a|s) = exp(log π - log π_old)
|
||||
log_ratio = logp_per_token - ref_logprobs
|
||||
ratio = torch.exp(log_ratio)
|
||||
|
||||
# PPO-style clipping
|
||||
clipped_ratio = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps)
|
||||
|
||||
# Surrogate objectives
|
||||
surr1 = ratio * adv_expanded
|
||||
surr2 = clipped_ratio * adv_expanded
|
||||
|
||||
# Pessimistic bound: min for positive advantages, max for negative
|
||||
# This is equivalent to: -min(ratio * A, clipped_ratio * A) when A > 0
|
||||
# -max(ratio * A, clipped_ratio * A) when A < 0
|
||||
policy_loss_per_token = -torch.where(
|
||||
adv_expanded >= 0,
|
||||
torch.min(surr1, surr2),
|
||||
torch.max(surr1, surr2),
|
||||
)
|
||||
|
||||
# Average over tokens, then over batch
|
||||
policy_loss = ((policy_loss_per_token * mask).sum(dim=-1) / mask_sum).mean()
|
||||
|
||||
# KL penalty: encourage staying close to reference policy
|
||||
# KL(π || π_ref) ≈ log(π/π_ref) = log_ratio (when π_ref is the reference)
|
||||
# We use the approximation: KL ≈ (ratio - 1) - log(ratio)
|
||||
# But simpler: just penalize squared log-ratio which is symmetric
|
||||
if kl_coef > 0:
|
||||
# Approximate KL using (log_ratio)^2 / 2 (Taylor expansion)
|
||||
# Or just use log_ratio directly as a penalty
|
||||
kl_per_token = log_ratio.pow(2) # Squared for symmetric penalty
|
||||
kl_penalty = ((kl_per_token * mask).sum(dim=-1) / mask_sum).mean()
|
||||
total_loss = (policy_loss + kl_coef * kl_penalty) / gradient_accumulation_steps
|
||||
else:
|
||||
kl_penalty = torch.tensor(0.0, device=logp_per_token.device)
|
||||
total_loss = policy_loss / gradient_accumulation_steps
|
||||
|
||||
# Compute metrics for logging
|
||||
with torch.no_grad():
|
||||
# Fraction of tokens where ratio was clipped
|
||||
clipped_fraction = ((ratio < 1.0 - clip_eps) | (ratio > 1.0 + clip_eps)).float()
|
||||
clipped_fraction = (clipped_fraction * mask).sum() / mask.sum()
|
||||
|
||||
# Mean ratio and KL for monitoring
|
||||
mean_ratio = (ratio * mask).sum() / mask.sum()
|
||||
mean_kl = (log_ratio.pow(2) * mask).sum() / mask.sum()
|
||||
|
||||
# For backward compatibility: collect training logprobs
|
||||
raw_logp_per_token = -F.cross_entropy(
|
||||
outputs.logits.view(-1, outputs.logits.size(-1)),
|
||||
labels.view(-1),
|
||||
reduction="none",
|
||||
ignore_index=-100,
|
||||
).view(labels.shape)
|
||||
training_logprobs_flat = raw_logp_per_token[mask.bool()].detach()
|
||||
else:
|
||||
# Fallback: REINFORCE-style (no reference policy)
|
||||
# This is what the original code did - NOT recommended!
|
||||
print(" [WARNING] No reference logprobs - using REINFORCE (may cause reward hacking!)")
|
||||
|
||||
# Simple policy gradient: -log(π) * A
|
||||
policy_loss = ((-logp_per_token * mask * adv_expanded).sum(dim=-1) / mask_sum).mean()
|
||||
total_loss = policy_loss / gradient_accumulation_steps
|
||||
kl_penalty = torch.tensor(0.0, device=logp_per_token.device)
|
||||
|
||||
with torch.no_grad():
|
||||
clipped_fraction = torch.tensor(0.0)
|
||||
mean_ratio = torch.tensor(1.0)
|
||||
mean_kl = torch.tensor(0.0)
|
||||
raw_logp_per_token = -F.cross_entropy(
|
||||
outputs.logits.view(-1, outputs.logits.size(-1)),
|
||||
labels.view(-1),
|
||||
reduction="none",
|
||||
ignore_index=-100,
|
||||
).view(labels.shape)
|
||||
training_logprobs_flat = raw_logp_per_token[mask.bool()].detach()
|
||||
|
||||
# === Compute Additional Metrics ===
|
||||
with torch.no_grad():
|
||||
pos = (advantages > 0).float()
|
||||
neg = (advantages <= 0).float()
|
||||
mask_float = mask.to(logp_per_token.dtype)
|
||||
mask_sum = mask_float.sum(dim=-1).clamp_min(1e-8)
|
||||
avg_logp = (logp_per_token * mask_float).sum(dim=-1) / mask_sum
|
||||
pos_logp = (logp_per_token * pos).mean().item()
|
||||
neg_logp = (logp_per_token * neg).mean().item()
|
||||
|
||||
# For alignment check: compute logprobs WITHOUT temperature scaling
|
||||
# This allows fair comparison with inference logprobs (which are at temp=1.0)
|
||||
raw_logp_per_token = -F.cross_entropy(
|
||||
outputs.logits.view(-1, outputs.logits.size(-1)), # Use original logits, not temp-scaled
|
||||
labels.view(-1),
|
||||
reduction="none",
|
||||
ignore_index=-100,
|
||||
).view(labels.shape)
|
||||
|
||||
# Collect raw training logprobs for masked positions (generated tokens only)
|
||||
# Keep as PyTorch tensor (supports bfloat16 natively)
|
||||
training_logprobs_flat = raw_logp_per_token[mask.bool()].detach()
|
||||
|
||||
# GRPO loss: weighted log probabilities by advantages
|
||||
grpo_loss_term = torch.exp(logp_per_token - logp_per_token.detach())
|
||||
grpo_loss = (
|
||||
((-grpo_loss_term * mask).sum(-1) / mask.sum(-1))
|
||||
* advantages.to(logp_per_token.device)
|
||||
).mean() / gradient_accumulation_steps
|
||||
|
||||
# Compute a more interpretable loss metric (advantage-weighted logprobs)
|
||||
with torch.no_grad():
|
||||
# Interpretable metric: advantage-weighted average logprob
|
||||
interpretable_loss = (avg_logp * advantages.squeeze()).mean().item()
|
||||
|
||||
metrics = {
|
||||
|
|
@ -233,11 +309,16 @@ def compute_grpo_loss(
|
|||
"avg_logp": avg_logp,
|
||||
"pos_count": pos.sum().item(),
|
||||
"neg_count": neg.sum().item(),
|
||||
"training_logprobs": training_logprobs_flat, # For alignment check
|
||||
"interpretable_loss": interpretable_loss, # More meaningful metric
|
||||
"training_logprobs": training_logprobs_flat,
|
||||
"interpretable_loss": interpretable_loss,
|
||||
# GRPO-specific metrics
|
||||
"kl_penalty": kl_penalty.item() if torch.is_tensor(kl_penalty) else kl_penalty,
|
||||
"mean_ratio": mean_ratio.item() if torch.is_tensor(mean_ratio) else mean_ratio,
|
||||
"mean_kl": mean_kl.item() if torch.is_tensor(mean_kl) else mean_kl,
|
||||
"clipped_fraction": clipped_fraction.item() if torch.is_tensor(clipped_fraction) else clipped_fraction,
|
||||
}
|
||||
|
||||
return grpo_loss, metrics
|
||||
return total_loss, metrics
|
||||
|
||||
|
||||
def compute_logprob_alignment(
|
||||
|
|
@ -309,17 +390,16 @@ def run_training_step(
|
|||
advantage_batches: List[torch.Tensor],
|
||||
temperature_batches: List[torch.Tensor],
|
||||
config: TrainingConfig,
|
||||
inference_logprobs: Optional[List[np.ndarray]] = None,
|
||||
inference_logprob_batches: Optional[List[torch.Tensor]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Run a single training step with gradient accumulation.
|
||||
|
||||
Performs:
|
||||
1. Forward pass through all micro-batches
|
||||
1. Forward pass through all micro-batches with proper GRPO loss
|
||||
2. Backward pass with gradient accumulation
|
||||
3. Gradient clipping
|
||||
4. Optimizer step
|
||||
5. (Optional) Logprob alignment check
|
||||
|
||||
Args:
|
||||
model: The model to train
|
||||
|
|
@ -328,8 +408,8 @@ def run_training_step(
|
|||
label_batches: List of label tensors
|
||||
advantage_batches: List of advantage tensors
|
||||
temperature_batches: List of temperature tensors
|
||||
config: Training configuration
|
||||
inference_logprobs: Optional logprobs from inference for alignment check
|
||||
config: Training configuration (includes kl_coef, clip_eps, use_reference_logprobs)
|
||||
inference_logprob_batches: Batched logprobs from inference (π_old), aligned with labels
|
||||
|
||||
Returns:
|
||||
Dict of training metrics for this step
|
||||
|
|
@ -341,16 +421,32 @@ def run_training_step(
|
|||
total_neg_logp = 0.0
|
||||
total_pos = 0.0
|
||||
total_neg = 0.0
|
||||
total_kl_penalty = 0.0
|
||||
total_mean_ratio = 0.0
|
||||
total_mean_kl = 0.0
|
||||
total_clipped_fraction = 0.0
|
||||
grad_norm = 0.0
|
||||
all_training_logprobs: List[torch.Tensor] = []
|
||||
|
||||
# Get GRPO hyperparameters from config
|
||||
kl_coef = getattr(config, 'kl_coef', 0.1)
|
||||
clip_eps = getattr(config, 'clip_eps', 0.2)
|
||||
use_reference_logprobs = getattr(config, 'use_reference_logprobs', True)
|
||||
|
||||
# Accumulate gradients over micro-batches
|
||||
for tokens, labels, advantages, temperatures in zip(
|
||||
num_batches = len(token_batches) if token_batches else 1
|
||||
|
||||
for batch_idx, (tokens, labels, advantages, temperatures) in enumerate(zip(
|
||||
token_batches, label_batches, advantage_batches, temperature_batches
|
||||
):
|
||||
)):
|
||||
tokens = tokens.to(config.device)
|
||||
labels = labels.to(config.device)
|
||||
advantages = advantages.to(config.device)
|
||||
|
||||
# Get corresponding inference logprobs batch if available
|
||||
inf_logprobs = None
|
||||
if inference_logprob_batches is not None and batch_idx < len(inference_logprob_batches):
|
||||
inf_logprobs = inference_logprob_batches[batch_idx]
|
||||
|
||||
loss, metrics = compute_grpo_loss(
|
||||
model,
|
||||
|
|
@ -359,6 +455,10 @@ def run_training_step(
|
|||
advantages,
|
||||
temperatures,
|
||||
config.gradient_accumulation_steps,
|
||||
inference_logprobs=inf_logprobs,
|
||||
kl_coef=kl_coef,
|
||||
clip_eps=clip_eps,
|
||||
use_reference_logprobs=use_reference_logprobs,
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
|
|
@ -368,7 +468,13 @@ def run_training_step(
|
|||
total_pos += metrics["pos_count"]
|
||||
total_neg += metrics["neg_count"]
|
||||
|
||||
# Collect training logprobs for alignment check
|
||||
# Accumulate GRPO-specific metrics
|
||||
total_kl_penalty += metrics.get("kl_penalty", 0.0)
|
||||
total_mean_ratio += metrics.get("mean_ratio", 1.0)
|
||||
total_mean_kl += metrics.get("mean_kl", 0.0)
|
||||
total_clipped_fraction += metrics.get("clipped_fraction", 0.0)
|
||||
|
||||
# Collect training logprobs for alignment monitoring
|
||||
if "training_logprobs" in metrics:
|
||||
all_training_logprobs.append(metrics["training_logprobs"])
|
||||
|
||||
|
|
@ -380,8 +486,7 @@ def run_training_step(
|
|||
# Help prevent memory fragmentation
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Normalize metrics by count
|
||||
num_batches = len(token_batches) if token_batches else 1
|
||||
# Normalize metrics by batch count
|
||||
if total_pos > 0:
|
||||
total_pos_logp /= num_batches
|
||||
if total_neg > 0:
|
||||
|
|
@ -394,18 +499,22 @@ def run_training_step(
|
|||
"neg_logp": total_neg_logp,
|
||||
"pos_count": total_pos,
|
||||
"neg_count": total_neg,
|
||||
# GRPO-specific metrics (averaged over batches)
|
||||
"kl_penalty": total_kl_penalty / num_batches,
|
||||
"mean_ratio": total_mean_ratio / num_batches,
|
||||
"mean_kl": total_mean_kl / num_batches,
|
||||
"clipped_fraction": total_clipped_fraction / num_batches,
|
||||
}
|
||||
|
||||
# Compute logprob alignment stats
|
||||
# NOTE: This comparison is approximate - inference and training logprobs
|
||||
# come from different batching, so token-by-token alignment isn't possible.
|
||||
# The real-time test at startup is the reliable alignment check.
|
||||
if inference_logprobs is not None and all_training_logprobs:
|
||||
alignment_stats = compute_logprob_alignment(
|
||||
inference_logprobs, all_training_logprobs, debug=False
|
||||
)
|
||||
_logprob_alignment_stats.update(alignment_stats)
|
||||
result["logprob_alignment"] = alignment_stats
|
||||
# Compute logprob alignment stats for monitoring
|
||||
# NOTE: Now that we use proper GRPO, this is less critical
|
||||
# but still useful for debugging weight sharing issues
|
||||
if all_training_logprobs:
|
||||
# Store training logprob stats
|
||||
train_flat = torch.cat(all_training_logprobs)
|
||||
if train_flat.numel() > 0:
|
||||
_logprob_alignment_stats["logprobs/training_mean"] = train_flat.mean().item()
|
||||
_logprob_alignment_stats["logprobs/training_std"] = train_flat.std().item()
|
||||
|
||||
return result
|
||||
|
||||
|
|
@ -441,19 +550,27 @@ def log_metrics(
|
|||
if "gpu_memory_gb" in metrics:
|
||||
timing_str += f", GPU mem: {metrics['gpu_memory_gb']:.2f}GB"
|
||||
|
||||
# Show interpretable loss (advantage-weighted logprobs) if available
|
||||
interp_loss = metrics.get("interpretable_loss")
|
||||
if interp_loss is not None:
|
||||
print(f" AdvWeightedLogP: {interp_loss:.4f}, Grad norm: {metrics['grad_norm']:.4f}{timing_str}")
|
||||
else:
|
||||
loss_str = (
|
||||
f"{metrics['loss']:.6f}"
|
||||
if abs(metrics["loss"]) < 0.01
|
||||
else f"{metrics['loss']:.4f}"
|
||||
)
|
||||
print(f" Loss: {loss_str}, Grad norm: {metrics['grad_norm']:.4f}{timing_str}")
|
||||
# Primary metrics line: Loss and grad norm
|
||||
loss_str = (
|
||||
f"{metrics['loss']:.6f}"
|
||||
if abs(metrics["loss"]) < 0.01
|
||||
else f"{metrics['loss']:.4f}"
|
||||
)
|
||||
print(f" Loss: {loss_str}, Grad norm: {metrics['grad_norm']:.4f}{timing_str}")
|
||||
|
||||
# Show GRPO-specific metrics if available
|
||||
# GRPO metrics line: KL, ratio, clipping
|
||||
kl_penalty = metrics.get("kl_penalty", 0)
|
||||
mean_ratio = metrics.get("mean_ratio", 1.0)
|
||||
mean_kl = metrics.get("mean_kl", 0)
|
||||
clipped_frac = metrics.get("clipped_fraction", 0)
|
||||
|
||||
if kl_penalty > 0 or mean_kl > 0:
|
||||
print(
|
||||
f" GRPO: KL={mean_kl:.4f}, ratio={mean_ratio:.3f}, "
|
||||
f"clipped={clipped_frac*100:.1f}%"
|
||||
)
|
||||
|
||||
# Advantage distribution
|
||||
if "pos_count" in metrics or "neg_count" in metrics:
|
||||
pos_count = metrics.get("pos_count", 0)
|
||||
neg_count = metrics.get("neg_count", 0)
|
||||
|
|
@ -463,24 +580,6 @@ def log_metrics(
|
|||
f" Advantages: +{int(pos_count)} / -{int(neg_count)}, "
|
||||
f"LogP: pos={pos_logp:.3f}, neg={neg_logp:.3f}"
|
||||
)
|
||||
|
||||
# Show logprob alignment stats (important for shared_vllm validation!)
|
||||
if "logprob_alignment" in metrics:
|
||||
alignment = metrics["logprob_alignment"]
|
||||
if "logprobs/diff" in alignment:
|
||||
diff = alignment["logprobs/diff"]
|
||||
inf_mean = alignment.get("logprobs/inference_mean", 0)
|
||||
train_mean = alignment.get("logprobs/training_mean", 0)
|
||||
|
||||
# NOTE: This comparison has a fundamental timing issue!
|
||||
# - inference_logprobs: from vLLM at generation time (possibly stale)
|
||||
# - training_logprobs: from trainer's current forward pass
|
||||
# After training starts, weights change, making comparison invalid.
|
||||
#
|
||||
# NOTE: This diff is just for monitoring, not validation!
|
||||
# The real-time test at startup is the reliable alignment check.
|
||||
# This diff will naturally drift as training progresses (expected).
|
||||
print(f" LogProb Stats: inf_mean={inf_mean:.4f}, train_mean={train_mean:.4f}")
|
||||
|
||||
if use_wandb:
|
||||
log_dict = {
|
||||
|
|
@ -488,6 +587,11 @@ def log_metrics(
|
|||
"train/grad_norm": metrics["grad_norm"],
|
||||
"train/pos_logp": metrics.get("pos_logp", 0),
|
||||
"train/neg_logp": metrics.get("neg_logp", 0),
|
||||
# GRPO-specific metrics
|
||||
"grpo/kl_penalty": kl_penalty,
|
||||
"grpo/mean_ratio": mean_ratio,
|
||||
"grpo/mean_kl": mean_kl,
|
||||
"grpo/clipped_fraction": clipped_frac,
|
||||
}
|
||||
# Add timing metrics if present
|
||||
for key in ["step_time", "sync_time", "data_fetch_time",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue