Update grpo.py

This commit is contained in:
Brawn 2025-08-14 20:20:41 +03:00 committed by GitHub
parent 6dccdcc67e
commit eb179e7fca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -371,8 +371,9 @@ def train(config: TrainingConfig):
with torch.no_grad():
pos = (advantages > 0).float()
neg = (advantages <= 0).float()
mask_sum = mask.sum(-1).clamp(min=1e-8)
avg_logp = (logp_per_token * mask).sum(-1)
mask = mask.to(logp_per_token.dtype)
mask_sum = mask.sum(dim=-1).clamp_min(1e-8)
avg_logp = (logp_per_token * mask).sum(dim=-1) / mask_sum
pos_logp = (logp_per_token * pos).mean().item()
neg_logp = (logp_per_token * neg).mean().item()
total_pos_logp += pos_logp