mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-25 17:10:42 +00:00
cleanup
This commit is contained in:
parent
04f2850980
commit
0b61dd047a
2 changed files with 453 additions and 540 deletions
|
|
@ -12,7 +12,6 @@ import string
|
|||
import time
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import wandb
|
||||
|
|
@ -131,7 +130,7 @@ def compute_grpo_loss(
|
|||
# Move inference logprobs to correct device/dtype
|
||||
ref_logprobs = inference_logprobs.to(logp_per_token.device, logp_per_token.dtype)
|
||||
|
||||
# DEBUG: Check if inference logprobs look valid
|
||||
|
||||
# NOTE: inference_logprobs uses 1.0 for masked (prompt) positions, actual negative values for generated
|
||||
with torch.no_grad():
|
||||
# Only look at generated positions (where mask == 1)
|
||||
|
|
@ -146,7 +145,7 @@ def compute_grpo_loss(
|
|||
elif abs(ref_at_generated - train_at_generated) > 2.0:
|
||||
print(f" [DEBUG] Logprob gap (may be OK for first step): ref={ref_at_generated:.3f}, train={train_at_generated:.3f}")
|
||||
|
||||
# Compute importance sampling ratio: π(a|s) / π_old(a|s) = exp(log π - log π_old)
|
||||
# Compute importance sampling ratio: policy(a|s) / policy_old(a|s) = exp(log policy - log policy_old)
|
||||
log_ratio = logp_per_token - ref_logprobs
|
||||
ratio = torch.exp(log_ratio)
|
||||
|
||||
|
|
@ -159,7 +158,7 @@ def compute_grpo_loss(
|
|||
|
||||
# Pessimistic bound: min for positive advantages, max for negative
|
||||
# This is equivalent to: -min(ratio * A, clipped_ratio * A) when A > 0
|
||||
# -max(ratio * A, clipped_ratio * A) when A < 0
|
||||
# -max(ratio * A, clipped_ratio * A) when A < 0
|
||||
policy_loss_per_token = -torch.where(
|
||||
adv_expanded >= 0,
|
||||
torch.min(surr1, surr2),
|
||||
|
|
@ -171,14 +170,6 @@ def compute_grpo_loss(
|
|||
|
||||
# KL penalty: encourage staying close to reference policy
|
||||
# Using Schulman's unbiased KL estimator from the DeepSeek GRPO paper (Equation 4):
|
||||
# D_KL(π_θ || π_ref) = (π_ref / π_θ) - log(π_ref / π_θ) - 1
|
||||
#
|
||||
# In terms of log probabilities:
|
||||
# log_ratio = log π_θ - log π_ref (what we computed above)
|
||||
# ratio_ref_over_pi = exp(-log_ratio) = π_ref / π_θ
|
||||
# kl = ratio_ref_over_pi - log(ratio_ref_over_pi) - 1
|
||||
# = exp(-log_ratio) + log_ratio - 1
|
||||
#
|
||||
# This estimator is guaranteed to be non-negative (unlike squared log-ratio).
|
||||
if kl_coef > 0:
|
||||
# Schulman's unbiased KL estimator: (π_ref/π) - log(π_ref/π) - 1
|
||||
|
|
@ -211,26 +202,18 @@ def compute_grpo_loss(
|
|||
).view(labels.shape)
|
||||
training_logprobs_flat = raw_logp_per_token[mask.bool()].detach()
|
||||
else:
|
||||
# Fallback: REINFORCE-style (no reference policy)
|
||||
# This is what the original code did - NOT recommended!
|
||||
print(" [WARNING] No reference logprobs - using REINFORCE (may cause reward hacking!)")
|
||||
|
||||
# Simple policy gradient: -log(π) * A
|
||||
policy_loss = ((-logp_per_token * mask * adv_expanded).sum(dim=-1) / mask_sum).mean()
|
||||
total_loss = policy_loss / gradient_accumulation_steps
|
||||
kl_penalty = torch.tensor(0.0, device=logp_per_token.device)
|
||||
|
||||
with torch.no_grad():
|
||||
clipped_fraction = torch.tensor(0.0)
|
||||
mean_ratio = torch.tensor(1.0)
|
||||
mean_kl = torch.tensor(0.0)
|
||||
raw_logp_per_token = -F.cross_entropy(
|
||||
outputs.logits.view(-1, outputs.logits.size(-1)),
|
||||
labels.view(-1),
|
||||
reduction="none",
|
||||
ignore_index=-100,
|
||||
).view(labels.shape)
|
||||
training_logprobs_flat = raw_logp_per_token[mask.bool()].detach()
|
||||
# Fail loudly
|
||||
raise ValueError(
|
||||
"GRPO requires inference_logprobs for importance sampling!\n"
|
||||
"\n"
|
||||
"This error means the environment isn't providing logprobs. To fix:\n"
|
||||
" 1. Use --openai.server_type vllm (not 'openai')\n"
|
||||
" 2. Ensure vLLM is returning logprobs in /generate response\n"
|
||||
" 3. Check that gsm8k_server is configured correctly\n"
|
||||
"\n"
|
||||
"Without inference logprobs, training will cause reward hacking.\n"
|
||||
"If you REALLY want vanilla REINFORCE (not recommended), set use_reference_logprobs=False"
|
||||
)
|
||||
|
||||
# === Compute Additional Metrics ===
|
||||
with torch.no_grad():
|
||||
|
|
@ -262,67 +245,6 @@ def compute_grpo_loss(
|
|||
return total_loss, metrics
|
||||
|
||||
|
||||
def compute_logprob_alignment(
|
||||
inference_logprobs: List[np.ndarray],
|
||||
training_logprobs: List[torch.Tensor],
|
||||
debug: bool = False,
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Compute alignment stats between inference and training logprobs.
|
||||
|
||||
At initialization (step 0), these should match closely if the model
|
||||
weights are correctly shared between training and inference.
|
||||
|
||||
Args:
|
||||
inference_logprobs: Logprobs from vLLM inference (numpy arrays)
|
||||
training_logprobs: Logprobs computed during training forward pass (PyTorch tensors, bfloat16 supported)
|
||||
debug: If True, print detailed debugging info
|
||||
|
||||
Returns:
|
||||
Dict of alignment statistics
|
||||
"""
|
||||
if not inference_logprobs or not training_logprobs:
|
||||
return {}
|
||||
|
||||
# Process inference logprobs (numpy)
|
||||
inf_flat = np.concatenate(inference_logprobs)
|
||||
# Filter out placeholder values (1.0 or 0.0 used for prompt tokens)
|
||||
inf_mask = (inf_flat != 1.0) & (inf_flat != 0.0)
|
||||
inf_filtered = inf_flat[inf_mask]
|
||||
|
||||
# Process training logprobs (PyTorch - supports bfloat16 natively)
|
||||
train_flat = torch.cat(training_logprobs)
|
||||
|
||||
if debug:
|
||||
print(f" [DEBUG] Inference: {len(inf_flat)} total, {len(inf_filtered)} after filter")
|
||||
print(f" [DEBUG] Training: {train_flat.numel()} logprobs")
|
||||
if len(inf_filtered) > 0:
|
||||
print(f" [DEBUG] Inf sample (first 5): {inf_filtered[:5]}")
|
||||
if train_flat.numel() > 0:
|
||||
print(f" [DEBUG] Train sample (first 5): {train_flat[:5].tolist()}")
|
||||
|
||||
# Compute stats using PyTorch for training (keeps bfloat16 precision)
|
||||
stats = {}
|
||||
|
||||
if len(inf_filtered) > 0:
|
||||
stats["logprobs/inference_mean"] = float(np.mean(inf_filtered))
|
||||
stats["logprobs/inference_std"] = float(np.std(inf_filtered))
|
||||
|
||||
if train_flat.numel() > 0:
|
||||
# PyTorch operations - fully support bfloat16
|
||||
stats["logprobs/training_mean"] = train_flat.mean().item()
|
||||
stats["logprobs/training_std"] = train_flat.std().item()
|
||||
|
||||
# Compute diff (for tracking, not validation)
|
||||
# NOTE: Per-token comparison is NOT reliable here because inference and training
|
||||
# logprobs come from different batch orderings and can't be aligned token-by-token.
|
||||
# The real-time test at startup is the proper alignment validation.
|
||||
if "logprobs/inference_mean" in stats and "logprobs/training_mean" in stats:
|
||||
stats["logprobs/diff"] = stats["logprobs/inference_mean"] - stats["logprobs/training_mean"]
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def run_training_step(
|
||||
model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue