mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-26 17:13:09 +00:00
readme updates
This commit is contained in:
parent
e34ac31ed7
commit
d23dfe75b4
1 changed files with 11 additions and 2 deletions
|
|
@ -121,9 +121,18 @@ def compute_grpo_loss(
|
|||
pos_logp = (logp_per_token * pos).mean().item()
|
||||
neg_logp = (logp_per_token * neg).mean().item()
|
||||
|
||||
# Collect training logprobs for masked positions (generated tokens only)
|
||||
# For alignment check: compute logprobs WITHOUT temperature scaling
|
||||
# This allows fair comparison with inference logprobs (which are at temp=1.0)
|
||||
raw_logp_per_token = -F.cross_entropy(
|
||||
outputs.logits.view(-1, outputs.logits.size(-1)), # Use original logits, not temp-scaled
|
||||
labels.view(-1),
|
||||
reduction="none",
|
||||
ignore_index=-100,
|
||||
).view(labels.shape)
|
||||
|
||||
# Collect raw training logprobs for masked positions (generated tokens only)
|
||||
# Keep as PyTorch tensor (supports bfloat16 natively)
|
||||
training_logprobs_flat = logp_per_token[mask.bool()].detach()
|
||||
training_logprobs_flat = raw_logp_per_token[mask.bool()].detach()
|
||||
|
||||
# GRPO loss: weighted log probabilities by advantages
|
||||
grpo_loss_term = torch.exp(logp_per_token - logp_per_token.detach())
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue