From eb179e7fca0648072053c235baa25e82781eabca Mon Sep 17 00:00:00 2001 From: Brawn Date: Thu, 14 Aug 2025 20:20:41 +0300 Subject: [PATCH] Update grpo.py --- example_trainer/grpo.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/example_trainer/grpo.py b/example_trainer/grpo.py index 8f11c7fd..fb340a2f 100644 --- a/example_trainer/grpo.py +++ b/example_trainer/grpo.py @@ -371,8 +371,9 @@ def train(config: TrainingConfig): with torch.no_grad(): pos = (advantages > 0).float() neg = (advantages <= 0).float() - mask_sum = mask.sum(-1).clamp(min=1e-8) - avg_logp = (logp_per_token * mask).sum(-1) + mask = mask.to(logp_per_token.dtype) + mask_sum = mask.sum(dim=-1).clamp_min(1e-8) + avg_logp = (logp_per_token * mask).sum(dim=-1) / mask_sum pos_logp = (logp_per_token * pos).mean().item() neg_logp = (logp_per_token * neg).mean().item() total_pos_logp += pos_logp