diff --git a/example_trainer/cli.py b/example_trainer/cli.py index 93946d51..1e46bfc9 100644 --- a/example_trainer/cli.py +++ b/example_trainer/cli.py @@ -163,23 +163,6 @@ def add_grpo_args(parser: argparse.ArgumentParser) -> None: default=0.2, help="PPO-style clipping epsilon. Clips ratio to [1-eps, 1+eps].", ) - group.add_argument( - "--distill-enabled", - action="store_true", - help="Enable teacher distillation loss (requires distill payload in Atropos batch).", - ) - group.add_argument( - "--distill-coef", - type=float, - default=0.0, - help="Coefficient for distillation loss term.", - ) - group.add_argument( - "--distill-temperature", - type=float, - default=1.0, - help="Temperature for teacher top-k distribution in distillation loss.", - ) def add_vllm_args(parser: argparse.ArgumentParser) -> None: @@ -441,9 +424,6 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig: checkpoint_interval=getattr(args, "checkpoint_interval", 3), # GRPO/PPO hyperparameters clip_eps=getattr(args, "clip_eps", 0.2), - distill_enabled=getattr(args, "distill_enabled", False), - distill_coef=getattr(args, "distill_coef", 0.0), - distill_temperature=getattr(args, "distill_temperature", 1.0), adafactor_scale_parameter=getattr(args, "adafactor_scale_parameter", False), adafactor_relative_step=getattr(args, "adafactor_relative_step", False), # vLLM settings diff --git a/example_trainer/config.py b/example_trainer/config.py index 03fd80a8..4ddeddb5 100644 --- a/example_trainer/config.py +++ b/example_trainer/config.py @@ -69,18 +69,6 @@ class TrainingConfig(BaseModel): "Prevents large policy updates that could destabilize training." ), ) - distill_enabled: bool = Field( - False, - description="Enable teacher distillation loss when distill tensors are present.", - ) - distill_coef: float = Field( - 0.0, - description="Weight for distillation loss in total loss.", - ) - distill_temperature: float = Field( - 1.0, - description="Temperature applied when converting teacher top-k logprobs.", - ) # === Device & Storage === device: str = Field( "cuda" if torch.cuda.is_available() else "cpu", description="Device to train on" diff --git a/example_trainer/data.py b/example_trainer/data.py index 0aa1a88a..16a38564 100644 --- a/example_trainer/data.py +++ b/example_trainer/data.py @@ -29,8 +29,6 @@ def pad_data_to_good_offset( List[torch.Tensor], # advantage_batches List[torch.Tensor], # temperature_batches Optional[List[torch.Tensor]], # inference_logprob_batches (aligned with labels) - Optional[List[torch.Tensor]], # distill_token_id_batches [batch, seq, k] - Optional[List[torch.Tensor]], # distill_logprob_batches [batch, seq, k] ]: """ Pad and batch data from the Atropos API. @@ -47,8 +45,7 @@ def pad_data_to_good_offset( extract_inference_logprobs: Whether to extract inference logprobs Returns: - Tuple of (token_batches, label_batches, advantage_batches, temperature_batches, - inference_logprob_batches, distill_token_id_batches, distill_logprob_batches) + Tuple of (token_batches, label_batches, advantage_batches, temperature_batches, inference_logprob_batches) inference_logprob_batches is None if extract_inference_logprobs=False or no logprobs in data Note: @@ -76,10 +73,6 @@ def pad_data_to_good_offset( temperatures = [] inference_logprobs_padded: List[np.ndarray] = [] # Padded to match labels shape has_any_logprobs = False - distill_token_ids_padded: List[np.ndarray] = [] - distill_logprobs_padded: List[np.ndarray] = [] - has_any_distill = False - max_distill_k = 1 for item in data["batch"]: # Normalize advantage scores @@ -160,85 +153,6 @@ def pad_data_to_good_offset( np.full(token_setup_len - 1, 1.0, dtype=np.float32) ) - # Extract teacher distillation top-k arrays if available. - # Expected shape in incoming payload: [sequence][position][k]. - if "distill_token_ids" in item and "distill_logprobs" in item: - seq_token_ids = item["distill_token_ids"] - seq_logprobs = item["distill_logprobs"] - if ( - isinstance(seq_token_ids, list) - and isinstance(seq_logprobs, list) - and i < len(seq_token_ids) - and i < len(seq_logprobs) - and seq_token_ids[i] is not None - and seq_logprobs[i] is not None - ): - per_pos_token_ids = seq_token_ids[i] - per_pos_logprobs = seq_logprobs[i] - if ( - isinstance(per_pos_token_ids, list) - and isinstance(per_pos_logprobs, list) - and len(per_pos_token_ids) == len(per_pos_logprobs) - ): - local_k = 1 - for row_ids in per_pos_token_ids: - if isinstance(row_ids, list): - local_k = max(local_k, len(row_ids)) - max_distill_k = max(max_distill_k, local_k) - has_any_distill = True - - rows = max(0, token_setup_len - 1) - token_mat = np.full((rows, local_k), -1, dtype=np.int64) - logprob_mat = np.full((rows, local_k), -1e9, dtype=np.float32) - - # Shift by one to align with causal labels like inference_logprobs. - copy_positions = min( - len(per_pos_token_ids), - len(per_pos_logprobs), - token_setup_len, - ) - for pos in range(1, copy_positions): - src_ids = per_pos_token_ids[pos] - src_lps = per_pos_logprobs[pos] - if not isinstance(src_ids, list) or not isinstance( - src_lps, list - ): - continue - topk = min(local_k, len(src_ids), len(src_lps)) - if topk <= 0: - continue - token_mat[pos - 1, :topk] = np.array( - src_ids[:topk], dtype=np.int64 - ) - logprob_mat[pos - 1, :topk] = np.array( - src_lps[:topk], dtype=np.float32 - ) - - distill_token_ids_padded.append(token_mat) - distill_logprobs_padded.append(logprob_mat) - else: - rows = max(0, token_setup_len - 1) - distill_token_ids_padded.append( - np.full((rows, 1), -1, dtype=np.int64) - ) - distill_logprobs_padded.append( - np.full((rows, 1), -1e9, dtype=np.float32) - ) - else: - rows = max(0, token_setup_len - 1) - distill_token_ids_padded.append( - np.full((rows, 1), -1, dtype=np.int64) - ) - distill_logprobs_padded.append( - np.full((rows, 1), -1e9, dtype=np.float32) - ) - else: - rows = max(0, token_setup_len - 1) - distill_token_ids_padded.append(np.full((rows, 1), -1, dtype=np.int64)) - distill_logprobs_padded.append( - np.full((rows, 1), -1e9, dtype=np.float32) - ) - # Extract temperature (priority: override > generation_params > group_overrides > 1.0) t = 1.0 if ( @@ -264,8 +178,6 @@ def pad_data_to_good_offset( advantage_batches = [] temperature_batches = [] inference_logprob_batches = [] - distill_token_id_batches = [] - distill_logprob_batches = [] for start in range(0, len(input_ids), batch_size): end = min(start + batch_size, len(input_ids)) @@ -287,46 +199,12 @@ def pad_data_to_good_offset( torch.tensor(np.stack(inference_logprobs_padded[start:end], axis=0)) ) - if distill_token_ids_padded and distill_logprobs_padded: - seq_slice_ids = distill_token_ids_padded[start:end] - seq_slice_lps = distill_logprobs_padded[start:end] - normalized_ids = [] - normalized_lps = [] - for ids_mat, lps_mat in zip(seq_slice_ids, seq_slice_lps): - if ids_mat.shape[1] < max_distill_k: - pad_cols = max_distill_k - ids_mat.shape[1] - ids_mat = np.pad( - ids_mat, ((0, 0), (0, pad_cols)), constant_values=-1 - ) - lps_mat = np.pad( - lps_mat, ((0, 0), (0, pad_cols)), constant_values=-1e9 - ) - normalized_ids.append(ids_mat) - normalized_lps.append(lps_mat) - - distill_token_id_batches.append( - torch.tensor(np.stack(normalized_ids, axis=0), dtype=torch.long) - ) - distill_logprob_batches.append( - torch.tensor(np.stack(normalized_lps, axis=0), dtype=torch.float32) - ) - # Return inference logprob batches if we have any real logprobs final_logprob_batches = ( inference_logprob_batches if (has_any_logprobs and inference_logprob_batches) else None ) - final_distill_token_id_batches = ( - distill_token_id_batches - if (has_any_distill and distill_token_id_batches) - else None - ) - final_distill_logprob_batches = ( - distill_logprob_batches - if (has_any_distill and distill_logprob_batches) - else None - ) return ( token_batches, @@ -334,8 +212,6 @@ def pad_data_to_good_offset( advantage_batches, temperature_batches, final_logprob_batches, - final_distill_token_id_batches, - final_distill_logprob_batches, ) @@ -352,8 +228,6 @@ def get_data( List[torch.Tensor], # advantage_batches List[torch.Tensor], # temperature_batches Optional[List[torch.Tensor]], # inference_logprob_batches - Optional[List[torch.Tensor]], # distill_token_id_batches - Optional[List[torch.Tensor]], # distill_logprob_batches ] ], None, # Legacy return (no longer used) @@ -377,11 +251,47 @@ def get_data( - inference_logprob_batches are aligned with labels for proper GRPO loss computation """ batches = [] + _logged_logprob_warning = False while True: data = get_batch(url=atropos_url) if data["batch"] is not None: + # DEBUG: Check if inference_logprobs exists in the data + if not _logged_logprob_warning: + has_logprobs = any( + "inference_logprobs" in item for item in data["batch"] + ) + if has_logprobs: + # Check if they're non-empty + sample_item = next( + ( + item + for item in data["batch"] + if "inference_logprobs" in item + ), + None, + ) + if sample_item and sample_item.get("inference_logprobs"): + sample_lp = ( + sample_item["inference_logprobs"][0] + if sample_item["inference_logprobs"] + else [] + ) + print( + f" [Data] ✓ inference_logprobs found in batch (sample len: {len(sample_lp)})" + ) + else: + print( + " [Data] ⚠ inference_logprobs key exists but is empty!" + ) + else: + print(" [Data] ⚠ NO inference_logprobs in batch data!") + print( + f" [Data] Keys in first item: {list(data['batch'][0].keys())}" + ) + _logged_logprob_warning = True + # Process and accumulate batches (now includes batched inference logprobs) ( token_batches, @@ -389,8 +299,6 @@ def get_data( adv_batches, temp_batches, inf_logprob_batches, - distill_token_id_batches, - distill_logprob_batches, ) = pad_data_to_good_offset(data, batch_size, extract_inference_logprobs) # Include inference logprob batches in the tuple @@ -401,8 +309,6 @@ def get_data( adv_batches, temp_batches, inf_logprob_batches, - distill_token_id_batches, - distill_logprob_batches, ) ) diff --git a/example_trainer/run.py b/example_trainer/run.py index d1cf37b2..b9b5f88f 100644 --- a/example_trainer/run.py +++ b/example_trainer/run.py @@ -201,9 +201,6 @@ def main(): checkpoint_interval=args.checkpoint_interval, # GRPO hyperparameters clip_eps=args.clip_eps, - distill_enabled=getattr(args, "distill_enabled", False), - distill_coef=getattr(args, "distill_coef", 0.0), - distill_temperature=getattr(args, "distill_temperature", 1.0), # vLLM settings vllm_port=args.vllm_port, vllm_gpu_memory_utilization=args.gpu_memory_utilization, diff --git a/example_trainer/trainers.py b/example_trainer/trainers.py index bff1763f..4c9e2893 100644 --- a/example_trainer/trainers.py +++ b/example_trainer/trainers.py @@ -170,8 +170,6 @@ def train_legacy(config: TrainingConfig): batch_data[:4] ) inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None - distill_token_id_batches = batch_data[5] if len(batch_data) > 5 else None - distill_logprob_batches = batch_data[6] if len(batch_data) > 6 else None data_fetch_time = time.time() - data_fetch_start benchmark_stats["data_fetch_times"].append(data_fetch_time) @@ -194,8 +192,6 @@ def train_legacy(config: TrainingConfig): config, step_idx=step, inference_logprob_batches=inference_logprob_batches, - distill_token_id_batches=distill_token_id_batches, - distill_logprob_batches=distill_logprob_batches, ) step_time = time.time() - step_start benchmark_stats["step_times"].append(step_time) @@ -317,30 +313,17 @@ def train_shared_vllm(config: TrainingConfig): # Fetch data (with inference logprobs for proper GRPO loss) data_fetch_start = time.time() if len(batches) == 0: - print(" [Trainer] requesting data from Atropos API...", flush=True) batches, _ = get_data( config.batch_size, config.seq_len, config.atropos_url, extract_inference_logprobs=True, # Enable proper GRPO with reference logprobs ) - print( - f" [Trainer] get_data returned {len(batches)} trainer batch tuple(s)", - flush=True, - ) batch_data = batches.pop(0) token_batches, label_batches, advantage_batches, temperature_batches = ( batch_data[:4] ) inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None - distill_token_id_batches = batch_data[5] if len(batch_data) > 5 else None - distill_logprob_batches = batch_data[6] if len(batch_data) > 6 else None - token_shapes = [tuple(tb.shape) for tb in token_batches] - print( - " [Trainer] selected trainer batch: " - f"micro_batches={len(token_batches)} token_batch_shapes={token_shapes}", - flush=True, - ) data_fetch_time = time.time() - data_fetch_start benchmark_stats["data_fetch_times"].append(data_fetch_time) @@ -356,8 +339,6 @@ def train_shared_vllm(config: TrainingConfig): config, step_idx=step, inference_logprob_batches=inference_logprob_batches, # Pass for GRPO ratio computation - distill_token_id_batches=distill_token_id_batches, - distill_logprob_batches=distill_logprob_batches, ) step_time = time.time() - step_start benchmark_stats["step_times"].append(step_time) @@ -503,8 +484,6 @@ def train_lora(config: TrainingConfig): batch_data[:4] ) inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None - distill_token_id_batches = batch_data[5] if len(batch_data) > 5 else None - distill_logprob_batches = batch_data[6] if len(batch_data) > 6 else None data_fetch_time = time.time() - data_fetch_start benchmark_stats["data_fetch_times"].append(data_fetch_time) @@ -520,8 +499,6 @@ def train_lora(config: TrainingConfig): config, step_idx=step, inference_logprob_batches=inference_logprob_batches, - distill_token_id_batches=distill_token_id_batches, - distill_logprob_batches=distill_logprob_batches, ) step_time = time.time() - step_start benchmark_stats["step_times"].append(step_time) @@ -729,8 +706,6 @@ def train_lora_restart(config: TrainingConfig): batch_data[:4] ) inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None - distill_token_id_batches = batch_data[5] if len(batch_data) > 5 else None - distill_logprob_batches = batch_data[6] if len(batch_data) > 6 else None data_fetch_time = time.time() - data_fetch_start benchmark_stats["data_fetch_times"].append(data_fetch_time) @@ -746,8 +721,6 @@ def train_lora_restart(config: TrainingConfig): config, step_idx=step, inference_logprob_batches=inference_logprob_batches, - distill_token_id_batches=distill_token_id_batches, - distill_logprob_batches=distill_logprob_batches, ) step_time = time.time() - step_start benchmark_stats["step_times"].append(step_time) diff --git a/example_trainer/training.py b/example_trainer/training.py index b7cab944..035d45c7 100644 --- a/example_trainer/training.py +++ b/example_trainer/training.py @@ -70,11 +70,6 @@ def compute_grpo_loss( gradient_accumulation_steps: int, inference_logprobs: Optional[torch.Tensor] = None, clip_eps: float = 0.2, - distill_token_ids: Optional[torch.Tensor] = None, - distill_logprobs: Optional[torch.Tensor] = None, - distill_enabled: bool = False, - distill_coef: float = 0.0, - distill_temperature: float = 1.0, ) -> Tuple[torch.Tensor, dict]: """ Compute GRPO (Group Relative Policy Optimization) loss for a single micro-batch. @@ -130,9 +125,6 @@ def compute_grpo_loss( logprob_diff_abs_mean = 0.0 logprob_diff_max = 0.0 - distill_loss_value = torch.tensor(0.0, device=logp_per_token.device) - distill_token_count = 0.0 - # === GRPO/PPO Loss Computation === if inference_logprobs is not None: # Move inference logprobs to correct device/dtype @@ -195,23 +187,7 @@ def compute_grpo_loss( # Average over tokens, then over batch policy_loss = ((policy_loss_per_token * mask).sum(dim=-1) / mask_sum).mean() - if ( - distill_enabled - and distill_coef > 0 - and distill_token_ids is not None - and distill_logprobs is not None - ): - distill_loss_value, distill_token_count = compute_distillation_loss( - logits=scaled_logits, - labels=labels, - distill_token_ids=distill_token_ids.to(logits.device), - distill_logprobs=distill_logprobs.to(logits.device, logits.dtype), - temperature=max(1e-6, float(distill_temperature)), - ) - - total_loss = (policy_loss + distill_coef * distill_loss_value) / ( - gradient_accumulation_steps - ) + total_loss = policy_loss / gradient_accumulation_steps # Compute metrics for logging with torch.no_grad(): @@ -277,77 +253,11 @@ def compute_grpo_loss( "logprob_diff_mean": logprob_diff_mean, "logprob_diff_abs_mean": logprob_diff_abs_mean, "logprob_diff_max": logprob_diff_max, - "distill_loss": ( - distill_loss_value.item() - if torch.is_tensor(distill_loss_value) - else float(distill_loss_value) - ), - "distill_token_count": distill_token_count, } return total_loss, metrics -def compute_distillation_loss( - logits: torch.Tensor, - labels: torch.Tensor, - distill_token_ids: torch.Tensor, - distill_logprobs: torch.Tensor, - temperature: float = 1.0, -) -> Tuple[torch.Tensor, float]: - """ - Compute token-level distillation loss from teacher top-k prompt logprobs. - - Args: - logits: Student logits [batch, seq_len, vocab] - labels: Labels [batch, seq_len], -100 for masked positions - distill_token_ids: Teacher top-k token IDs [batch, seq_len, k], -1 padded - distill_logprobs: Teacher top-k logprobs [batch, seq_len, k], very negative padded - temperature: Distillation temperature - - Returns: - Tuple of (distillation loss scalar, valid token count) - """ - if distill_token_ids.dim() != 3 or distill_logprobs.dim() != 3: - return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0 - - if ( - distill_token_ids.shape[:2] != labels.shape - or distill_logprobs.shape != distill_token_ids.shape - ): - return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0 - - temp = max(1e-6, float(temperature)) - - valid_ids = distill_token_ids >= 0 - label_mask = labels != -100 - valid_pos = label_mask & valid_ids.any(dim=-1) - if not valid_pos.any(): - return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0 - - gather_ids = distill_token_ids.clamp_min(0).long() - - # Avoid materializing the full [batch, seq_len, vocab] log_softmax tensor - # (e.g. [2, 20480, 151936] = ~12.5 GB) which is the main cause of OOM/hangs. - # Instead: gather raw logits at top-k positions, then subtract logsumexp. - # Output tensors are [batch, seq_len, k] (tiny) not [batch, seq_len, vocab]. - scaled_logits = logits / temp - log_normalizer = torch.logsumexp(scaled_logits, dim=-1, keepdim=True) # [b, s, 1] - student_logp_topk = ( - torch.gather(scaled_logits, dim=-1, index=gather_ids) - log_normalizer - ) - - masked_teacher_logprobs = distill_logprobs.masked_fill(~valid_ids, -1e9) - teacher_probs = F.softmax(masked_teacher_logprobs / temp, dim=-1) - - per_token_loss = -(teacher_probs * student_logp_topk).sum(dim=-1) - per_token_loss = per_token_loss * valid_pos.to(per_token_loss.dtype) - - token_count = valid_pos.sum().item() - loss = per_token_loss.sum() / valid_pos.sum().clamp_min(1).to(per_token_loss.dtype) - return loss, float(token_count) - - def run_training_step( model: torch.nn.Module, optimizer: torch.optim.Optimizer, @@ -358,8 +268,6 @@ def run_training_step( config: TrainingConfig, step_idx: int, inference_logprob_batches: Optional[List[torch.Tensor]] = None, - distill_token_id_batches: Optional[List[torch.Tensor]] = None, - distill_logprob_batches: Optional[List[torch.Tensor]] = None, ) -> dict: """ Run a single training step with gradient accumulation. @@ -394,8 +302,6 @@ def run_training_step( total_logprob_diff_mean = 0.0 total_logprob_diff_abs_mean = 0.0 total_logprob_diff_max = 0.0 - total_distill_loss = 0.0 - total_distill_tokens = 0.0 grad_norm = 0.0 all_training_logprobs: List[torch.Tensor] = [] all_inference_logprobs: List[torch.Tensor] = [] @@ -419,13 +325,6 @@ def run_training_step( for batch_idx, (tokens, labels, advantages, temperatures) in enumerate( zip(token_batches, label_batches, advantage_batches, temperature_batches) ): - print( - f" [Step] micro-batch {batch_idx+1}/{num_batches} " - f"tokens={tokens.shape} " - f"gpu_mem={torch.cuda.memory_allocated()/1e9:.1f}GB " - f"gpu_reserved={torch.cuda.memory_reserved()/1e9:.1f}GB", - flush=True, - ) tokens = tokens.to(config.device) labels = labels.to(config.device) advantages = advantages.to(config.device) @@ -436,18 +335,7 @@ def run_training_step( inference_logprob_batches ): inf_logprobs = inference_logprob_batches[batch_idx] - distill_ids = None - if distill_token_id_batches is not None and batch_idx < len( - distill_token_id_batches - ): - distill_ids = distill_token_id_batches[batch_idx] - distill_lps = None - if distill_logprob_batches is not None and batch_idx < len( - distill_logprob_batches - ): - distill_lps = distill_logprob_batches[batch_idx] - print(f" [Step] micro-batch {batch_idx+1} forward pass...", flush=True) loss, metrics = compute_grpo_loss( model, tokens, @@ -457,20 +345,9 @@ def run_training_step( config.gradient_accumulation_steps, inference_logprobs=inf_logprobs, clip_eps=clip_eps, - distill_token_ids=distill_ids, - distill_logprobs=distill_lps, - distill_enabled=bool(getattr(config, "distill_enabled", False)), - distill_coef=float(getattr(config, "distill_coef", 0.0)), - distill_temperature=float(getattr(config, "distill_temperature", 1.0)), ) - print( - f" [Step] micro-batch {batch_idx+1} loss={loss.item():.4f} " - f"backward...", - flush=True, - ) loss.backward() - print(f" [Step] micro-batch {batch_idx+1} backward done", flush=True) total_loss += loss.item() total_pos_logp += metrics["pos_logp"] total_neg_logp += metrics["neg_logp"] @@ -487,8 +364,6 @@ def run_training_step( total_logprob_diff_max = max( total_logprob_diff_max, metrics.get("logprob_diff_max", 0.0) ) - total_distill_loss += metrics.get("distill_loss", 0.0) - total_distill_tokens += metrics.get("distill_token_count", 0.0) # Collect logprobs for alignment monitoring if "training_logprobs" in metrics and metrics["training_logprobs"] is not None: @@ -524,8 +399,6 @@ def run_training_step( # GRPO-specific metrics (averaged over batches) "mean_ratio": total_mean_ratio / num_batches, "clipped_fraction": total_clipped_fraction / num_batches, - "distill_loss": total_distill_loss / num_batches, - "distill_token_count": total_distill_tokens, } # Compute logprob alignment stats for monitoring @@ -599,12 +472,6 @@ def log_metrics( clipped_frac = metrics.get("clipped_fraction", 0) print(f" GRPO: ratio={mean_ratio:.3f}, clipped={clipped_frac*100:.1f}%") - if metrics.get("distill_token_count", 0) > 0: - print( - " Distill: " - f"loss={metrics.get('distill_loss', 0.0):.4f}, " - f"tokens={int(metrics.get('distill_token_count', 0))}" - ) # Advantage distribution if "pos_count" in metrics or "neg_count" in metrics: @@ -627,8 +494,6 @@ def log_metrics( # GRPO-specific metrics "grpo/mean_ratio": mean_ratio, "grpo/clipped_fraction": clipped_frac, - "distill/loss": metrics.get("distill_loss", 0.0), - "distill/token_count": metrics.get("distill_token_count", 0.0), } # Add timing metrics if present for key in [