diff --git a/example_trainer/training.py b/example_trainer/training.py index 6c41556e..a242417f 100644 --- a/example_trainer/training.py +++ b/example_trainer/training.py @@ -121,9 +121,18 @@ def compute_grpo_loss( pos_logp = (logp_per_token * pos).mean().item() neg_logp = (logp_per_token * neg).mean().item() - # Collect training logprobs for masked positions (generated tokens only) + # For alignment check: compute logprobs WITHOUT temperature scaling + # This allows fair comparison with inference logprobs (which are at temp=1.0) + raw_logp_per_token = -F.cross_entropy( + outputs.logits.view(-1, outputs.logits.size(-1)), # Use original logits, not temp-scaled + labels.view(-1), + reduction="none", + ignore_index=-100, + ).view(labels.shape) + + # Collect raw training logprobs for masked positions (generated tokens only) # Keep as PyTorch tensor (supports bfloat16 natively) - training_logprobs_flat = logp_per_token[mask.bool()].detach() + training_logprobs_flat = raw_logp_per_token[mask.bool()].detach() # GRPO loss: weighted log probabilities by advantages grpo_loss_term = torch.exp(logp_per_token - logp_per_token.detach())