[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2026-02-20 04:58:43 +00:00
parent ccdd5a1ca6
commit 60fb6cae11
11 changed files with 221 additions and 136 deletions

View file

@ -63,9 +63,10 @@ def compute_distillation_loss(
continue
ids_tensor = torch.tensor(pos_ids, device=logits.device, dtype=torch.long)
teacher_lps = torch.tensor(
pos_lps, device=logits.device, dtype=logits.dtype
) / temperature
teacher_lps = (
torch.tensor(pos_lps, device=logits.device, dtype=logits.dtype)
/ temperature
)
student_log_probs = F.log_softmax(logits[b, t] / temperature, dim=-1)
student_subset = student_log_probs[ids_tensor]
@ -75,7 +76,9 @@ def compute_distillation_loss(
token_loss = -(teacher_probs * student_subset).sum()
else:
teacher_log_probs = F.log_softmax(teacher_lps, dim=-1)
token_loss = (teacher_probs * (teacher_log_probs - student_subset)).sum()
token_loss = (
teacher_probs * (teacher_log_probs - student_subset)
).sum()
total = total + token_loss
count = count + 1.0
@ -323,7 +326,11 @@ def compute_grpo_loss(
interpretable_loss = (avg_logp * advantages.squeeze()).mean().item()
distill_loss_val = 0.0
if distillation_enabled and distill_token_ids is not None and distill_logprobs is not None:
if (
distillation_enabled
and distill_token_ids is not None
and distill_logprobs is not None
):
distill_loss = compute_distillation_loss(
logits=scaled_logits,
mask=mask,
@ -332,7 +339,10 @@ def compute_grpo_loss(
temperature=distillation_temperature,
loss_type=distillation_loss_type,
)
total_loss = total_loss + (distillation_coef * distill_loss) / gradient_accumulation_steps
total_loss = (
total_loss
+ (distillation_coef * distill_loss) / gradient_accumulation_steps
)
distill_loss_val = distill_loss.item()
metrics = {
@ -437,9 +447,13 @@ def run_training_step(
inf_logprobs = inference_logprob_batches[batch_idx]
distill_ids = None
distill_lps = None
if distill_token_id_batches is not None and batch_idx < len(distill_token_id_batches):
if distill_token_id_batches is not None and batch_idx < len(
distill_token_id_batches
):
distill_ids = distill_token_id_batches[batch_idx]
if distill_logprob_batches is not None and batch_idx < len(distill_logprob_batches):
if distill_logprob_batches is not None and batch_idx < len(
distill_logprob_batches
):
distill_lps = distill_logprob_batches[batch_idx]
loss, metrics = compute_grpo_loss(