remove training code

This commit is contained in:
Jai Suphavadeeprasit 2026-03-13 12:52:52 -04:00
parent 862cd3667d
commit 148a4fd5eb
6 changed files with 38 additions and 329 deletions

View file

@ -163,23 +163,6 @@ def add_grpo_args(parser: argparse.ArgumentParser) -> None:
default=0.2,
help="PPO-style clipping epsilon. Clips ratio to [1-eps, 1+eps].",
)
group.add_argument(
"--distill-enabled",
action="store_true",
help="Enable teacher distillation loss (requires distill payload in Atropos batch).",
)
group.add_argument(
"--distill-coef",
type=float,
default=0.0,
help="Coefficient for distillation loss term.",
)
group.add_argument(
"--distill-temperature",
type=float,
default=1.0,
help="Temperature for teacher top-k distribution in distillation loss.",
)
def add_vllm_args(parser: argparse.ArgumentParser) -> None:
@ -441,9 +424,6 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig:
checkpoint_interval=getattr(args, "checkpoint_interval", 3),
# GRPO/PPO hyperparameters
clip_eps=getattr(args, "clip_eps", 0.2),
distill_enabled=getattr(args, "distill_enabled", False),
distill_coef=getattr(args, "distill_coef", 0.0),
distill_temperature=getattr(args, "distill_temperature", 1.0),
adafactor_scale_parameter=getattr(args, "adafactor_scale_parameter", False),
adafactor_relative_step=getattr(args, "adafactor_relative_step", False),
# vLLM settings

View file

@ -69,18 +69,6 @@ class TrainingConfig(BaseModel):
"Prevents large policy updates that could destabilize training."
),
)
distill_enabled: bool = Field(
False,
description="Enable teacher distillation loss when distill tensors are present.",
)
distill_coef: float = Field(
0.0,
description="Weight for distillation loss in total loss.",
)
distill_temperature: float = Field(
1.0,
description="Temperature applied when converting teacher top-k logprobs.",
)
# === Device & Storage ===
device: str = Field(
"cuda" if torch.cuda.is_available() else "cpu", description="Device to train on"

View file

@ -29,8 +29,6 @@ def pad_data_to_good_offset(
List[torch.Tensor], # advantage_batches
List[torch.Tensor], # temperature_batches
Optional[List[torch.Tensor]], # inference_logprob_batches (aligned with labels)
Optional[List[torch.Tensor]], # distill_token_id_batches [batch, seq, k]
Optional[List[torch.Tensor]], # distill_logprob_batches [batch, seq, k]
]:
"""
Pad and batch data from the Atropos API.
@ -47,8 +45,7 @@ def pad_data_to_good_offset(
extract_inference_logprobs: Whether to extract inference logprobs
Returns:
Tuple of (token_batches, label_batches, advantage_batches, temperature_batches,
inference_logprob_batches, distill_token_id_batches, distill_logprob_batches)
Tuple of (token_batches, label_batches, advantage_batches, temperature_batches, inference_logprob_batches)
inference_logprob_batches is None if extract_inference_logprobs=False or no logprobs in data
Note:
@ -76,10 +73,6 @@ def pad_data_to_good_offset(
temperatures = []
inference_logprobs_padded: List[np.ndarray] = [] # Padded to match labels shape
has_any_logprobs = False
distill_token_ids_padded: List[np.ndarray] = []
distill_logprobs_padded: List[np.ndarray] = []
has_any_distill = False
max_distill_k = 1
for item in data["batch"]:
# Normalize advantage scores
@ -160,85 +153,6 @@ def pad_data_to_good_offset(
np.full(token_setup_len - 1, 1.0, dtype=np.float32)
)
# Extract teacher distillation top-k arrays if available.
# Expected shape in incoming payload: [sequence][position][k].
if "distill_token_ids" in item and "distill_logprobs" in item:
seq_token_ids = item["distill_token_ids"]
seq_logprobs = item["distill_logprobs"]
if (
isinstance(seq_token_ids, list)
and isinstance(seq_logprobs, list)
and i < len(seq_token_ids)
and i < len(seq_logprobs)
and seq_token_ids[i] is not None
and seq_logprobs[i] is not None
):
per_pos_token_ids = seq_token_ids[i]
per_pos_logprobs = seq_logprobs[i]
if (
isinstance(per_pos_token_ids, list)
and isinstance(per_pos_logprobs, list)
and len(per_pos_token_ids) == len(per_pos_logprobs)
):
local_k = 1
for row_ids in per_pos_token_ids:
if isinstance(row_ids, list):
local_k = max(local_k, len(row_ids))
max_distill_k = max(max_distill_k, local_k)
has_any_distill = True
rows = max(0, token_setup_len - 1)
token_mat = np.full((rows, local_k), -1, dtype=np.int64)
logprob_mat = np.full((rows, local_k), -1e9, dtype=np.float32)
# Shift by one to align with causal labels like inference_logprobs.
copy_positions = min(
len(per_pos_token_ids),
len(per_pos_logprobs),
token_setup_len,
)
for pos in range(1, copy_positions):
src_ids = per_pos_token_ids[pos]
src_lps = per_pos_logprobs[pos]
if not isinstance(src_ids, list) or not isinstance(
src_lps, list
):
continue
topk = min(local_k, len(src_ids), len(src_lps))
if topk <= 0:
continue
token_mat[pos - 1, :topk] = np.array(
src_ids[:topk], dtype=np.int64
)
logprob_mat[pos - 1, :topk] = np.array(
src_lps[:topk], dtype=np.float32
)
distill_token_ids_padded.append(token_mat)
distill_logprobs_padded.append(logprob_mat)
else:
rows = max(0, token_setup_len - 1)
distill_token_ids_padded.append(
np.full((rows, 1), -1, dtype=np.int64)
)
distill_logprobs_padded.append(
np.full((rows, 1), -1e9, dtype=np.float32)
)
else:
rows = max(0, token_setup_len - 1)
distill_token_ids_padded.append(
np.full((rows, 1), -1, dtype=np.int64)
)
distill_logprobs_padded.append(
np.full((rows, 1), -1e9, dtype=np.float32)
)
else:
rows = max(0, token_setup_len - 1)
distill_token_ids_padded.append(np.full((rows, 1), -1, dtype=np.int64))
distill_logprobs_padded.append(
np.full((rows, 1), -1e9, dtype=np.float32)
)
# Extract temperature (priority: override > generation_params > group_overrides > 1.0)
t = 1.0
if (
@ -264,8 +178,6 @@ def pad_data_to_good_offset(
advantage_batches = []
temperature_batches = []
inference_logprob_batches = []
distill_token_id_batches = []
distill_logprob_batches = []
for start in range(0, len(input_ids), batch_size):
end = min(start + batch_size, len(input_ids))
@ -287,46 +199,12 @@ def pad_data_to_good_offset(
torch.tensor(np.stack(inference_logprobs_padded[start:end], axis=0))
)
if distill_token_ids_padded and distill_logprobs_padded:
seq_slice_ids = distill_token_ids_padded[start:end]
seq_slice_lps = distill_logprobs_padded[start:end]
normalized_ids = []
normalized_lps = []
for ids_mat, lps_mat in zip(seq_slice_ids, seq_slice_lps):
if ids_mat.shape[1] < max_distill_k:
pad_cols = max_distill_k - ids_mat.shape[1]
ids_mat = np.pad(
ids_mat, ((0, 0), (0, pad_cols)), constant_values=-1
)
lps_mat = np.pad(
lps_mat, ((0, 0), (0, pad_cols)), constant_values=-1e9
)
normalized_ids.append(ids_mat)
normalized_lps.append(lps_mat)
distill_token_id_batches.append(
torch.tensor(np.stack(normalized_ids, axis=0), dtype=torch.long)
)
distill_logprob_batches.append(
torch.tensor(np.stack(normalized_lps, axis=0), dtype=torch.float32)
)
# Return inference logprob batches if we have any real logprobs
final_logprob_batches = (
inference_logprob_batches
if (has_any_logprobs and inference_logprob_batches)
else None
)
final_distill_token_id_batches = (
distill_token_id_batches
if (has_any_distill and distill_token_id_batches)
else None
)
final_distill_logprob_batches = (
distill_logprob_batches
if (has_any_distill and distill_logprob_batches)
else None
)
return (
token_batches,
@ -334,8 +212,6 @@ def pad_data_to_good_offset(
advantage_batches,
temperature_batches,
final_logprob_batches,
final_distill_token_id_batches,
final_distill_logprob_batches,
)
@ -352,8 +228,6 @@ def get_data(
List[torch.Tensor], # advantage_batches
List[torch.Tensor], # temperature_batches
Optional[List[torch.Tensor]], # inference_logprob_batches
Optional[List[torch.Tensor]], # distill_token_id_batches
Optional[List[torch.Tensor]], # distill_logprob_batches
]
],
None, # Legacy return (no longer used)
@ -377,11 +251,47 @@ def get_data(
- inference_logprob_batches are aligned with labels for proper GRPO loss computation
"""
batches = []
_logged_logprob_warning = False
while True:
data = get_batch(url=atropos_url)
if data["batch"] is not None:
# DEBUG: Check if inference_logprobs exists in the data
if not _logged_logprob_warning:
has_logprobs = any(
"inference_logprobs" in item for item in data["batch"]
)
if has_logprobs:
# Check if they're non-empty
sample_item = next(
(
item
for item in data["batch"]
if "inference_logprobs" in item
),
None,
)
if sample_item and sample_item.get("inference_logprobs"):
sample_lp = (
sample_item["inference_logprobs"][0]
if sample_item["inference_logprobs"]
else []
)
print(
f" [Data] ✓ inference_logprobs found in batch (sample len: {len(sample_lp)})"
)
else:
print(
" [Data] ⚠ inference_logprobs key exists but is empty!"
)
else:
print(" [Data] ⚠ NO inference_logprobs in batch data!")
print(
f" [Data] Keys in first item: {list(data['batch'][0].keys())}"
)
_logged_logprob_warning = True
# Process and accumulate batches (now includes batched inference logprobs)
(
token_batches,
@ -389,8 +299,6 @@ def get_data(
adv_batches,
temp_batches,
inf_logprob_batches,
distill_token_id_batches,
distill_logprob_batches,
) = pad_data_to_good_offset(data, batch_size, extract_inference_logprobs)
# Include inference logprob batches in the tuple
@ -401,8 +309,6 @@ def get_data(
adv_batches,
temp_batches,
inf_logprob_batches,
distill_token_id_batches,
distill_logprob_batches,
)
)

View file

@ -201,9 +201,6 @@ def main():
checkpoint_interval=args.checkpoint_interval,
# GRPO hyperparameters
clip_eps=args.clip_eps,
distill_enabled=getattr(args, "distill_enabled", False),
distill_coef=getattr(args, "distill_coef", 0.0),
distill_temperature=getattr(args, "distill_temperature", 1.0),
# vLLM settings
vllm_port=args.vllm_port,
vllm_gpu_memory_utilization=args.gpu_memory_utilization,

View file

@ -170,8 +170,6 @@ def train_legacy(config: TrainingConfig):
batch_data[:4]
)
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
distill_token_id_batches = batch_data[5] if len(batch_data) > 5 else None
distill_logprob_batches = batch_data[6] if len(batch_data) > 6 else None
data_fetch_time = time.time() - data_fetch_start
benchmark_stats["data_fetch_times"].append(data_fetch_time)
@ -194,8 +192,6 @@ def train_legacy(config: TrainingConfig):
config,
step_idx=step,
inference_logprob_batches=inference_logprob_batches,
distill_token_id_batches=distill_token_id_batches,
distill_logprob_batches=distill_logprob_batches,
)
step_time = time.time() - step_start
benchmark_stats["step_times"].append(step_time)
@ -317,30 +313,17 @@ def train_shared_vllm(config: TrainingConfig):
# Fetch data (with inference logprobs for proper GRPO loss)
data_fetch_start = time.time()
if len(batches) == 0:
print(" [Trainer] requesting data from Atropos API...", flush=True)
batches, _ = get_data(
config.batch_size,
config.seq_len,
config.atropos_url,
extract_inference_logprobs=True, # Enable proper GRPO with reference logprobs
)
print(
f" [Trainer] get_data returned {len(batches)} trainer batch tuple(s)",
flush=True,
)
batch_data = batches.pop(0)
token_batches, label_batches, advantage_batches, temperature_batches = (
batch_data[:4]
)
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
distill_token_id_batches = batch_data[5] if len(batch_data) > 5 else None
distill_logprob_batches = batch_data[6] if len(batch_data) > 6 else None
token_shapes = [tuple(tb.shape) for tb in token_batches]
print(
" [Trainer] selected trainer batch: "
f"micro_batches={len(token_batches)} token_batch_shapes={token_shapes}",
flush=True,
)
data_fetch_time = time.time() - data_fetch_start
benchmark_stats["data_fetch_times"].append(data_fetch_time)
@ -356,8 +339,6 @@ def train_shared_vllm(config: TrainingConfig):
config,
step_idx=step,
inference_logprob_batches=inference_logprob_batches, # Pass for GRPO ratio computation
distill_token_id_batches=distill_token_id_batches,
distill_logprob_batches=distill_logprob_batches,
)
step_time = time.time() - step_start
benchmark_stats["step_times"].append(step_time)
@ -503,8 +484,6 @@ def train_lora(config: TrainingConfig):
batch_data[:4]
)
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
distill_token_id_batches = batch_data[5] if len(batch_data) > 5 else None
distill_logprob_batches = batch_data[6] if len(batch_data) > 6 else None
data_fetch_time = time.time() - data_fetch_start
benchmark_stats["data_fetch_times"].append(data_fetch_time)
@ -520,8 +499,6 @@ def train_lora(config: TrainingConfig):
config,
step_idx=step,
inference_logprob_batches=inference_logprob_batches,
distill_token_id_batches=distill_token_id_batches,
distill_logprob_batches=distill_logprob_batches,
)
step_time = time.time() - step_start
benchmark_stats["step_times"].append(step_time)
@ -729,8 +706,6 @@ def train_lora_restart(config: TrainingConfig):
batch_data[:4]
)
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
distill_token_id_batches = batch_data[5] if len(batch_data) > 5 else None
distill_logprob_batches = batch_data[6] if len(batch_data) > 6 else None
data_fetch_time = time.time() - data_fetch_start
benchmark_stats["data_fetch_times"].append(data_fetch_time)
@ -746,8 +721,6 @@ def train_lora_restart(config: TrainingConfig):
config,
step_idx=step,
inference_logprob_batches=inference_logprob_batches,
distill_token_id_batches=distill_token_id_batches,
distill_logprob_batches=distill_logprob_batches,
)
step_time = time.time() - step_start
benchmark_stats["step_times"].append(step_time)

View file

@ -70,11 +70,6 @@ def compute_grpo_loss(
gradient_accumulation_steps: int,
inference_logprobs: Optional[torch.Tensor] = None,
clip_eps: float = 0.2,
distill_token_ids: Optional[torch.Tensor] = None,
distill_logprobs: Optional[torch.Tensor] = None,
distill_enabled: bool = False,
distill_coef: float = 0.0,
distill_temperature: float = 1.0,
) -> Tuple[torch.Tensor, dict]:
"""
Compute GRPO (Group Relative Policy Optimization) loss for a single micro-batch.
@ -130,9 +125,6 @@ def compute_grpo_loss(
logprob_diff_abs_mean = 0.0
logprob_diff_max = 0.0
distill_loss_value = torch.tensor(0.0, device=logp_per_token.device)
distill_token_count = 0.0
# === GRPO/PPO Loss Computation ===
if inference_logprobs is not None:
# Move inference logprobs to correct device/dtype
@ -195,23 +187,7 @@ def compute_grpo_loss(
# Average over tokens, then over batch
policy_loss = ((policy_loss_per_token * mask).sum(dim=-1) / mask_sum).mean()
if (
distill_enabled
and distill_coef > 0
and distill_token_ids is not None
and distill_logprobs is not None
):
distill_loss_value, distill_token_count = compute_distillation_loss(
logits=scaled_logits,
labels=labels,
distill_token_ids=distill_token_ids.to(logits.device),
distill_logprobs=distill_logprobs.to(logits.device, logits.dtype),
temperature=max(1e-6, float(distill_temperature)),
)
total_loss = (policy_loss + distill_coef * distill_loss_value) / (
gradient_accumulation_steps
)
total_loss = policy_loss / gradient_accumulation_steps
# Compute metrics for logging
with torch.no_grad():
@ -277,77 +253,11 @@ def compute_grpo_loss(
"logprob_diff_mean": logprob_diff_mean,
"logprob_diff_abs_mean": logprob_diff_abs_mean,
"logprob_diff_max": logprob_diff_max,
"distill_loss": (
distill_loss_value.item()
if torch.is_tensor(distill_loss_value)
else float(distill_loss_value)
),
"distill_token_count": distill_token_count,
}
return total_loss, metrics
def compute_distillation_loss(
logits: torch.Tensor,
labels: torch.Tensor,
distill_token_ids: torch.Tensor,
distill_logprobs: torch.Tensor,
temperature: float = 1.0,
) -> Tuple[torch.Tensor, float]:
"""
Compute token-level distillation loss from teacher top-k prompt logprobs.
Args:
logits: Student logits [batch, seq_len, vocab]
labels: Labels [batch, seq_len], -100 for masked positions
distill_token_ids: Teacher top-k token IDs [batch, seq_len, k], -1 padded
distill_logprobs: Teacher top-k logprobs [batch, seq_len, k], very negative padded
temperature: Distillation temperature
Returns:
Tuple of (distillation loss scalar, valid token count)
"""
if distill_token_ids.dim() != 3 or distill_logprobs.dim() != 3:
return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0
if (
distill_token_ids.shape[:2] != labels.shape
or distill_logprobs.shape != distill_token_ids.shape
):
return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0
temp = max(1e-6, float(temperature))
valid_ids = distill_token_ids >= 0
label_mask = labels != -100
valid_pos = label_mask & valid_ids.any(dim=-1)
if not valid_pos.any():
return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0
gather_ids = distill_token_ids.clamp_min(0).long()
# Avoid materializing the full [batch, seq_len, vocab] log_softmax tensor
# (e.g. [2, 20480, 151936] = ~12.5 GB) which is the main cause of OOM/hangs.
# Instead: gather raw logits at top-k positions, then subtract logsumexp.
# Output tensors are [batch, seq_len, k] (tiny) not [batch, seq_len, vocab].
scaled_logits = logits / temp
log_normalizer = torch.logsumexp(scaled_logits, dim=-1, keepdim=True) # [b, s, 1]
student_logp_topk = (
torch.gather(scaled_logits, dim=-1, index=gather_ids) - log_normalizer
)
masked_teacher_logprobs = distill_logprobs.masked_fill(~valid_ids, -1e9)
teacher_probs = F.softmax(masked_teacher_logprobs / temp, dim=-1)
per_token_loss = -(teacher_probs * student_logp_topk).sum(dim=-1)
per_token_loss = per_token_loss * valid_pos.to(per_token_loss.dtype)
token_count = valid_pos.sum().item()
loss = per_token_loss.sum() / valid_pos.sum().clamp_min(1).to(per_token_loss.dtype)
return loss, float(token_count)
def run_training_step(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
@ -358,8 +268,6 @@ def run_training_step(
config: TrainingConfig,
step_idx: int,
inference_logprob_batches: Optional[List[torch.Tensor]] = None,
distill_token_id_batches: Optional[List[torch.Tensor]] = None,
distill_logprob_batches: Optional[List[torch.Tensor]] = None,
) -> dict:
"""
Run a single training step with gradient accumulation.
@ -394,8 +302,6 @@ def run_training_step(
total_logprob_diff_mean = 0.0
total_logprob_diff_abs_mean = 0.0
total_logprob_diff_max = 0.0
total_distill_loss = 0.0
total_distill_tokens = 0.0
grad_norm = 0.0
all_training_logprobs: List[torch.Tensor] = []
all_inference_logprobs: List[torch.Tensor] = []
@ -419,13 +325,6 @@ def run_training_step(
for batch_idx, (tokens, labels, advantages, temperatures) in enumerate(
zip(token_batches, label_batches, advantage_batches, temperature_batches)
):
print(
f" [Step] micro-batch {batch_idx+1}/{num_batches} "
f"tokens={tokens.shape} "
f"gpu_mem={torch.cuda.memory_allocated()/1e9:.1f}GB "
f"gpu_reserved={torch.cuda.memory_reserved()/1e9:.1f}GB",
flush=True,
)
tokens = tokens.to(config.device)
labels = labels.to(config.device)
advantages = advantages.to(config.device)
@ -436,18 +335,7 @@ def run_training_step(
inference_logprob_batches
):
inf_logprobs = inference_logprob_batches[batch_idx]
distill_ids = None
if distill_token_id_batches is not None and batch_idx < len(
distill_token_id_batches
):
distill_ids = distill_token_id_batches[batch_idx]
distill_lps = None
if distill_logprob_batches is not None and batch_idx < len(
distill_logprob_batches
):
distill_lps = distill_logprob_batches[batch_idx]
print(f" [Step] micro-batch {batch_idx+1} forward pass...", flush=True)
loss, metrics = compute_grpo_loss(
model,
tokens,
@ -457,20 +345,9 @@ def run_training_step(
config.gradient_accumulation_steps,
inference_logprobs=inf_logprobs,
clip_eps=clip_eps,
distill_token_ids=distill_ids,
distill_logprobs=distill_lps,
distill_enabled=bool(getattr(config, "distill_enabled", False)),
distill_coef=float(getattr(config, "distill_coef", 0.0)),
distill_temperature=float(getattr(config, "distill_temperature", 1.0)),
)
print(
f" [Step] micro-batch {batch_idx+1} loss={loss.item():.4f} "
f"backward...",
flush=True,
)
loss.backward()
print(f" [Step] micro-batch {batch_idx+1} backward done", flush=True)
total_loss += loss.item()
total_pos_logp += metrics["pos_logp"]
total_neg_logp += metrics["neg_logp"]
@ -487,8 +364,6 @@ def run_training_step(
total_logprob_diff_max = max(
total_logprob_diff_max, metrics.get("logprob_diff_max", 0.0)
)
total_distill_loss += metrics.get("distill_loss", 0.0)
total_distill_tokens += metrics.get("distill_token_count", 0.0)
# Collect logprobs for alignment monitoring
if "training_logprobs" in metrics and metrics["training_logprobs"] is not None:
@ -524,8 +399,6 @@ def run_training_step(
# GRPO-specific metrics (averaged over batches)
"mean_ratio": total_mean_ratio / num_batches,
"clipped_fraction": total_clipped_fraction / num_batches,
"distill_loss": total_distill_loss / num_batches,
"distill_token_count": total_distill_tokens,
}
# Compute logprob alignment stats for monitoring
@ -599,12 +472,6 @@ def log_metrics(
clipped_frac = metrics.get("clipped_fraction", 0)
print(f" GRPO: ratio={mean_ratio:.3f}, clipped={clipped_frac*100:.1f}%")
if metrics.get("distill_token_count", 0) > 0:
print(
" Distill: "
f"loss={metrics.get('distill_loss', 0.0):.4f}, "
f"tokens={int(metrics.get('distill_token_count', 0))}"
)
# Advantage distribution
if "pos_count" in metrics or "neg_count" in metrics:
@ -627,8 +494,6 @@ def log_metrics(
# GRPO-specific metrics
"grpo/mean_ratio": mean_ratio,
"grpo/clipped_fraction": clipped_frac,
"distill/loss": metrics.get("distill_loss", 0.0),
"distill/token_count": metrics.get("distill_token_count", 0.0),
}
# Add timing metrics if present
for key in [