mirror of
https://github.com/NousResearch/atropos.git
synced 2026-05-02 17:45:50 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
d8857eb69f
commit
d1b0dee8f7
7 changed files with 53 additions and 19 deletions
|
|
@ -311,7 +311,10 @@ def compute_distillation_loss(
|
|||
if distill_token_ids.dim() != 3 or distill_logprobs.dim() != 3:
|
||||
return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0
|
||||
|
||||
if distill_token_ids.shape[:2] != labels.shape or distill_logprobs.shape != distill_token_ids.shape:
|
||||
if (
|
||||
distill_token_ids.shape[:2] != labels.shape
|
||||
or distill_logprobs.shape != distill_token_ids.shape
|
||||
):
|
||||
return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0
|
||||
|
||||
temp = max(1e-6, float(temperature))
|
||||
|
|
@ -330,7 +333,9 @@ def compute_distillation_loss(
|
|||
# Output tensors are [batch, seq_len, k] (tiny) not [batch, seq_len, vocab].
|
||||
scaled_logits = logits / temp
|
||||
log_normalizer = torch.logsumexp(scaled_logits, dim=-1, keepdim=True) # [b, s, 1]
|
||||
student_logp_topk = torch.gather(scaled_logits, dim=-1, index=gather_ids) - log_normalizer
|
||||
student_logp_topk = (
|
||||
torch.gather(scaled_logits, dim=-1, index=gather_ids) - log_normalizer
|
||||
)
|
||||
|
||||
masked_teacher_logprobs = distill_logprobs.masked_fill(~valid_ids, -1e9)
|
||||
teacher_probs = F.softmax(masked_teacher_logprobs / temp, dim=-1)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue