readme updates

This commit is contained in:
Jai Suphavadeeprasit 2026-01-27 14:38:12 -05:00
parent e34ac31ed7
commit d23dfe75b4

View file

@ -121,9 +121,18 @@ def compute_grpo_loss(
pos_logp = (logp_per_token * pos).mean().item()
neg_logp = (logp_per_token * neg).mean().item()
# Collect training logprobs for masked positions (generated tokens only)
# For alignment check: compute logprobs WITHOUT temperature scaling
# This allows fair comparison with inference logprobs (which are at temp=1.0)
raw_logp_per_token = -F.cross_entropy(
outputs.logits.view(-1, outputs.logits.size(-1)), # Use original logits, not temp-scaled
labels.view(-1),
reduction="none",
ignore_index=-100,
).view(labels.shape)
# Collect raw training logprobs for masked positions (generated tokens only)
# Keep as PyTorch tensor (supports bfloat16 natively)
training_logprobs_flat = logp_per_token[mask.bool()].detach()
training_logprobs_flat = raw_logp_per_token[mask.bool()].detach()
# GRPO loss: weighted log probabilities by advantages
grpo_loss_term = torch.exp(logp_per_token - logp_per_token.detach())