mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
d8857eb69f
commit
d1b0dee8f7
7 changed files with 53 additions and 19 deletions
|
|
@ -189,23 +189,27 @@ def pad_data_to_good_offset(
|
|||
|
||||
rows = max(0, token_setup_len - 1)
|
||||
token_mat = np.full((rows, local_k), -1, dtype=np.int64)
|
||||
logprob_mat = np.full(
|
||||
(rows, local_k), -1e9, dtype=np.float32
|
||||
)
|
||||
logprob_mat = np.full((rows, local_k), -1e9, dtype=np.float32)
|
||||
|
||||
# Shift by one to align with causal labels like inference_logprobs.
|
||||
copy_positions = min(
|
||||
len(per_pos_token_ids), len(per_pos_logprobs), token_setup_len
|
||||
len(per_pos_token_ids),
|
||||
len(per_pos_logprobs),
|
||||
token_setup_len,
|
||||
)
|
||||
for pos in range(1, copy_positions):
|
||||
src_ids = per_pos_token_ids[pos]
|
||||
src_lps = per_pos_logprobs[pos]
|
||||
if not isinstance(src_ids, list) or not isinstance(src_lps, list):
|
||||
if not isinstance(src_ids, list) or not isinstance(
|
||||
src_lps, list
|
||||
):
|
||||
continue
|
||||
topk = min(local_k, len(src_ids), len(src_lps))
|
||||
if topk <= 0:
|
||||
continue
|
||||
token_mat[pos - 1, :topk] = np.array(src_ids[:topk], dtype=np.int64)
|
||||
token_mat[pos - 1, :topk] = np.array(
|
||||
src_ids[:topk], dtype=np.int64
|
||||
)
|
||||
logprob_mat[pos - 1, :topk] = np.array(
|
||||
src_lps[:topk], dtype=np.float32
|
||||
)
|
||||
|
|
@ -222,14 +226,18 @@ def pad_data_to_good_offset(
|
|||
)
|
||||
else:
|
||||
rows = max(0, token_setup_len - 1)
|
||||
distill_token_ids_padded.append(np.full((rows, 1), -1, dtype=np.int64))
|
||||
distill_token_ids_padded.append(
|
||||
np.full((rows, 1), -1, dtype=np.int64)
|
||||
)
|
||||
distill_logprobs_padded.append(
|
||||
np.full((rows, 1), -1e9, dtype=np.float32)
|
||||
)
|
||||
else:
|
||||
rows = max(0, token_setup_len - 1)
|
||||
distill_token_ids_padded.append(np.full((rows, 1), -1, dtype=np.int64))
|
||||
distill_logprobs_padded.append(np.full((rows, 1), -1e9, dtype=np.float32))
|
||||
distill_logprobs_padded.append(
|
||||
np.full((rows, 1), -1e9, dtype=np.float32)
|
||||
)
|
||||
|
||||
# Extract temperature (priority: override > generation_params > group_overrides > 1.0)
|
||||
t = 1.0
|
||||
|
|
@ -310,10 +318,14 @@ def pad_data_to_good_offset(
|
|||
else None
|
||||
)
|
||||
final_distill_token_id_batches = (
|
||||
distill_token_id_batches if (has_any_distill and distill_token_id_batches) else None
|
||||
distill_token_id_batches
|
||||
if (has_any_distill and distill_token_id_batches)
|
||||
else None
|
||||
)
|
||||
final_distill_logprob_batches = (
|
||||
distill_logprob_batches if (has_any_distill and distill_logprob_batches) else None
|
||||
distill_logprob_batches
|
||||
if (has_any_distill and distill_logprob_batches)
|
||||
else None
|
||||
)
|
||||
|
||||
return (
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue