logprob alignment

This commit is contained in:
Jai Suphavadeeprasit 2026-02-02 16:27:09 -05:00
parent 871f846b10
commit 24b8ab8574
2 changed files with 26 additions and 27 deletions

View file

@ -120,35 +120,31 @@ def pad_data_to_good_offset(
advantages.append(item["scores"][i])
# Extract and pad inference logprobs to match labels shape
# Inference logprobs are ONLY for generated tokens (where labels != -100)
# We need to create a padded array that aligns position-by-position
# IMPORTANT: inference_logprobs is ALREADY ALIGNED with tokens/masks:
# - 1.0 for prompt tokens (masked positions)
# - actual negative logprobs for generated tokens
# We just need to pad to match the sequence length, no realignment needed!
if extract_inference_logprobs and "inference_logprobs" in item:
if i < len(item["inference_logprobs"]):
raw_logprobs = np.array(item["inference_logprobs"][i], dtype=np.float32)
has_any_logprobs = True
# Create padded logprobs array matching label_item shape
# Fill with 0.0 (will be masked out during loss computation)
padded_logprobs = np.zeros(token_setup_len, dtype=np.float32)
# Create padded logprobs array matching token_setup_len
# Fill with 1.0 (the masked token placeholder value) for padding
padded_logprobs = np.full(token_setup_len, 1.0, dtype=np.float32)
# The inference logprobs correspond to generated tokens
# Find positions where labels != -100 (generated positions)
mask_arr = np.array(item["masks"][i])
generated_positions = np.where(mask_arr != -100)[0]
# Fill in inference logprobs at generated positions
n_to_fill = min(len(raw_logprobs), len(generated_positions))
if n_to_fill > 0:
padded_logprobs[generated_positions[:n_to_fill]] = raw_logprobs[:n_to_fill]
# Copy raw_logprobs directly - they're already aligned with tokens
n_to_copy = min(len(raw_logprobs), token_setup_len)
padded_logprobs[:n_to_copy] = raw_logprobs[:n_to_copy]
# Shift by 1 to match causal label shift
inference_logprobs_padded.append(padded_logprobs[1:])
else:
# No logprobs for this sample, use zeros
inference_logprobs_padded.append(np.zeros(token_setup_len - 1, dtype=np.float32))
# No logprobs for this sample, use 1.0 (masked placeholder)
inference_logprobs_padded.append(np.full(token_setup_len - 1, 1.0, dtype=np.float32))
elif extract_inference_logprobs:
# No inference_logprobs in item, use zeros
inference_logprobs_padded.append(np.zeros(token_setup_len - 1, dtype=np.float32))
# No inference_logprobs in item, use 1.0 (masked placeholder)
inference_logprobs_padded.append(np.full(token_setup_len - 1, 1.0, dtype=np.float32))
# Extract temperature (priority: override > generation_params > group_overrides > 1.0)
t = 1.0