mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
ccdd5a1ca6
commit
60fb6cae11
11 changed files with 221 additions and 136 deletions
|
|
@ -63,9 +63,10 @@ def compute_distillation_loss(
|
|||
continue
|
||||
|
||||
ids_tensor = torch.tensor(pos_ids, device=logits.device, dtype=torch.long)
|
||||
teacher_lps = torch.tensor(
|
||||
pos_lps, device=logits.device, dtype=logits.dtype
|
||||
) / temperature
|
||||
teacher_lps = (
|
||||
torch.tensor(pos_lps, device=logits.device, dtype=logits.dtype)
|
||||
/ temperature
|
||||
)
|
||||
|
||||
student_log_probs = F.log_softmax(logits[b, t] / temperature, dim=-1)
|
||||
student_subset = student_log_probs[ids_tensor]
|
||||
|
|
@ -75,7 +76,9 @@ def compute_distillation_loss(
|
|||
token_loss = -(teacher_probs * student_subset).sum()
|
||||
else:
|
||||
teacher_log_probs = F.log_softmax(teacher_lps, dim=-1)
|
||||
token_loss = (teacher_probs * (teacher_log_probs - student_subset)).sum()
|
||||
token_loss = (
|
||||
teacher_probs * (teacher_log_probs - student_subset)
|
||||
).sum()
|
||||
|
||||
total = total + token_loss
|
||||
count = count + 1.0
|
||||
|
|
@ -323,7 +326,11 @@ def compute_grpo_loss(
|
|||
interpretable_loss = (avg_logp * advantages.squeeze()).mean().item()
|
||||
|
||||
distill_loss_val = 0.0
|
||||
if distillation_enabled and distill_token_ids is not None and distill_logprobs is not None:
|
||||
if (
|
||||
distillation_enabled
|
||||
and distill_token_ids is not None
|
||||
and distill_logprobs is not None
|
||||
):
|
||||
distill_loss = compute_distillation_loss(
|
||||
logits=scaled_logits,
|
||||
mask=mask,
|
||||
|
|
@ -332,7 +339,10 @@ def compute_grpo_loss(
|
|||
temperature=distillation_temperature,
|
||||
loss_type=distillation_loss_type,
|
||||
)
|
||||
total_loss = total_loss + (distillation_coef * distill_loss) / gradient_accumulation_steps
|
||||
total_loss = (
|
||||
total_loss
|
||||
+ (distillation_coef * distill_loss) / gradient_accumulation_steps
|
||||
)
|
||||
distill_loss_val = distill_loss.item()
|
||||
|
||||
metrics = {
|
||||
|
|
@ -437,9 +447,13 @@ def run_training_step(
|
|||
inf_logprobs = inference_logprob_batches[batch_idx]
|
||||
distill_ids = None
|
||||
distill_lps = None
|
||||
if distill_token_id_batches is not None and batch_idx < len(distill_token_id_batches):
|
||||
if distill_token_id_batches is not None and batch_idx < len(
|
||||
distill_token_id_batches
|
||||
):
|
||||
distill_ids = distill_token_id_batches[batch_idx]
|
||||
if distill_logprob_batches is not None and batch_idx < len(distill_logprob_batches):
|
||||
if distill_logprob_batches is not None and batch_idx < len(
|
||||
distill_logprob_batches
|
||||
):
|
||||
distill_lps = distill_logprob_batches[batch_idx]
|
||||
|
||||
loss, metrics = compute_grpo_loss(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue