mirror of
https://github.com/NousResearch/atropos.git
synced 2026-05-02 17:45:50 +00:00
logprob alignment
This commit is contained in:
parent
871f846b10
commit
24b8ab8574
2 changed files with 26 additions and 27 deletions
|
|
@ -215,16 +215,19 @@ def compute_grpo_loss(
|
|||
ref_logprobs = inference_logprobs.to(logp_per_token.device, logp_per_token.dtype)
|
||||
|
||||
# DEBUG: Check if inference logprobs look valid
|
||||
# NOTE: inference_logprobs uses 1.0 for masked (prompt) positions, actual negative values for generated
|
||||
with torch.no_grad():
|
||||
ref_nonzero = (ref_logprobs != 0).float()
|
||||
ref_nonzero_frac = (ref_nonzero * mask).sum() / mask.sum()
|
||||
ref_mean = (ref_logprobs * mask).sum() / mask.sum()
|
||||
train_mean = (logp_per_token * mask).sum() / mask.sum()
|
||||
if ref_nonzero_frac < 0.5:
|
||||
print(f" [WARNING] Only {ref_nonzero_frac*100:.1f}% of inference logprobs are non-zero!")
|
||||
print(f" [WARNING] This suggests inference_logprobs field may be missing from data")
|
||||
if abs(ref_mean - train_mean) > 1.0:
|
||||
print(f" [DEBUG] Large logprob gap: ref_mean={ref_mean:.3f}, train_mean={train_mean:.3f}")
|
||||
# Only look at generated positions (where mask == 1)
|
||||
ref_at_generated = (ref_logprobs * mask).sum() / mask.sum()
|
||||
train_at_generated = (logp_per_token * mask).sum() / mask.sum()
|
||||
|
||||
# Check if ref logprobs are negative (as they should be for generated tokens)
|
||||
# If ref_at_generated is close to 1.0, that means the 1.0 placeholder is being used
|
||||
if ref_at_generated > 0.5:
|
||||
print(f" [WARNING] ref_logprobs at generated positions avg {ref_at_generated:.3f} (should be negative!)")
|
||||
print(f" [WARNING] This suggests inference_logprobs alignment is still wrong")
|
||||
elif abs(ref_at_generated - train_at_generated) > 2.0:
|
||||
print(f" [DEBUG] Logprob gap (may be OK for first step): ref={ref_at_generated:.3f}, train={train_at_generated:.3f}")
|
||||
|
||||
# Compute importance sampling ratio: π(a|s) / π_old(a|s) = exp(log π - log π_old)
|
||||
log_ratio = logp_per_token - ref_logprobs
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue