mirror of
https://github.com/NousResearch/atropos.git
synced 2026-05-01 17:45:16 +00:00
training kernel
This commit is contained in:
parent
a54dfe7a13
commit
62ef2fcc2e
1 changed files with 22 additions and 2 deletions
|
|
@ -315,7 +315,6 @@ def compute_distillation_loss(
|
|||
return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0
|
||||
|
||||
temp = max(1e-6, float(temperature))
|
||||
student_log_probs = F.log_softmax(logits / temp, dim=-1)
|
||||
|
||||
valid_ids = distill_token_ids >= 0
|
||||
label_mask = labels != -100
|
||||
|
|
@ -324,7 +323,14 @@ def compute_distillation_loss(
|
|||
return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0
|
||||
|
||||
gather_ids = distill_token_ids.clamp_min(0).long()
|
||||
student_logp_topk = torch.gather(student_log_probs, dim=-1, index=gather_ids)
|
||||
|
||||
# Avoid materializing the full [batch, seq_len, vocab] log_softmax tensor
|
||||
# (e.g. [2, 20480, 151936] = ~12.5 GB) which is the main cause of OOM/hangs.
|
||||
# Instead: gather raw logits at top-k positions, then subtract logsumexp.
|
||||
# Output tensors are [batch, seq_len, k] (tiny) not [batch, seq_len, vocab].
|
||||
scaled_logits = logits / temp
|
||||
log_normalizer = torch.logsumexp(scaled_logits, dim=-1, keepdim=True) # [b, s, 1]
|
||||
student_logp_topk = torch.gather(scaled_logits, dim=-1, index=gather_ids) - log_normalizer
|
||||
|
||||
masked_teacher_logprobs = distill_logprobs.masked_fill(~valid_ids, -1e9)
|
||||
teacher_probs = F.softmax(masked_teacher_logprobs / temp, dim=-1)
|
||||
|
|
@ -408,6 +414,13 @@ def run_training_step(
|
|||
for batch_idx, (tokens, labels, advantages, temperatures) in enumerate(
|
||||
zip(token_batches, label_batches, advantage_batches, temperature_batches)
|
||||
):
|
||||
print(
|
||||
f" [Step] micro-batch {batch_idx+1}/{num_batches} "
|
||||
f"tokens={tokens.shape} "
|
||||
f"gpu_mem={torch.cuda.memory_allocated()/1e9:.1f}GB "
|
||||
f"gpu_reserved={torch.cuda.memory_reserved()/1e9:.1f}GB",
|
||||
flush=True,
|
||||
)
|
||||
tokens = tokens.to(config.device)
|
||||
labels = labels.to(config.device)
|
||||
advantages = advantages.to(config.device)
|
||||
|
|
@ -429,6 +442,7 @@ def run_training_step(
|
|||
):
|
||||
distill_lps = distill_logprob_batches[batch_idx]
|
||||
|
||||
print(f" [Step] micro-batch {batch_idx+1} forward pass...", flush=True)
|
||||
loss, metrics = compute_grpo_loss(
|
||||
model,
|
||||
tokens,
|
||||
|
|
@ -445,7 +459,13 @@ def run_training_step(
|
|||
distill_temperature=float(getattr(config, "distill_temperature", 1.0)),
|
||||
)
|
||||
|
||||
print(
|
||||
f" [Step] micro-batch {batch_idx+1} loss={loss.item():.4f} "
|
||||
f"backward...",
|
||||
flush=True,
|
||||
)
|
||||
loss.backward()
|
||||
print(f" [Step] micro-batch {batch_idx+1} backward done", flush=True)
|
||||
total_loss += loss.item()
|
||||
total_pos_logp += metrics["pos_logp"]
|
||||
total_neg_logp += metrics["neg_logp"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue