remove training code

This commit is contained in:
Jai Suphavadeeprasit 2026-03-13 12:52:52 -04:00
parent 862cd3667d
commit 148a4fd5eb
6 changed files with 38 additions and 329 deletions

View file

@ -70,11 +70,6 @@ def compute_grpo_loss(
gradient_accumulation_steps: int,
inference_logprobs: Optional[torch.Tensor] = None,
clip_eps: float = 0.2,
distill_token_ids: Optional[torch.Tensor] = None,
distill_logprobs: Optional[torch.Tensor] = None,
distill_enabled: bool = False,
distill_coef: float = 0.0,
distill_temperature: float = 1.0,
) -> Tuple[torch.Tensor, dict]:
"""
Compute GRPO (Group Relative Policy Optimization) loss for a single micro-batch.
@ -130,9 +125,6 @@ def compute_grpo_loss(
logprob_diff_abs_mean = 0.0
logprob_diff_max = 0.0
distill_loss_value = torch.tensor(0.0, device=logp_per_token.device)
distill_token_count = 0.0
# === GRPO/PPO Loss Computation ===
if inference_logprobs is not None:
# Move inference logprobs to correct device/dtype
@ -195,23 +187,7 @@ def compute_grpo_loss(
# Average over tokens, then over batch
policy_loss = ((policy_loss_per_token * mask).sum(dim=-1) / mask_sum).mean()
if (
distill_enabled
and distill_coef > 0
and distill_token_ids is not None
and distill_logprobs is not None
):
distill_loss_value, distill_token_count = compute_distillation_loss(
logits=scaled_logits,
labels=labels,
distill_token_ids=distill_token_ids.to(logits.device),
distill_logprobs=distill_logprobs.to(logits.device, logits.dtype),
temperature=max(1e-6, float(distill_temperature)),
)
total_loss = (policy_loss + distill_coef * distill_loss_value) / (
gradient_accumulation_steps
)
total_loss = policy_loss / gradient_accumulation_steps
# Compute metrics for logging
with torch.no_grad():
@ -277,77 +253,11 @@ def compute_grpo_loss(
"logprob_diff_mean": logprob_diff_mean,
"logprob_diff_abs_mean": logprob_diff_abs_mean,
"logprob_diff_max": logprob_diff_max,
"distill_loss": (
distill_loss_value.item()
if torch.is_tensor(distill_loss_value)
else float(distill_loss_value)
),
"distill_token_count": distill_token_count,
}
return total_loss, metrics
def compute_distillation_loss(
logits: torch.Tensor,
labels: torch.Tensor,
distill_token_ids: torch.Tensor,
distill_logprobs: torch.Tensor,
temperature: float = 1.0,
) -> Tuple[torch.Tensor, float]:
"""
Compute token-level distillation loss from teacher top-k prompt logprobs.
Args:
logits: Student logits [batch, seq_len, vocab]
labels: Labels [batch, seq_len], -100 for masked positions
distill_token_ids: Teacher top-k token IDs [batch, seq_len, k], -1 padded
distill_logprobs: Teacher top-k logprobs [batch, seq_len, k], very negative padded
temperature: Distillation temperature
Returns:
Tuple of (distillation loss scalar, valid token count)
"""
if distill_token_ids.dim() != 3 or distill_logprobs.dim() != 3:
return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0
if (
distill_token_ids.shape[:2] != labels.shape
or distill_logprobs.shape != distill_token_ids.shape
):
return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0
temp = max(1e-6, float(temperature))
valid_ids = distill_token_ids >= 0
label_mask = labels != -100
valid_pos = label_mask & valid_ids.any(dim=-1)
if not valid_pos.any():
return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0
gather_ids = distill_token_ids.clamp_min(0).long()
# Avoid materializing the full [batch, seq_len, vocab] log_softmax tensor
# (e.g. [2, 20480, 151936] = ~12.5 GB) which is the main cause of OOM/hangs.
# Instead: gather raw logits at top-k positions, then subtract logsumexp.
# Output tensors are [batch, seq_len, k] (tiny) not [batch, seq_len, vocab].
scaled_logits = logits / temp
log_normalizer = torch.logsumexp(scaled_logits, dim=-1, keepdim=True) # [b, s, 1]
student_logp_topk = (
torch.gather(scaled_logits, dim=-1, index=gather_ids) - log_normalizer
)
masked_teacher_logprobs = distill_logprobs.masked_fill(~valid_ids, -1e9)
teacher_probs = F.softmax(masked_teacher_logprobs / temp, dim=-1)
per_token_loss = -(teacher_probs * student_logp_topk).sum(dim=-1)
per_token_loss = per_token_loss * valid_pos.to(per_token_loss.dtype)
token_count = valid_pos.sum().item()
loss = per_token_loss.sum() / valid_pos.sum().clamp_min(1).to(per_token_loss.dtype)
return loss, float(token_count)
def run_training_step(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
@ -358,8 +268,6 @@ def run_training_step(
config: TrainingConfig,
step_idx: int,
inference_logprob_batches: Optional[List[torch.Tensor]] = None,
distill_token_id_batches: Optional[List[torch.Tensor]] = None,
distill_logprob_batches: Optional[List[torch.Tensor]] = None,
) -> dict:
"""
Run a single training step with gradient accumulation.
@ -394,8 +302,6 @@ def run_training_step(
total_logprob_diff_mean = 0.0
total_logprob_diff_abs_mean = 0.0
total_logprob_diff_max = 0.0
total_distill_loss = 0.0
total_distill_tokens = 0.0
grad_norm = 0.0
all_training_logprobs: List[torch.Tensor] = []
all_inference_logprobs: List[torch.Tensor] = []
@ -419,13 +325,6 @@ def run_training_step(
for batch_idx, (tokens, labels, advantages, temperatures) in enumerate(
zip(token_batches, label_batches, advantage_batches, temperature_batches)
):
print(
f" [Step] micro-batch {batch_idx+1}/{num_batches} "
f"tokens={tokens.shape} "
f"gpu_mem={torch.cuda.memory_allocated()/1e9:.1f}GB "
f"gpu_reserved={torch.cuda.memory_reserved()/1e9:.1f}GB",
flush=True,
)
tokens = tokens.to(config.device)
labels = labels.to(config.device)
advantages = advantages.to(config.device)
@ -436,18 +335,7 @@ def run_training_step(
inference_logprob_batches
):
inf_logprobs = inference_logprob_batches[batch_idx]
distill_ids = None
if distill_token_id_batches is not None and batch_idx < len(
distill_token_id_batches
):
distill_ids = distill_token_id_batches[batch_idx]
distill_lps = None
if distill_logprob_batches is not None and batch_idx < len(
distill_logprob_batches
):
distill_lps = distill_logprob_batches[batch_idx]
print(f" [Step] micro-batch {batch_idx+1} forward pass...", flush=True)
loss, metrics = compute_grpo_loss(
model,
tokens,
@ -457,20 +345,9 @@ def run_training_step(
config.gradient_accumulation_steps,
inference_logprobs=inf_logprobs,
clip_eps=clip_eps,
distill_token_ids=distill_ids,
distill_logprobs=distill_lps,
distill_enabled=bool(getattr(config, "distill_enabled", False)),
distill_coef=float(getattr(config, "distill_coef", 0.0)),
distill_temperature=float(getattr(config, "distill_temperature", 1.0)),
)
print(
f" [Step] micro-batch {batch_idx+1} loss={loss.item():.4f} "
f"backward...",
flush=True,
)
loss.backward()
print(f" [Step] micro-batch {batch_idx+1} backward done", flush=True)
total_loss += loss.item()
total_pos_logp += metrics["pos_logp"]
total_neg_logp += metrics["neg_logp"]
@ -487,8 +364,6 @@ def run_training_step(
total_logprob_diff_max = max(
total_logprob_diff_max, metrics.get("logprob_diff_max", 0.0)
)
total_distill_loss += metrics.get("distill_loss", 0.0)
total_distill_tokens += metrics.get("distill_token_count", 0.0)
# Collect logprobs for alignment monitoring
if "training_logprobs" in metrics and metrics["training_logprobs"] is not None:
@ -524,8 +399,6 @@ def run_training_step(
# GRPO-specific metrics (averaged over batches)
"mean_ratio": total_mean_ratio / num_batches,
"clipped_fraction": total_clipped_fraction / num_batches,
"distill_loss": total_distill_loss / num_batches,
"distill_token_count": total_distill_tokens,
}
# Compute logprob alignment stats for monitoring
@ -599,12 +472,6 @@ def log_metrics(
clipped_frac = metrics.get("clipped_fraction", 0)
print(f" GRPO: ratio={mean_ratio:.3f}, clipped={clipped_frac*100:.1f}%")
if metrics.get("distill_token_count", 0) > 0:
print(
" Distill: "
f"loss={metrics.get('distill_loss', 0.0):.4f}, "
f"tokens={int(metrics.get('distill_token_count', 0))}"
)
# Advantage distribution
if "pos_count" in metrics or "neg_count" in metrics:
@ -627,8 +494,6 @@ def log_metrics(
# GRPO-specific metrics
"grpo/mean_ratio": mean_ratio,
"grpo/clipped_fraction": clipped_frac,
"distill/loss": metrics.get("distill_loss", 0.0),
"distill/token_count": metrics.get("distill_token_count", 0.0),
}
# Add timing metrics if present
for key in [