training kernel

This commit is contained in:
Jai Suphavadeeprasit 2026-03-12 12:20:54 -04:00
parent a54dfe7a13
commit 62ef2fcc2e

View file

@ -315,7 +315,6 @@ def compute_distillation_loss(
return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0
temp = max(1e-6, float(temperature))
student_log_probs = F.log_softmax(logits / temp, dim=-1)
valid_ids = distill_token_ids >= 0
label_mask = labels != -100
@ -324,7 +323,14 @@ def compute_distillation_loss(
return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0
gather_ids = distill_token_ids.clamp_min(0).long()
student_logp_topk = torch.gather(student_log_probs, dim=-1, index=gather_ids)
# Avoid materializing the full [batch, seq_len, vocab] log_softmax tensor
# (e.g. [2, 20480, 151936] = ~12.5 GB) which is the main cause of OOM/hangs.
# Instead: gather raw logits at top-k positions, then subtract logsumexp.
# Output tensors are [batch, seq_len, k] (tiny) not [batch, seq_len, vocab].
scaled_logits = logits / temp
log_normalizer = torch.logsumexp(scaled_logits, dim=-1, keepdim=True) # [b, s, 1]
student_logp_topk = torch.gather(scaled_logits, dim=-1, index=gather_ids) - log_normalizer
masked_teacher_logprobs = distill_logprobs.masked_fill(~valid_ids, -1e9)
teacher_probs = F.softmax(masked_teacher_logprobs / temp, dim=-1)
@ -408,6 +414,13 @@ def run_training_step(
for batch_idx, (tokens, labels, advantages, temperatures) in enumerate(
zip(token_batches, label_batches, advantage_batches, temperature_batches)
):
print(
f" [Step] micro-batch {batch_idx+1}/{num_batches} "
f"tokens={tokens.shape} "
f"gpu_mem={torch.cuda.memory_allocated()/1e9:.1f}GB "
f"gpu_reserved={torch.cuda.memory_reserved()/1e9:.1f}GB",
flush=True,
)
tokens = tokens.to(config.device)
labels = labels.to(config.device)
advantages = advantages.to(config.device)
@ -429,6 +442,7 @@ def run_training_step(
):
distill_lps = distill_logprob_batches[batch_idx]
print(f" [Step] micro-batch {batch_idx+1} forward pass...", flush=True)
loss, metrics = compute_grpo_loss(
model,
tokens,
@ -445,7 +459,13 @@ def run_training_step(
distill_temperature=float(getattr(config, "distill_temperature", 1.0)),
)
print(
f" [Step] micro-batch {batch_idx+1} loss={loss.item():.4f} "
f"backward...",
flush=True,
)
loss.backward()
print(f" [Step] micro-batch {batch_idx+1} backward done", flush=True)
total_loss += loss.item()
total_pos_logp += metrics["pos_logp"]
total_neg_logp += metrics["neg_logp"]