mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-28 17:29:30 +00:00
remove training code
This commit is contained in:
parent
862cd3667d
commit
148a4fd5eb
6 changed files with 38 additions and 329 deletions
|
|
@ -70,11 +70,6 @@ def compute_grpo_loss(
|
|||
gradient_accumulation_steps: int,
|
||||
inference_logprobs: Optional[torch.Tensor] = None,
|
||||
clip_eps: float = 0.2,
|
||||
distill_token_ids: Optional[torch.Tensor] = None,
|
||||
distill_logprobs: Optional[torch.Tensor] = None,
|
||||
distill_enabled: bool = False,
|
||||
distill_coef: float = 0.0,
|
||||
distill_temperature: float = 1.0,
|
||||
) -> Tuple[torch.Tensor, dict]:
|
||||
"""
|
||||
Compute GRPO (Group Relative Policy Optimization) loss for a single micro-batch.
|
||||
|
|
@ -130,9 +125,6 @@ def compute_grpo_loss(
|
|||
logprob_diff_abs_mean = 0.0
|
||||
logprob_diff_max = 0.0
|
||||
|
||||
distill_loss_value = torch.tensor(0.0, device=logp_per_token.device)
|
||||
distill_token_count = 0.0
|
||||
|
||||
# === GRPO/PPO Loss Computation ===
|
||||
if inference_logprobs is not None:
|
||||
# Move inference logprobs to correct device/dtype
|
||||
|
|
@ -195,23 +187,7 @@ def compute_grpo_loss(
|
|||
# Average over tokens, then over batch
|
||||
policy_loss = ((policy_loss_per_token * mask).sum(dim=-1) / mask_sum).mean()
|
||||
|
||||
if (
|
||||
distill_enabled
|
||||
and distill_coef > 0
|
||||
and distill_token_ids is not None
|
||||
and distill_logprobs is not None
|
||||
):
|
||||
distill_loss_value, distill_token_count = compute_distillation_loss(
|
||||
logits=scaled_logits,
|
||||
labels=labels,
|
||||
distill_token_ids=distill_token_ids.to(logits.device),
|
||||
distill_logprobs=distill_logprobs.to(logits.device, logits.dtype),
|
||||
temperature=max(1e-6, float(distill_temperature)),
|
||||
)
|
||||
|
||||
total_loss = (policy_loss + distill_coef * distill_loss_value) / (
|
||||
gradient_accumulation_steps
|
||||
)
|
||||
total_loss = policy_loss / gradient_accumulation_steps
|
||||
|
||||
# Compute metrics for logging
|
||||
with torch.no_grad():
|
||||
|
|
@ -277,77 +253,11 @@ def compute_grpo_loss(
|
|||
"logprob_diff_mean": logprob_diff_mean,
|
||||
"logprob_diff_abs_mean": logprob_diff_abs_mean,
|
||||
"logprob_diff_max": logprob_diff_max,
|
||||
"distill_loss": (
|
||||
distill_loss_value.item()
|
||||
if torch.is_tensor(distill_loss_value)
|
||||
else float(distill_loss_value)
|
||||
),
|
||||
"distill_token_count": distill_token_count,
|
||||
}
|
||||
|
||||
return total_loss, metrics
|
||||
|
||||
|
||||
def compute_distillation_loss(
|
||||
logits: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
distill_token_ids: torch.Tensor,
|
||||
distill_logprobs: torch.Tensor,
|
||||
temperature: float = 1.0,
|
||||
) -> Tuple[torch.Tensor, float]:
|
||||
"""
|
||||
Compute token-level distillation loss from teacher top-k prompt logprobs.
|
||||
|
||||
Args:
|
||||
logits: Student logits [batch, seq_len, vocab]
|
||||
labels: Labels [batch, seq_len], -100 for masked positions
|
||||
distill_token_ids: Teacher top-k token IDs [batch, seq_len, k], -1 padded
|
||||
distill_logprobs: Teacher top-k logprobs [batch, seq_len, k], very negative padded
|
||||
temperature: Distillation temperature
|
||||
|
||||
Returns:
|
||||
Tuple of (distillation loss scalar, valid token count)
|
||||
"""
|
||||
if distill_token_ids.dim() != 3 or distill_logprobs.dim() != 3:
|
||||
return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0
|
||||
|
||||
if (
|
||||
distill_token_ids.shape[:2] != labels.shape
|
||||
or distill_logprobs.shape != distill_token_ids.shape
|
||||
):
|
||||
return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0
|
||||
|
||||
temp = max(1e-6, float(temperature))
|
||||
|
||||
valid_ids = distill_token_ids >= 0
|
||||
label_mask = labels != -100
|
||||
valid_pos = label_mask & valid_ids.any(dim=-1)
|
||||
if not valid_pos.any():
|
||||
return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0
|
||||
|
||||
gather_ids = distill_token_ids.clamp_min(0).long()
|
||||
|
||||
# Avoid materializing the full [batch, seq_len, vocab] log_softmax tensor
|
||||
# (e.g. [2, 20480, 151936] = ~12.5 GB) which is the main cause of OOM/hangs.
|
||||
# Instead: gather raw logits at top-k positions, then subtract logsumexp.
|
||||
# Output tensors are [batch, seq_len, k] (tiny) not [batch, seq_len, vocab].
|
||||
scaled_logits = logits / temp
|
||||
log_normalizer = torch.logsumexp(scaled_logits, dim=-1, keepdim=True) # [b, s, 1]
|
||||
student_logp_topk = (
|
||||
torch.gather(scaled_logits, dim=-1, index=gather_ids) - log_normalizer
|
||||
)
|
||||
|
||||
masked_teacher_logprobs = distill_logprobs.masked_fill(~valid_ids, -1e9)
|
||||
teacher_probs = F.softmax(masked_teacher_logprobs / temp, dim=-1)
|
||||
|
||||
per_token_loss = -(teacher_probs * student_logp_topk).sum(dim=-1)
|
||||
per_token_loss = per_token_loss * valid_pos.to(per_token_loss.dtype)
|
||||
|
||||
token_count = valid_pos.sum().item()
|
||||
loss = per_token_loss.sum() / valid_pos.sum().clamp_min(1).to(per_token_loss.dtype)
|
||||
return loss, float(token_count)
|
||||
|
||||
|
||||
def run_training_step(
|
||||
model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
|
|
@ -358,8 +268,6 @@ def run_training_step(
|
|||
config: TrainingConfig,
|
||||
step_idx: int,
|
||||
inference_logprob_batches: Optional[List[torch.Tensor]] = None,
|
||||
distill_token_id_batches: Optional[List[torch.Tensor]] = None,
|
||||
distill_logprob_batches: Optional[List[torch.Tensor]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Run a single training step with gradient accumulation.
|
||||
|
|
@ -394,8 +302,6 @@ def run_training_step(
|
|||
total_logprob_diff_mean = 0.0
|
||||
total_logprob_diff_abs_mean = 0.0
|
||||
total_logprob_diff_max = 0.0
|
||||
total_distill_loss = 0.0
|
||||
total_distill_tokens = 0.0
|
||||
grad_norm = 0.0
|
||||
all_training_logprobs: List[torch.Tensor] = []
|
||||
all_inference_logprobs: List[torch.Tensor] = []
|
||||
|
|
@ -419,13 +325,6 @@ def run_training_step(
|
|||
for batch_idx, (tokens, labels, advantages, temperatures) in enumerate(
|
||||
zip(token_batches, label_batches, advantage_batches, temperature_batches)
|
||||
):
|
||||
print(
|
||||
f" [Step] micro-batch {batch_idx+1}/{num_batches} "
|
||||
f"tokens={tokens.shape} "
|
||||
f"gpu_mem={torch.cuda.memory_allocated()/1e9:.1f}GB "
|
||||
f"gpu_reserved={torch.cuda.memory_reserved()/1e9:.1f}GB",
|
||||
flush=True,
|
||||
)
|
||||
tokens = tokens.to(config.device)
|
||||
labels = labels.to(config.device)
|
||||
advantages = advantages.to(config.device)
|
||||
|
|
@ -436,18 +335,7 @@ def run_training_step(
|
|||
inference_logprob_batches
|
||||
):
|
||||
inf_logprobs = inference_logprob_batches[batch_idx]
|
||||
distill_ids = None
|
||||
if distill_token_id_batches is not None and batch_idx < len(
|
||||
distill_token_id_batches
|
||||
):
|
||||
distill_ids = distill_token_id_batches[batch_idx]
|
||||
distill_lps = None
|
||||
if distill_logprob_batches is not None and batch_idx < len(
|
||||
distill_logprob_batches
|
||||
):
|
||||
distill_lps = distill_logprob_batches[batch_idx]
|
||||
|
||||
print(f" [Step] micro-batch {batch_idx+1} forward pass...", flush=True)
|
||||
loss, metrics = compute_grpo_loss(
|
||||
model,
|
||||
tokens,
|
||||
|
|
@ -457,20 +345,9 @@ def run_training_step(
|
|||
config.gradient_accumulation_steps,
|
||||
inference_logprobs=inf_logprobs,
|
||||
clip_eps=clip_eps,
|
||||
distill_token_ids=distill_ids,
|
||||
distill_logprobs=distill_lps,
|
||||
distill_enabled=bool(getattr(config, "distill_enabled", False)),
|
||||
distill_coef=float(getattr(config, "distill_coef", 0.0)),
|
||||
distill_temperature=float(getattr(config, "distill_temperature", 1.0)),
|
||||
)
|
||||
|
||||
print(
|
||||
f" [Step] micro-batch {batch_idx+1} loss={loss.item():.4f} "
|
||||
f"backward...",
|
||||
flush=True,
|
||||
)
|
||||
loss.backward()
|
||||
print(f" [Step] micro-batch {batch_idx+1} backward done", flush=True)
|
||||
total_loss += loss.item()
|
||||
total_pos_logp += metrics["pos_logp"]
|
||||
total_neg_logp += metrics["neg_logp"]
|
||||
|
|
@ -487,8 +364,6 @@ def run_training_step(
|
|||
total_logprob_diff_max = max(
|
||||
total_logprob_diff_max, metrics.get("logprob_diff_max", 0.0)
|
||||
)
|
||||
total_distill_loss += metrics.get("distill_loss", 0.0)
|
||||
total_distill_tokens += metrics.get("distill_token_count", 0.0)
|
||||
|
||||
# Collect logprobs for alignment monitoring
|
||||
if "training_logprobs" in metrics and metrics["training_logprobs"] is not None:
|
||||
|
|
@ -524,8 +399,6 @@ def run_training_step(
|
|||
# GRPO-specific metrics (averaged over batches)
|
||||
"mean_ratio": total_mean_ratio / num_batches,
|
||||
"clipped_fraction": total_clipped_fraction / num_batches,
|
||||
"distill_loss": total_distill_loss / num_batches,
|
||||
"distill_token_count": total_distill_tokens,
|
||||
}
|
||||
|
||||
# Compute logprob alignment stats for monitoring
|
||||
|
|
@ -599,12 +472,6 @@ def log_metrics(
|
|||
clipped_frac = metrics.get("clipped_fraction", 0)
|
||||
|
||||
print(f" GRPO: ratio={mean_ratio:.3f}, clipped={clipped_frac*100:.1f}%")
|
||||
if metrics.get("distill_token_count", 0) > 0:
|
||||
print(
|
||||
" Distill: "
|
||||
f"loss={metrics.get('distill_loss', 0.0):.4f}, "
|
||||
f"tokens={int(metrics.get('distill_token_count', 0))}"
|
||||
)
|
||||
|
||||
# Advantage distribution
|
||||
if "pos_count" in metrics or "neg_count" in metrics:
|
||||
|
|
@ -627,8 +494,6 @@ def log_metrics(
|
|||
# GRPO-specific metrics
|
||||
"grpo/mean_ratio": mean_ratio,
|
||||
"grpo/clipped_fraction": clipped_frac,
|
||||
"distill/loss": metrics.get("distill_loss", 0.0),
|
||||
"distill/token_count": metrics.get("distill_token_count", 0.0),
|
||||
}
|
||||
# Add timing metrics if present
|
||||
for key in [
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue