mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
change OPD style
This commit is contained in:
parent
33f5696171
commit
527433b5bc
10 changed files with 452 additions and 90 deletions
|
|
@ -7,6 +7,10 @@ and advantage normalization.
|
|||
Also extracts inference logprobs for proper GRPO loss computation:
|
||||
- Inference logprobs serve as π_old (reference policy) for importance sampling
|
||||
- They are batched and padded to align token-by-token with training labels
|
||||
|
||||
Also supports optional on-policy distillation arrays:
|
||||
- distill_token_ids[seq][pos][top_k]
|
||||
- distill_logprobs[seq][pos][top_k]
|
||||
"""
|
||||
|
||||
import math
|
||||
|
|
@ -29,6 +33,8 @@ def pad_data_to_good_offset(
|
|||
List[torch.Tensor], # advantage_batches
|
||||
List[torch.Tensor], # temperature_batches
|
||||
Optional[List[torch.Tensor]], # inference_logprob_batches (aligned with labels)
|
||||
Optional[List[list]], # distill_token_id_batches (nested ragged arrays)
|
||||
Optional[List[list]], # distill_logprob_batches (nested ragged arrays)
|
||||
]:
|
||||
"""
|
||||
Pad and batch data from the Atropos API.
|
||||
|
|
@ -45,7 +51,15 @@ def pad_data_to_good_offset(
|
|||
extract_inference_logprobs: Whether to extract inference logprobs
|
||||
|
||||
Returns:
|
||||
Tuple of (token_batches, label_batches, advantage_batches, temperature_batches, inference_logprob_batches)
|
||||
Tuple of (
|
||||
token_batches,
|
||||
label_batches,
|
||||
advantage_batches,
|
||||
temperature_batches,
|
||||
inference_logprob_batches,
|
||||
distill_token_id_batches,
|
||||
distill_logprob_batches,
|
||||
)
|
||||
inference_logprob_batches is None if extract_inference_logprobs=False or no logprobs in data
|
||||
|
||||
Note:
|
||||
|
|
@ -73,6 +87,9 @@ def pad_data_to_good_offset(
|
|||
temperatures = []
|
||||
inference_logprobs_padded: List[np.ndarray] = [] # Padded to match labels shape
|
||||
has_any_logprobs = False
|
||||
distill_token_ids_padded: List[list] = []
|
||||
distill_logprobs_padded: List[list] = []
|
||||
has_any_distill = False
|
||||
|
||||
for item in data["batch"]:
|
||||
# Normalize advantage scores
|
||||
|
|
@ -153,6 +170,36 @@ def pad_data_to_good_offset(
|
|||
np.full(token_setup_len - 1, 1.0, dtype=np.float32)
|
||||
)
|
||||
|
||||
# Extract optional distillation arrays.
|
||||
# Format:
|
||||
# distill_token_ids[seq][pos][top_k], distill_logprobs[seq][pos][top_k]
|
||||
seq_token_ids = None
|
||||
seq_logprobs = None
|
||||
if (
|
||||
isinstance(item.get("distill_token_ids"), list)
|
||||
and isinstance(item.get("distill_logprobs"), list)
|
||||
and i < len(item["distill_token_ids"])
|
||||
and i < len(item["distill_logprobs"])
|
||||
):
|
||||
seq_token_ids = item["distill_token_ids"][i]
|
||||
seq_logprobs = item["distill_logprobs"][i]
|
||||
|
||||
if seq_token_ids is not None and seq_logprobs is not None:
|
||||
has_any_distill = True
|
||||
seq_target_len = token_setup_len - 1
|
||||
padded_ids = list(seq_token_ids[:seq_target_len]) + [
|
||||
[] for _ in range(max(0, seq_target_len - len(seq_token_ids)))
|
||||
]
|
||||
padded_lps = list(seq_logprobs[:seq_target_len]) + [
|
||||
[] for _ in range(max(0, seq_target_len - len(seq_logprobs)))
|
||||
]
|
||||
distill_token_ids_padded.append(padded_ids)
|
||||
distill_logprobs_padded.append(padded_lps)
|
||||
else:
|
||||
seq_target_len = token_setup_len - 1
|
||||
distill_token_ids_padded.append([[] for _ in range(seq_target_len)])
|
||||
distill_logprobs_padded.append([[] for _ in range(seq_target_len)])
|
||||
|
||||
# Extract temperature (priority: override > generation_params > group_overrides > 1.0)
|
||||
t = 1.0
|
||||
if (
|
||||
|
|
@ -178,6 +225,8 @@ def pad_data_to_good_offset(
|
|||
advantage_batches = []
|
||||
temperature_batches = []
|
||||
inference_logprob_batches = []
|
||||
distill_token_id_batches = []
|
||||
distill_logprob_batches = []
|
||||
|
||||
for start in range(0, len(input_ids), batch_size):
|
||||
end = min(start + batch_size, len(input_ids))
|
||||
|
|
@ -198,6 +247,9 @@ def pad_data_to_good_offset(
|
|||
inference_logprob_batches.append(
|
||||
torch.tensor(np.stack(inference_logprobs_padded[start:end], axis=0))
|
||||
)
|
||||
if distill_token_ids_padded:
|
||||
distill_token_id_batches.append(distill_token_ids_padded[start:end])
|
||||
distill_logprob_batches.append(distill_logprobs_padded[start:end])
|
||||
|
||||
# Return inference logprob batches if we have any real logprobs
|
||||
final_logprob_batches = (
|
||||
|
|
@ -205,6 +257,12 @@ def pad_data_to_good_offset(
|
|||
if (has_any_logprobs and inference_logprob_batches)
|
||||
else None
|
||||
)
|
||||
final_distill_token_id_batches = (
|
||||
distill_token_id_batches if (has_any_distill and distill_token_id_batches) else None
|
||||
)
|
||||
final_distill_logprob_batches = (
|
||||
distill_logprob_batches if (has_any_distill and distill_logprob_batches) else None
|
||||
)
|
||||
|
||||
return (
|
||||
token_batches,
|
||||
|
|
@ -212,6 +270,8 @@ def pad_data_to_good_offset(
|
|||
advantage_batches,
|
||||
temperature_batches,
|
||||
final_logprob_batches,
|
||||
final_distill_token_id_batches,
|
||||
final_distill_logprob_batches,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -228,6 +288,8 @@ def get_data(
|
|||
List[torch.Tensor], # advantage_batches
|
||||
List[torch.Tensor], # temperature_batches
|
||||
Optional[List[torch.Tensor]], # inference_logprob_batches
|
||||
Optional[List[list]], # distill_token_id_batches
|
||||
Optional[List[list]], # distill_logprob_batches
|
||||
]
|
||||
],
|
||||
None, # Legacy return (no longer used)
|
||||
|
|
@ -247,7 +309,15 @@ def get_data(
|
|||
Returns:
|
||||
Tuple of (batches, None)
|
||||
- batches: List of processed batch tuples, each containing:
|
||||
(token_batches, label_batches, advantage_batches, temperature_batches, inference_logprob_batches)
|
||||
(
|
||||
token_batches,
|
||||
label_batches,
|
||||
advantage_batches,
|
||||
temperature_batches,
|
||||
inference_logprob_batches,
|
||||
distill_token_id_batches,
|
||||
distill_logprob_batches,
|
||||
)
|
||||
- inference_logprob_batches are aligned with labels for proper GRPO loss computation
|
||||
"""
|
||||
batches = []
|
||||
|
|
@ -299,6 +369,8 @@ def get_data(
|
|||
adv_batches,
|
||||
temp_batches,
|
||||
inf_logprob_batches,
|
||||
distill_token_id_batches,
|
||||
distill_logprob_batches,
|
||||
) = pad_data_to_good_offset(data, batch_size, extract_inference_logprobs)
|
||||
|
||||
# Include inference logprob batches in the tuple
|
||||
|
|
@ -309,6 +381,8 @@ def get_data(
|
|||
adv_batches,
|
||||
temp_batches,
|
||||
inf_logprob_batches,
|
||||
distill_token_id_batches,
|
||||
distill_logprob_batches,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue