diff --git a/example_trainer/training.py b/example_trainer/training.py index c5b739e9..673ed795 100644 --- a/example_trainer/training.py +++ b/example_trainer/training.py @@ -315,7 +315,6 @@ def compute_distillation_loss( return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0 temp = max(1e-6, float(temperature)) - student_log_probs = F.log_softmax(logits / temp, dim=-1) valid_ids = distill_token_ids >= 0 label_mask = labels != -100 @@ -324,7 +323,14 @@ def compute_distillation_loss( return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0 gather_ids = distill_token_ids.clamp_min(0).long() - student_logp_topk = torch.gather(student_log_probs, dim=-1, index=gather_ids) + + # Avoid materializing the full [batch, seq_len, vocab] log_softmax tensor + # (e.g. [2, 20480, 151936] = ~12.5 GB) which is the main cause of OOM/hangs. + # Instead: gather raw logits at top-k positions, then subtract logsumexp. + # Output tensors are [batch, seq_len, k] (tiny) not [batch, seq_len, vocab]. + scaled_logits = logits / temp + log_normalizer = torch.logsumexp(scaled_logits, dim=-1, keepdim=True) # [b, s, 1] + student_logp_topk = torch.gather(scaled_logits, dim=-1, index=gather_ids) - log_normalizer masked_teacher_logprobs = distill_logprobs.masked_fill(~valid_ids, -1e9) teacher_probs = F.softmax(masked_teacher_logprobs / temp, dim=-1) @@ -408,6 +414,13 @@ def run_training_step( for batch_idx, (tokens, labels, advantages, temperatures) in enumerate( zip(token_batches, label_batches, advantage_batches, temperature_batches) ): + print( + f" [Step] micro-batch {batch_idx+1}/{num_batches} " + f"tokens={tokens.shape} " + f"gpu_mem={torch.cuda.memory_allocated()/1e9:.1f}GB " + f"gpu_reserved={torch.cuda.memory_reserved()/1e9:.1f}GB", + flush=True, + ) tokens = tokens.to(config.device) labels = labels.to(config.device) advantages = advantages.to(config.device) @@ -429,6 +442,7 @@ def run_training_step( ): distill_lps = distill_logprob_batches[batch_idx] + print(f" [Step] micro-batch {batch_idx+1} forward pass...", flush=True) loss, metrics = compute_grpo_loss( model, tokens, @@ -445,7 +459,13 @@ def run_training_step( distill_temperature=float(getattr(config, "distill_temperature", 1.0)), ) + print( + f" [Step] micro-batch {batch_idx+1} loss={loss.item():.4f} " + f"backward...", + flush=True, + ) loss.backward() + print(f" [Step] micro-batch {batch_idx+1} backward done", flush=True) total_loss += loss.item() total_pos_logp += metrics["pos_logp"] total_neg_logp += metrics["neg_logp"]