mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
438 lines
18 KiB
Python
438 lines
18 KiB
Python
"""
|
|
Data processing utilities for GRPO trainer.
|
|
|
|
Handles data retrieval from Atropos API, padding, batching,
|
|
and advantage normalization.
|
|
|
|
Also extracts inference logprobs for proper GRPO loss computation:
|
|
- Inference logprobs are used in importance-ratio computation
|
|
- They are batched and padded to align token-by-token with training labels
|
|
"""
|
|
|
|
import math
|
|
import time
|
|
from typing import List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from .api import get_batch
|
|
|
|
|
|
def pad_data_to_good_offset(
|
|
data: dict,
|
|
batch_size: int,
|
|
extract_inference_logprobs: bool = True,
|
|
) -> Tuple[
|
|
List[torch.Tensor], # token_batches
|
|
List[torch.Tensor], # label_batches
|
|
List[torch.Tensor], # advantage_batches
|
|
List[torch.Tensor], # temperature_batches
|
|
Optional[List[torch.Tensor]], # inference_logprob_batches (aligned with labels)
|
|
Optional[List[torch.Tensor]], # distill_token_id_batches [batch, seq, k]
|
|
Optional[List[torch.Tensor]], # distill_logprob_batches [batch, seq, k]
|
|
]:
|
|
"""
|
|
Pad and batch data from the Atropos API.
|
|
|
|
Processes raw batch data into properly padded tensors suitable for training:
|
|
- Pads token sequences to nearest multiple of 64
|
|
- Normalizes advantage scores
|
|
- Extracts temperature values
|
|
- Extracts and pads inference logprobs for proper GRPO loss computation
|
|
|
|
Args:
|
|
data: Raw batch data from Atropos API
|
|
batch_size: Size of each training batch
|
|
extract_inference_logprobs: Whether to extract inference logprobs
|
|
|
|
Returns:
|
|
Tuple of (token_batches, label_batches, advantage_batches, temperature_batches,
|
|
inference_logprob_batches, distill_token_id_batches, distill_logprob_batches)
|
|
inference_logprob_batches is None if extract_inference_logprobs=False or no logprobs in data
|
|
|
|
Note:
|
|
inference_logprob_batches are padded with 0.0 at positions where labels == -100.
|
|
This allows token-by-token alignment during GRPO loss computation.
|
|
"""
|
|
max_token_len = max(
|
|
[max([len(x) for x in item["tokens"]]) for item in data["batch"]]
|
|
)
|
|
|
|
# Pad to nearest multiple of 64 for GPU efficiency
|
|
good_multiple = 64
|
|
if (max_token_len - 1) % (good_multiple) != 0:
|
|
max_token_len = math.ceil((max_token_len - 1) / (good_multiple)) * good_multiple
|
|
token_setup_len = max_token_len + 1 # +1 for causal shift
|
|
else:
|
|
token_setup_len = max_token_len
|
|
max_token_len = max_token_len - 1 # -1 for causal shift
|
|
|
|
# Process all items
|
|
input_ids = []
|
|
labels = []
|
|
advantages = []
|
|
lengths = []
|
|
temperatures = []
|
|
inference_logprobs_padded: List[np.ndarray] = [] # Padded to match labels shape
|
|
has_any_logprobs = False
|
|
distill_token_ids_padded: List[np.ndarray] = []
|
|
distill_logprobs_padded: List[np.ndarray] = []
|
|
has_any_distill = False
|
|
max_distill_k = 1
|
|
|
|
for item in data["batch"]:
|
|
# Normalize advantage scores
|
|
scores = np.array(item["scores"])
|
|
if len(scores) > 1:
|
|
scores = scores - scores.mean()
|
|
scores = scores / max(scores.std(), 1e-8)
|
|
item["scores"] = scores
|
|
|
|
# Handle score overrides
|
|
if item["overrides"] is not None:
|
|
for i in range(len(item["overrides"])):
|
|
if item["overrides"][i].get("set_advantage_to_zero", False):
|
|
item["scores"][i] = 0
|
|
|
|
# Process each sample in the item
|
|
for i in range(len(item["tokens"])):
|
|
seq_len = len(item["tokens"][i])
|
|
lengths.append(math.ceil((seq_len - 1) / good_multiple) * good_multiple)
|
|
|
|
# Create labels with padding (-100 for masked positions)
|
|
label_item = np.concatenate(
|
|
[
|
|
np.array(item["masks"][i]),
|
|
np.full(
|
|
max(0, token_setup_len - seq_len),
|
|
-100,
|
|
dtype=np.int32,
|
|
),
|
|
]
|
|
)
|
|
|
|
# Pad tokens
|
|
item["tokens"][i] = np.concatenate(
|
|
[
|
|
np.array(item["tokens"][i]),
|
|
np.zeros(
|
|
max(0, token_setup_len - seq_len),
|
|
dtype=np.int32,
|
|
),
|
|
]
|
|
)
|
|
|
|
input_ids.append(item["tokens"][i][:-1]) # Remove last for causal
|
|
labels.append(label_item[1:]) # Shift by 1 for causal
|
|
advantages.append(item["scores"][i])
|
|
|
|
# Extract and pad inference logprobs to match labels shape
|
|
# IMPORTANT: inference_logprobs is ALREADY ALIGNED with tokens/masks:
|
|
# - 1.0 for prompt tokens (masked positions)
|
|
# - actual negative logprobs for generated tokens
|
|
# We just need to pad to match the sequence length
|
|
if extract_inference_logprobs and "inference_logprobs" in item:
|
|
if i < len(item["inference_logprobs"]):
|
|
raw_logprobs = np.array(
|
|
item["inference_logprobs"][i], dtype=np.float32
|
|
)
|
|
has_any_logprobs = True
|
|
|
|
# Create padded logprobs array matching token_setup_len
|
|
# Fill with 1.0 (the masked token placeholder value) for padding
|
|
padded_logprobs = np.full(token_setup_len, 1.0, dtype=np.float32)
|
|
|
|
# Copy raw_logprobs directly - they're already aligned with tokens
|
|
n_to_copy = min(len(raw_logprobs), token_setup_len)
|
|
padded_logprobs[:n_to_copy] = raw_logprobs[:n_to_copy]
|
|
|
|
# Shift by 1 to match causal label shift
|
|
inference_logprobs_padded.append(padded_logprobs[1:])
|
|
else:
|
|
# No logprobs for this sample, use 1.0
|
|
inference_logprobs_padded.append(
|
|
np.full(token_setup_len - 1, 1.0, dtype=np.float32)
|
|
)
|
|
elif extract_inference_logprobs:
|
|
# No inference_logprobs in item, use 1.0
|
|
inference_logprobs_padded.append(
|
|
np.full(token_setup_len - 1, 1.0, dtype=np.float32)
|
|
)
|
|
|
|
# Extract teacher distillation top-k arrays if available.
|
|
# Expected shape in incoming payload: [sequence][position][k].
|
|
if "distill_token_ids" in item and "distill_logprobs" in item:
|
|
seq_token_ids = item["distill_token_ids"]
|
|
seq_logprobs = item["distill_logprobs"]
|
|
if (
|
|
isinstance(seq_token_ids, list)
|
|
and isinstance(seq_logprobs, list)
|
|
and i < len(seq_token_ids)
|
|
and i < len(seq_logprobs)
|
|
and seq_token_ids[i] is not None
|
|
and seq_logprobs[i] is not None
|
|
):
|
|
per_pos_token_ids = seq_token_ids[i]
|
|
per_pos_logprobs = seq_logprobs[i]
|
|
if (
|
|
isinstance(per_pos_token_ids, list)
|
|
and isinstance(per_pos_logprobs, list)
|
|
and len(per_pos_token_ids) == len(per_pos_logprobs)
|
|
):
|
|
local_k = 1
|
|
for row_ids in per_pos_token_ids:
|
|
if isinstance(row_ids, list):
|
|
local_k = max(local_k, len(row_ids))
|
|
max_distill_k = max(max_distill_k, local_k)
|
|
has_any_distill = True
|
|
|
|
rows = max(0, token_setup_len - 1)
|
|
token_mat = np.full((rows, local_k), -1, dtype=np.int64)
|
|
logprob_mat = np.full(
|
|
(rows, local_k), -1e9, dtype=np.float32
|
|
)
|
|
|
|
# Shift by one to align with causal labels like inference_logprobs.
|
|
copy_positions = min(
|
|
len(per_pos_token_ids), len(per_pos_logprobs), token_setup_len
|
|
)
|
|
for pos in range(1, copy_positions):
|
|
src_ids = per_pos_token_ids[pos]
|
|
src_lps = per_pos_logprobs[pos]
|
|
if not isinstance(src_ids, list) or not isinstance(src_lps, list):
|
|
continue
|
|
topk = min(local_k, len(src_ids), len(src_lps))
|
|
if topk <= 0:
|
|
continue
|
|
token_mat[pos - 1, :topk] = np.array(src_ids[:topk], dtype=np.int64)
|
|
logprob_mat[pos - 1, :topk] = np.array(
|
|
src_lps[:topk], dtype=np.float32
|
|
)
|
|
|
|
distill_token_ids_padded.append(token_mat)
|
|
distill_logprobs_padded.append(logprob_mat)
|
|
else:
|
|
rows = max(0, token_setup_len - 1)
|
|
distill_token_ids_padded.append(
|
|
np.full((rows, 1), -1, dtype=np.int64)
|
|
)
|
|
distill_logprobs_padded.append(
|
|
np.full((rows, 1), -1e9, dtype=np.float32)
|
|
)
|
|
else:
|
|
rows = max(0, token_setup_len - 1)
|
|
distill_token_ids_padded.append(np.full((rows, 1), -1, dtype=np.int64))
|
|
distill_logprobs_padded.append(
|
|
np.full((rows, 1), -1e9, dtype=np.float32)
|
|
)
|
|
else:
|
|
rows = max(0, token_setup_len - 1)
|
|
distill_token_ids_padded.append(np.full((rows, 1), -1, dtype=np.int64))
|
|
distill_logprobs_padded.append(np.full((rows, 1), -1e9, dtype=np.float32))
|
|
|
|
# Extract temperature (priority: override > generation_params > group_overrides > 1.0)
|
|
t = 1.0
|
|
if (
|
|
item.get("overrides")
|
|
and i < len(item["overrides"])
|
|
and isinstance(item["overrides"][i], dict)
|
|
and ("temperature" in item["overrides"][i])
|
|
):
|
|
t = float(item["overrides"][i]["temperature"])
|
|
elif item.get("generation_params") and (
|
|
"temperature" in item["generation_params"]
|
|
):
|
|
t = float(item["generation_params"]["temperature"])
|
|
elif item.get("group_overrides") and (
|
|
"temperature" in item["group_overrides"]
|
|
):
|
|
t = float(item["group_overrides"]["temperature"])
|
|
temperatures.append(t)
|
|
|
|
# Batch the data
|
|
token_batches = []
|
|
label_batches = []
|
|
advantage_batches = []
|
|
temperature_batches = []
|
|
inference_logprob_batches = []
|
|
distill_token_id_batches = []
|
|
distill_logprob_batches = []
|
|
|
|
for start in range(0, len(input_ids), batch_size):
|
|
end = min(start + batch_size, len(input_ids))
|
|
|
|
token_batches.append(torch.tensor(np.stack(input_ids[start:end], axis=0)))
|
|
label_batches.append(torch.tensor(np.stack(labels[start:end], axis=0)))
|
|
advantage_batches.append(
|
|
torch.tensor(np.stack(advantages[start:end], axis=0)).view(-1, 1)
|
|
)
|
|
temperature_batches.append(
|
|
torch.tensor(np.array(temperatures[start:end], dtype=np.float32)).view(
|
|
-1, 1, 1
|
|
)
|
|
)
|
|
|
|
# Batch inference logprobs (same shape as labels)
|
|
if extract_inference_logprobs and inference_logprobs_padded:
|
|
inference_logprob_batches.append(
|
|
torch.tensor(np.stack(inference_logprobs_padded[start:end], axis=0))
|
|
)
|
|
|
|
if distill_token_ids_padded and distill_logprobs_padded:
|
|
seq_slice_ids = distill_token_ids_padded[start:end]
|
|
seq_slice_lps = distill_logprobs_padded[start:end]
|
|
normalized_ids = []
|
|
normalized_lps = []
|
|
for ids_mat, lps_mat in zip(seq_slice_ids, seq_slice_lps):
|
|
if ids_mat.shape[1] < max_distill_k:
|
|
pad_cols = max_distill_k - ids_mat.shape[1]
|
|
ids_mat = np.pad(
|
|
ids_mat, ((0, 0), (0, pad_cols)), constant_values=-1
|
|
)
|
|
lps_mat = np.pad(
|
|
lps_mat, ((0, 0), (0, pad_cols)), constant_values=-1e9
|
|
)
|
|
normalized_ids.append(ids_mat)
|
|
normalized_lps.append(lps_mat)
|
|
|
|
distill_token_id_batches.append(
|
|
torch.tensor(np.stack(normalized_ids, axis=0), dtype=torch.long)
|
|
)
|
|
distill_logprob_batches.append(
|
|
torch.tensor(np.stack(normalized_lps, axis=0), dtype=torch.float32)
|
|
)
|
|
|
|
# Return inference logprob batches if we have any real logprobs
|
|
final_logprob_batches = (
|
|
inference_logprob_batches
|
|
if (has_any_logprobs and inference_logprob_batches)
|
|
else None
|
|
)
|
|
final_distill_token_id_batches = (
|
|
distill_token_id_batches if (has_any_distill and distill_token_id_batches) else None
|
|
)
|
|
final_distill_logprob_batches = (
|
|
distill_logprob_batches if (has_any_distill and distill_logprob_batches) else None
|
|
)
|
|
|
|
return (
|
|
token_batches,
|
|
label_batches,
|
|
advantage_batches,
|
|
temperature_batches,
|
|
final_logprob_batches,
|
|
final_distill_token_id_batches,
|
|
final_distill_logprob_batches,
|
|
)
|
|
|
|
|
|
def get_data(
|
|
batch_size: int,
|
|
seq_len: int,
|
|
atropos_url: str = "http://localhost:8000",
|
|
extract_inference_logprobs: bool = True,
|
|
) -> Tuple[
|
|
List[
|
|
Tuple[
|
|
List[torch.Tensor], # token_batches
|
|
List[torch.Tensor], # label_batches
|
|
List[torch.Tensor], # advantage_batches
|
|
List[torch.Tensor], # temperature_batches
|
|
Optional[List[torch.Tensor]], # inference_logprob_batches
|
|
Optional[List[torch.Tensor]], # distill_token_id_batches
|
|
Optional[List[torch.Tensor]], # distill_logprob_batches
|
|
]
|
|
],
|
|
None, # Legacy return (no longer used)
|
|
]:
|
|
"""
|
|
Fetch and process training data from the Atropos API.
|
|
|
|
Continuously polls the API until data is available, then processes
|
|
all available batches.
|
|
|
|
Args:
|
|
batch_size: Size of each training batch
|
|
seq_len: Maximum sequence length (for reference, not used directly)
|
|
atropos_url: URL of the Atropos API server
|
|
extract_inference_logprobs: Whether to extract inference logprobs for GRPO loss
|
|
|
|
Returns:
|
|
Tuple of (batches, None)
|
|
- batches: List of processed batch tuples, each containing:
|
|
(token_batches, label_batches, advantage_batches, temperature_batches, inference_logprob_batches)
|
|
- inference_logprob_batches are aligned with labels for proper GRPO loss computation
|
|
"""
|
|
batches = []
|
|
_logged_logprob_warning = False
|
|
|
|
while True:
|
|
data = get_batch(url=atropos_url)
|
|
|
|
if data["batch"] is not None:
|
|
# DEBUG: Check if inference_logprobs exists in the data
|
|
if not _logged_logprob_warning:
|
|
has_logprobs = any(
|
|
"inference_logprobs" in item for item in data["batch"]
|
|
)
|
|
if has_logprobs:
|
|
# Check if they're non-empty
|
|
sample_item = next(
|
|
(
|
|
item
|
|
for item in data["batch"]
|
|
if "inference_logprobs" in item
|
|
),
|
|
None,
|
|
)
|
|
if sample_item and sample_item.get("inference_logprobs"):
|
|
sample_lp = (
|
|
sample_item["inference_logprobs"][0]
|
|
if sample_item["inference_logprobs"]
|
|
else []
|
|
)
|
|
print(
|
|
f" [Data] ✓ inference_logprobs found in batch (sample len: {len(sample_lp)})"
|
|
)
|
|
else:
|
|
print(
|
|
" [Data] ⚠ inference_logprobs key exists but is empty!"
|
|
)
|
|
else:
|
|
print(" [Data] ⚠ NO inference_logprobs in batch data!")
|
|
print(
|
|
f" [Data] Keys in first item: {list(data['batch'][0].keys())}"
|
|
)
|
|
_logged_logprob_warning = True
|
|
|
|
# Process and accumulate batches (now includes batched inference logprobs)
|
|
(
|
|
token_batches,
|
|
label_batches,
|
|
adv_batches,
|
|
temp_batches,
|
|
inf_logprob_batches,
|
|
distill_token_id_batches,
|
|
distill_logprob_batches,
|
|
) = pad_data_to_good_offset(data, batch_size, extract_inference_logprobs)
|
|
|
|
# Include inference logprob batches in the tuple
|
|
batches.append(
|
|
(
|
|
token_batches,
|
|
label_batches,
|
|
adv_batches,
|
|
temp_batches,
|
|
inf_logprob_batches,
|
|
distill_token_id_batches,
|
|
distill_logprob_batches,
|
|
)
|
|
)
|
|
|
|
elif len(batches) > 0:
|
|
# Return accumulated batches when no more data
|
|
return batches, None
|
|
else:
|
|
# Wait for data
|
|
time.sleep(1)
|