[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2026-03-13 15:14:05 +00:00
parent d8857eb69f
commit d1b0dee8f7
7 changed files with 53 additions and 19 deletions

View file

@ -189,23 +189,27 @@ def pad_data_to_good_offset(
rows = max(0, token_setup_len - 1)
token_mat = np.full((rows, local_k), -1, dtype=np.int64)
logprob_mat = np.full(
(rows, local_k), -1e9, dtype=np.float32
)
logprob_mat = np.full((rows, local_k), -1e9, dtype=np.float32)
# Shift by one to align with causal labels like inference_logprobs.
copy_positions = min(
len(per_pos_token_ids), len(per_pos_logprobs), token_setup_len
len(per_pos_token_ids),
len(per_pos_logprobs),
token_setup_len,
)
for pos in range(1, copy_positions):
src_ids = per_pos_token_ids[pos]
src_lps = per_pos_logprobs[pos]
if not isinstance(src_ids, list) or not isinstance(src_lps, list):
if not isinstance(src_ids, list) or not isinstance(
src_lps, list
):
continue
topk = min(local_k, len(src_ids), len(src_lps))
if topk <= 0:
continue
token_mat[pos - 1, :topk] = np.array(src_ids[:topk], dtype=np.int64)
token_mat[pos - 1, :topk] = np.array(
src_ids[:topk], dtype=np.int64
)
logprob_mat[pos - 1, :topk] = np.array(
src_lps[:topk], dtype=np.float32
)
@ -222,14 +226,18 @@ def pad_data_to_good_offset(
)
else:
rows = max(0, token_setup_len - 1)
distill_token_ids_padded.append(np.full((rows, 1), -1, dtype=np.int64))
distill_token_ids_padded.append(
np.full((rows, 1), -1, dtype=np.int64)
)
distill_logprobs_padded.append(
np.full((rows, 1), -1e9, dtype=np.float32)
)
else:
rows = max(0, token_setup_len - 1)
distill_token_ids_padded.append(np.full((rows, 1), -1, dtype=np.int64))
distill_logprobs_padded.append(np.full((rows, 1), -1e9, dtype=np.float32))
distill_logprobs_padded.append(
np.full((rows, 1), -1e9, dtype=np.float32)
)
# Extract temperature (priority: override > generation_params > group_overrides > 1.0)
t = 1.0
@ -310,10 +318,14 @@ def pad_data_to_good_offset(
else None
)
final_distill_token_id_batches = (
distill_token_id_batches if (has_any_distill and distill_token_id_batches) else None
distill_token_id_batches
if (has_any_distill and distill_token_id_batches)
else None
)
final_distill_logprob_batches = (
distill_logprob_batches if (has_any_distill and distill_logprob_batches) else None
distill_logprob_batches
if (has_any_distill and distill_logprob_batches)
else None
)
return (