mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
logprob alignment
This commit is contained in:
parent
871f846b10
commit
24b8ab8574
2 changed files with 26 additions and 27 deletions
|
|
@ -120,35 +120,31 @@ def pad_data_to_good_offset(
|
|||
advantages.append(item["scores"][i])
|
||||
|
||||
# Extract and pad inference logprobs to match labels shape
|
||||
# Inference logprobs are ONLY for generated tokens (where labels != -100)
|
||||
# We need to create a padded array that aligns position-by-position
|
||||
# IMPORTANT: inference_logprobs is ALREADY ALIGNED with tokens/masks:
|
||||
# - 1.0 for prompt tokens (masked positions)
|
||||
# - actual negative logprobs for generated tokens
|
||||
# We just need to pad to match the sequence length, no realignment needed!
|
||||
if extract_inference_logprobs and "inference_logprobs" in item:
|
||||
if i < len(item["inference_logprobs"]):
|
||||
raw_logprobs = np.array(item["inference_logprobs"][i], dtype=np.float32)
|
||||
has_any_logprobs = True
|
||||
|
||||
# Create padded logprobs array matching label_item shape
|
||||
# Fill with 0.0 (will be masked out during loss computation)
|
||||
padded_logprobs = np.zeros(token_setup_len, dtype=np.float32)
|
||||
# Create padded logprobs array matching token_setup_len
|
||||
# Fill with 1.0 (the masked token placeholder value) for padding
|
||||
padded_logprobs = np.full(token_setup_len, 1.0, dtype=np.float32)
|
||||
|
||||
# The inference logprobs correspond to generated tokens
|
||||
# Find positions where labels != -100 (generated positions)
|
||||
mask_arr = np.array(item["masks"][i])
|
||||
generated_positions = np.where(mask_arr != -100)[0]
|
||||
|
||||
# Fill in inference logprobs at generated positions
|
||||
n_to_fill = min(len(raw_logprobs), len(generated_positions))
|
||||
if n_to_fill > 0:
|
||||
padded_logprobs[generated_positions[:n_to_fill]] = raw_logprobs[:n_to_fill]
|
||||
# Copy raw_logprobs directly - they're already aligned with tokens
|
||||
n_to_copy = min(len(raw_logprobs), token_setup_len)
|
||||
padded_logprobs[:n_to_copy] = raw_logprobs[:n_to_copy]
|
||||
|
||||
# Shift by 1 to match causal label shift
|
||||
inference_logprobs_padded.append(padded_logprobs[1:])
|
||||
else:
|
||||
# No logprobs for this sample, use zeros
|
||||
inference_logprobs_padded.append(np.zeros(token_setup_len - 1, dtype=np.float32))
|
||||
# No logprobs for this sample, use 1.0 (masked placeholder)
|
||||
inference_logprobs_padded.append(np.full(token_setup_len - 1, 1.0, dtype=np.float32))
|
||||
elif extract_inference_logprobs:
|
||||
# No inference_logprobs in item, use zeros
|
||||
inference_logprobs_padded.append(np.zeros(token_setup_len - 1, dtype=np.float32))
|
||||
# No inference_logprobs in item, use 1.0 (masked placeholder)
|
||||
inference_logprobs_padded.append(np.full(token_setup_len - 1, 1.0, dtype=np.float32))
|
||||
|
||||
# Extract temperature (priority: override > generation_params > group_overrides > 1.0)
|
||||
t = 1.0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue