diff --git a/example_trainer/data.py b/example_trainer/data.py index e79fecb0..74e1ffe1 100644 --- a/example_trainer/data.py +++ b/example_trainer/data.py @@ -120,35 +120,31 @@ def pad_data_to_good_offset( advantages.append(item["scores"][i]) # Extract and pad inference logprobs to match labels shape - # Inference logprobs are ONLY for generated tokens (where labels != -100) - # We need to create a padded array that aligns position-by-position + # IMPORTANT: inference_logprobs is ALREADY ALIGNED with tokens/masks: + # - 1.0 for prompt tokens (masked positions) + # - actual negative logprobs for generated tokens + # We just need to pad to match the sequence length, no realignment needed! if extract_inference_logprobs and "inference_logprobs" in item: if i < len(item["inference_logprobs"]): raw_logprobs = np.array(item["inference_logprobs"][i], dtype=np.float32) has_any_logprobs = True - # Create padded logprobs array matching label_item shape - # Fill with 0.0 (will be masked out during loss computation) - padded_logprobs = np.zeros(token_setup_len, dtype=np.float32) + # Create padded logprobs array matching token_setup_len + # Fill with 1.0 (the masked token placeholder value) for padding + padded_logprobs = np.full(token_setup_len, 1.0, dtype=np.float32) - # The inference logprobs correspond to generated tokens - # Find positions where labels != -100 (generated positions) - mask_arr = np.array(item["masks"][i]) - generated_positions = np.where(mask_arr != -100)[0] - - # Fill in inference logprobs at generated positions - n_to_fill = min(len(raw_logprobs), len(generated_positions)) - if n_to_fill > 0: - padded_logprobs[generated_positions[:n_to_fill]] = raw_logprobs[:n_to_fill] + # Copy raw_logprobs directly - they're already aligned with tokens + n_to_copy = min(len(raw_logprobs), token_setup_len) + padded_logprobs[:n_to_copy] = raw_logprobs[:n_to_copy] # Shift by 1 to match causal label shift inference_logprobs_padded.append(padded_logprobs[1:]) else: - # No logprobs for this sample, use zeros - inference_logprobs_padded.append(np.zeros(token_setup_len - 1, dtype=np.float32)) + # No logprobs for this sample, use 1.0 (masked placeholder) + inference_logprobs_padded.append(np.full(token_setup_len - 1, 1.0, dtype=np.float32)) elif extract_inference_logprobs: - # No inference_logprobs in item, use zeros - inference_logprobs_padded.append(np.zeros(token_setup_len - 1, dtype=np.float32)) + # No inference_logprobs in item, use 1.0 (masked placeholder) + inference_logprobs_padded.append(np.full(token_setup_len - 1, 1.0, dtype=np.float32)) # Extract temperature (priority: override > generation_params > group_overrides > 1.0) t = 1.0 diff --git a/example_trainer/training.py b/example_trainer/training.py index f36e9b63..39853fad 100644 --- a/example_trainer/training.py +++ b/example_trainer/training.py @@ -215,16 +215,19 @@ def compute_grpo_loss( ref_logprobs = inference_logprobs.to(logp_per_token.device, logp_per_token.dtype) # DEBUG: Check if inference logprobs look valid + # NOTE: inference_logprobs uses 1.0 for masked (prompt) positions, actual negative values for generated with torch.no_grad(): - ref_nonzero = (ref_logprobs != 0).float() - ref_nonzero_frac = (ref_nonzero * mask).sum() / mask.sum() - ref_mean = (ref_logprobs * mask).sum() / mask.sum() - train_mean = (logp_per_token * mask).sum() / mask.sum() - if ref_nonzero_frac < 0.5: - print(f" [WARNING] Only {ref_nonzero_frac*100:.1f}% of inference logprobs are non-zero!") - print(f" [WARNING] This suggests inference_logprobs field may be missing from data") - if abs(ref_mean - train_mean) > 1.0: - print(f" [DEBUG] Large logprob gap: ref_mean={ref_mean:.3f}, train_mean={train_mean:.3f}") + # Only look at generated positions (where mask == 1) + ref_at_generated = (ref_logprobs * mask).sum() / mask.sum() + train_at_generated = (logp_per_token * mask).sum() / mask.sum() + + # Check if ref logprobs are negative (as they should be for generated tokens) + # If ref_at_generated is close to 1.0, that means the 1.0 placeholder is being used + if ref_at_generated > 0.5: + print(f" [WARNING] ref_logprobs at generated positions avg {ref_at_generated:.3f} (should be negative!)") + print(f" [WARNING] This suggests inference_logprobs alignment is still wrong") + elif abs(ref_at_generated - train_at_generated) > 2.0: + print(f" [DEBUG] Logprob gap (may be OK for first step): ref={ref_at_generated:.3f}, train={train_at_generated:.3f}") # Compute importance sampling ratio: π(a|s) / π_old(a|s) = exp(log π - log π_old) log_ratio = logp_per_token - ref_logprobs