testing set up

This commit is contained in:
Jai Suphavadeeprasit 2026-03-06 14:49:32 -05:00
parent f44eb810bf
commit 530fed2877
8 changed files with 599 additions and 2 deletions

View file

@ -29,6 +29,8 @@ def pad_data_to_good_offset(
List[torch.Tensor], # advantage_batches
List[torch.Tensor], # temperature_batches
Optional[List[torch.Tensor]], # inference_logprob_batches (aligned with labels)
Optional[List[torch.Tensor]], # distill_token_id_batches [batch, seq, k]
Optional[List[torch.Tensor]], # distill_logprob_batches [batch, seq, k]
]:
"""
Pad and batch data from the Atropos API.
@ -45,7 +47,8 @@ def pad_data_to_good_offset(
extract_inference_logprobs: Whether to extract inference logprobs
Returns:
Tuple of (token_batches, label_batches, advantage_batches, temperature_batches, inference_logprob_batches)
Tuple of (token_batches, label_batches, advantage_batches, temperature_batches,
inference_logprob_batches, distill_token_id_batches, distill_logprob_batches)
inference_logprob_batches is None if extract_inference_logprobs=False or no logprobs in data
Note:
@ -73,6 +76,10 @@ def pad_data_to_good_offset(
temperatures = []
inference_logprobs_padded: List[np.ndarray] = [] # Padded to match labels shape
has_any_logprobs = False
distill_token_ids_padded: List[np.ndarray] = []
distill_logprobs_padded: List[np.ndarray] = []
has_any_distill = False
max_distill_k = 1
for item in data["batch"]:
# Normalize advantage scores
@ -153,6 +160,77 @@ def pad_data_to_good_offset(
np.full(token_setup_len - 1, 1.0, dtype=np.float32)
)
# Extract teacher distillation top-k arrays if available.
# Expected shape in incoming payload: [sequence][position][k].
if "distill_token_ids" in item and "distill_logprobs" in item:
seq_token_ids = item["distill_token_ids"]
seq_logprobs = item["distill_logprobs"]
if (
isinstance(seq_token_ids, list)
and isinstance(seq_logprobs, list)
and i < len(seq_token_ids)
and i < len(seq_logprobs)
and seq_token_ids[i] is not None
and seq_logprobs[i] is not None
):
per_pos_token_ids = seq_token_ids[i]
per_pos_logprobs = seq_logprobs[i]
if (
isinstance(per_pos_token_ids, list)
and isinstance(per_pos_logprobs, list)
and len(per_pos_token_ids) == len(per_pos_logprobs)
):
local_k = 1
for row_ids in per_pos_token_ids:
if isinstance(row_ids, list):
local_k = max(local_k, len(row_ids))
max_distill_k = max(max_distill_k, local_k)
has_any_distill = True
rows = max(0, token_setup_len - 1)
token_mat = np.full((rows, local_k), -1, dtype=np.int64)
logprob_mat = np.full(
(rows, local_k), -1e9, dtype=np.float32
)
# Shift by one to align with causal labels like inference_logprobs.
copy_positions = min(
len(per_pos_token_ids), len(per_pos_logprobs), token_setup_len
)
for pos in range(1, copy_positions):
src_ids = per_pos_token_ids[pos]
src_lps = per_pos_logprobs[pos]
if not isinstance(src_ids, list) or not isinstance(src_lps, list):
continue
topk = min(local_k, len(src_ids), len(src_lps))
if topk <= 0:
continue
token_mat[pos - 1, :topk] = np.array(src_ids[:topk], dtype=np.int64)
logprob_mat[pos - 1, :topk] = np.array(
src_lps[:topk], dtype=np.float32
)
distill_token_ids_padded.append(token_mat)
distill_logprobs_padded.append(logprob_mat)
else:
rows = max(0, token_setup_len - 1)
distill_token_ids_padded.append(
np.full((rows, 1), -1, dtype=np.int64)
)
distill_logprobs_padded.append(
np.full((rows, 1), -1e9, dtype=np.float32)
)
else:
rows = max(0, token_setup_len - 1)
distill_token_ids_padded.append(np.full((rows, 1), -1, dtype=np.int64))
distill_logprobs_padded.append(
np.full((rows, 1), -1e9, dtype=np.float32)
)
else:
rows = max(0, token_setup_len - 1)
distill_token_ids_padded.append(np.full((rows, 1), -1, dtype=np.int64))
distill_logprobs_padded.append(np.full((rows, 1), -1e9, dtype=np.float32))
# Extract temperature (priority: override > generation_params > group_overrides > 1.0)
t = 1.0
if (
@ -178,6 +256,8 @@ def pad_data_to_good_offset(
advantage_batches = []
temperature_batches = []
inference_logprob_batches = []
distill_token_id_batches = []
distill_logprob_batches = []
for start in range(0, len(input_ids), batch_size):
end = min(start + batch_size, len(input_ids))
@ -199,12 +279,42 @@ def pad_data_to_good_offset(
torch.tensor(np.stack(inference_logprobs_padded[start:end], axis=0))
)
if distill_token_ids_padded and distill_logprobs_padded:
seq_slice_ids = distill_token_ids_padded[start:end]
seq_slice_lps = distill_logprobs_padded[start:end]
normalized_ids = []
normalized_lps = []
for ids_mat, lps_mat in zip(seq_slice_ids, seq_slice_lps):
if ids_mat.shape[1] < max_distill_k:
pad_cols = max_distill_k - ids_mat.shape[1]
ids_mat = np.pad(
ids_mat, ((0, 0), (0, pad_cols)), constant_values=-1
)
lps_mat = np.pad(
lps_mat, ((0, 0), (0, pad_cols)), constant_values=-1e9
)
normalized_ids.append(ids_mat)
normalized_lps.append(lps_mat)
distill_token_id_batches.append(
torch.tensor(np.stack(normalized_ids, axis=0), dtype=torch.long)
)
distill_logprob_batches.append(
torch.tensor(np.stack(normalized_lps, axis=0), dtype=torch.float32)
)
# Return inference logprob batches if we have any real logprobs
final_logprob_batches = (
inference_logprob_batches
if (has_any_logprobs and inference_logprob_batches)
else None
)
final_distill_token_id_batches = (
distill_token_id_batches if (has_any_distill and distill_token_id_batches) else None
)
final_distill_logprob_batches = (
distill_logprob_batches if (has_any_distill and distill_logprob_batches) else None
)
return (
token_batches,
@ -212,6 +322,8 @@ def pad_data_to_good_offset(
advantage_batches,
temperature_batches,
final_logprob_batches,
final_distill_token_id_batches,
final_distill_logprob_batches,
)
@ -228,6 +340,8 @@ def get_data(
List[torch.Tensor], # advantage_batches
List[torch.Tensor], # temperature_batches
Optional[List[torch.Tensor]], # inference_logprob_batches
Optional[List[torch.Tensor]], # distill_token_id_batches
Optional[List[torch.Tensor]], # distill_logprob_batches
]
],
None, # Legacy return (no longer used)
@ -299,6 +413,8 @@ def get_data(
adv_batches,
temp_batches,
inf_logprob_batches,
distill_token_id_batches,
distill_logprob_batches,
) = pad_data_to_good_offset(data, batch_size, extract_inference_logprobs)
# Include inference logprob batches in the tuple
@ -309,6 +425,8 @@ def get_data(
adv_batches,
temp_batches,
inf_logprob_batches,
distill_token_id_batches,
distill_logprob_batches,
)
)