[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2026-03-13 15:14:05 +00:00
parent d8857eb69f
commit d1b0dee8f7
7 changed files with 53 additions and 19 deletions

View file

@ -311,7 +311,10 @@ def compute_distillation_loss(
if distill_token_ids.dim() != 3 or distill_logprobs.dim() != 3:
return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0
if distill_token_ids.shape[:2] != labels.shape or distill_logprobs.shape != distill_token_ids.shape:
if (
distill_token_ids.shape[:2] != labels.shape
or distill_logprobs.shape != distill_token_ids.shape
):
return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0
temp = max(1e-6, float(temperature))
@ -330,7 +333,9 @@ def compute_distillation_loss(
# Output tensors are [batch, seq_len, k] (tiny) not [batch, seq_len, vocab].
scaled_logits = logits / temp
log_normalizer = torch.logsumexp(scaled_logits, dim=-1, keepdim=True) # [b, s, 1]
student_logp_topk = torch.gather(scaled_logits, dim=-1, index=gather_ids) - log_normalizer
student_logp_topk = (
torch.gather(scaled_logits, dim=-1, index=gather_ids) - log_normalizer
)
masked_teacher_logprobs = distill_logprobs.masked_fill(~valid_ids, -1e9)
teacher_probs = F.softmax(masked_teacher_logprobs / temp, dim=-1)