mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Update grpo.py
This commit is contained in:
parent
6dccdcc67e
commit
eb179e7fca
1 changed files with 3 additions and 2 deletions
|
|
@ -371,8 +371,9 @@ def train(config: TrainingConfig):
|
|||
with torch.no_grad():
|
||||
pos = (advantages > 0).float()
|
||||
neg = (advantages <= 0).float()
|
||||
mask_sum = mask.sum(-1).clamp(min=1e-8)
|
||||
avg_logp = (logp_per_token * mask).sum(-1)
|
||||
mask = mask.to(logp_per_token.dtype)
|
||||
mask_sum = mask.sum(dim=-1).clamp_min(1e-8)
|
||||
avg_logp = (logp_per_token * mask).sum(dim=-1) / mask_sum
|
||||
pos_logp = (logp_per_token * pos).mean().item()
|
||||
neg_logp = (logp_per_token * neg).mean().item()
|
||||
total_pos_logp += pos_logp
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue