diff --git a/example_trainer/data.py b/example_trainer/data.py index 3d9b4cfd..f290adab 100644 --- a/example_trainer/data.py +++ b/example_trainer/data.py @@ -3,12 +3,14 @@ Data processing utilities for GRPO trainer. Handles data retrieval from Atropos API, padding, batching, and advantage normalization. + +Also extracts inference logprobs for alignment validation with training logprobs. """ import json import math import time -from typing import List, Tuple +from typing import List, Optional, Tuple import numpy as np import torch @@ -16,8 +18,16 @@ import torch from .api import get_batch -def pad_data_to_good_offset(data: dict, batch_size: int) -> Tuple[ - List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor] +def pad_data_to_good_offset( + data: dict, + batch_size: int, + extract_inference_logprobs: bool = True, +) -> Tuple[ + List[torch.Tensor], + List[torch.Tensor], + List[torch.Tensor], + List[torch.Tensor], + Optional[List[np.ndarray]], ]: """ Pad and batch data from the Atropos API. @@ -26,13 +36,16 @@ def pad_data_to_good_offset(data: dict, batch_size: int) -> Tuple[ - Pads token sequences to nearest multiple of 64 - Normalizes advantage scores - Extracts temperature values + - Optionally extracts inference logprobs for alignment validation Args: data: Raw batch data from Atropos API batch_size: Size of each training batch + extract_inference_logprobs: Whether to extract inference logprobs Returns: - Tuple of (token_batches, label_batches, advantage_batches, temperature_batches) + Tuple of (token_batches, label_batches, advantage_batches, temperature_batches, inference_logprobs) + inference_logprobs is None if extract_inference_logprobs=False or no logprobs in data """ max_token_len = max( [max([len(x) for x in item["tokens"]]) for item in data["batch"]] @@ -53,6 +66,7 @@ def pad_data_to_good_offset(data: dict, batch_size: int) -> Tuple[ advantages = [] lengths = [] temperatures = [] + inference_logprobs_list: List[np.ndarray] = [] for item in data["batch"]: # Normalize advantage scores @@ -97,6 +111,14 @@ def pad_data_to_good_offset(data: dict, batch_size: int) -> Tuple[ labels.append(label_item[1:]) # Shift by 1 for causal advantages.append(item["scores"][i]) + # Extract inference logprobs for alignment validation + # These come from vLLM during rollout generation + if extract_inference_logprobs and "inference_logprobs" in item: + if i < len(item["inference_logprobs"]): + inference_logprobs_list.append( + np.array(item["inference_logprobs"][i], dtype=np.float32) + ) + # Extract temperature (priority: override > generation_params > group_overrides > 1.0) t = 1.0 if ( @@ -137,16 +159,21 @@ def pad_data_to_good_offset(data: dict, batch_size: int) -> Tuple[ ).view(-1, 1, 1) ) - return token_batches, label_batches, advantage_batches, temperature_batches + # Return inference logprobs if available + inference_logprobs = inference_logprobs_list if inference_logprobs_list else None + + return token_batches, label_batches, advantage_batches, temperature_batches, inference_logprobs def get_data( batch_size: int, seq_len: int, atropos_url: str = "http://localhost:8000", -) -> List[Tuple[ - List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor] -]]: + extract_inference_logprobs: bool = True, +) -> Tuple[ + List[Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]], + Optional[List[np.ndarray]], +]: """ Fetch and process training data from the Atropos API. @@ -157,11 +184,15 @@ def get_data( batch_size: Size of each training batch seq_len: Maximum sequence length (for reference, not used directly) atropos_url: URL of the Atropos API server + extract_inference_logprobs: Whether to extract inference logprobs for alignment Returns: - List of processed batch tuples + Tuple of (batches, all_inference_logprobs) + - batches: List of processed batch tuples + - all_inference_logprobs: List of inference logprob arrays for alignment validation """ batches = [] + all_inference_logprobs: List[np.ndarray] = [] while True: data = get_batch(url=atropos_url) @@ -172,10 +203,17 @@ def get_data( json.dump(data, f) # Process and accumulate batches - batches.append(pad_data_to_good_offset(data, batch_size)) + token_batches, label_batches, adv_batches, temp_batches, inf_logprobs = \ + pad_data_to_good_offset(data, batch_size, extract_inference_logprobs) + + batches.append((token_batches, label_batches, adv_batches, temp_batches)) + + if inf_logprobs: + all_inference_logprobs.extend(inf_logprobs) + elif len(batches) > 0: # Return accumulated batches when no more data - return batches + return batches, all_inference_logprobs if all_inference_logprobs else None else: # Wait for data time.sleep(1) diff --git a/example_trainer/model.py b/example_trainer/model.py index 71d52c99..37b8bf4f 100644 --- a/example_trainer/model.py +++ b/example_trainer/model.py @@ -510,26 +510,73 @@ def _create_vllm_to_hf_mapping( Handles fused layers: - qkv_proj (vLLM) = q_proj + k_proj + v_proj (HF) - gate_up_proj (vLLM) = gate_proj + up_proj (HF) + + Uses actual tensor shapes from HF model to determine slice sizes, + rather than calculating from config (which can be wrong for some models). """ - hf_params = set(model.state_dict().keys()) + hf_state_dict = model.state_dict() + hf_params = set(hf_state_dict.keys()) vllm_params = set(ipc_handles.keys()) - # Get model config for dimension calculations + # Get model config for fallback dimension calculations model_config = model.config hidden_size = getattr(model_config, "hidden_size", 4096) num_attention_heads = getattr(model_config, "num_attention_heads", 32) num_key_value_heads = getattr(model_config, "num_key_value_heads", num_attention_heads) intermediate_size = getattr(model_config, "intermediate_size", hidden_size * 4) - head_dim = hidden_size // num_attention_heads + + # Try to get head_dim from config (some models like Qwen3 have this) + head_dim = getattr(model_config, "head_dim", None) + if head_dim is None: + head_dim = hidden_size // num_attention_heads - # QKV sizes - q_size = hidden_size - k_size = num_key_value_heads * head_dim - v_size = num_key_value_heads * head_dim + # Determine QKV sizes from ACTUAL HF model tensor shapes (more reliable) + # Look for a q_proj weight in the model to get the actual size + q_size = None + k_size = None + v_size = None + + for name, param in hf_state_dict.items(): + if "q_proj.weight" in name and q_size is None: + q_size = param.shape[0] # Output dimension + elif "k_proj.weight" in name and k_size is None: + k_size = param.shape[0] + elif "v_proj.weight" in name and v_size is None: + v_size = param.shape[0] + if q_size and k_size and v_size: + break + + # Fallback to calculated values if not found + if q_size is None: + q_size = num_attention_heads * head_dim + if k_size is None: + k_size = num_key_value_heads * head_dim + if v_size is None: + v_size = num_key_value_heads * head_dim - if debug: - print(f"[Mapping] Model config: hidden={hidden_size}, heads={num_attention_heads}, " - f"kv_heads={num_key_value_heads}, intermediate={intermediate_size}") + # Also get gate/up sizes from actual HF model + gate_size = None + up_size = None + + for name, param in hf_state_dict.items(): + if "gate_proj.weight" in name and gate_size is None: + gate_size = param.shape[0] + elif "up_proj.weight" in name and up_size is None: + up_size = param.shape[0] + if gate_size and up_size: + break + + # Fallback + if gate_size is None: + gate_size = intermediate_size + if up_size is None: + up_size = intermediate_size + + # Always print sizes for debugging weight sharing issues + print(f"[Mapping] Model config: hidden={hidden_size}, heads={num_attention_heads}, " + f"kv_heads={num_key_value_heads}, head_dim={head_dim}") + print(f"[Mapping] QKV sizes from HF model: q={q_size}, k={k_size}, v={v_size}") + print(f"[Mapping] Gate/Up sizes from HF model: gate={gate_size}, up={up_size}") mapping = {} @@ -586,9 +633,9 @@ def _create_vllm_to_hf_mapping( fused_name = find_fused_source(hf_name, "gate_up_proj") if fused_name: if "gate_proj" in hf_name: - start, end = 0, intermediate_size + start, end = 0, gate_size else: - start, end = intermediate_size, intermediate_size * 2 + start, end = gate_size, gate_size + up_size mapping[hf_name] = { "source": fused_name, diff --git a/example_trainer/training.py b/example_trainer/training.py index 860c15db..0994387a 100644 --- a/example_trainer/training.py +++ b/example_trainer/training.py @@ -2,13 +2,17 @@ Training utilities for GRPO trainer. Contains loss computation, training step logic, and metric logging. + +Includes logprob alignment tracking to verify that training logprobs match +inference logprobs at initialization (validates shared_vllm mode is working). """ import random import string import time -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple +import numpy as np import torch import torch.nn.functional as F import wandb @@ -16,6 +20,10 @@ import wandb from .config import TrainingConfig +# Global storage for logprob alignment stats +_logprob_alignment_stats: Dict[str, float] = {} + + def setup_wandb(config: TrainingConfig) -> bool: """ Initialize Weights & Biases logging if enabled. @@ -62,6 +70,7 @@ def compute_grpo_loss( advantages: torch.Tensor, temperatures: torch.Tensor, gradient_accumulation_steps: int, + inference_logprobs: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, dict]: """ Compute GRPO (Group Relative Policy Optimization) loss for a single micro-batch. @@ -77,6 +86,7 @@ def compute_grpo_loss( advantages: Advantage values [batch, 1] temperatures: Temperature values [batch, 1, 1] gradient_accumulation_steps: Number of accumulation steps (for scaling) + inference_logprobs: Optional logprobs from inference for alignment check Returns: Tuple of (loss tensor, metrics dict) @@ -110,6 +120,9 @@ def compute_grpo_loss( avg_logp = (logp_per_token * mask_float).sum(dim=-1) / mask_sum pos_logp = (logp_per_token * pos).mean().item() neg_logp = (logp_per_token * neg).mean().item() + + # Collect training logprobs for masked positions (generated tokens only) + training_logprobs_flat = logp_per_token[mask.bool()].cpu().numpy() # GRPO loss: weighted log probabilities by advantages grpo_loss_term = torch.exp(logp_per_token - logp_per_token.detach()) @@ -124,11 +137,58 @@ def compute_grpo_loss( "avg_logp": avg_logp, "pos_count": pos.sum().item(), "neg_count": neg.sum().item(), + "training_logprobs": training_logprobs_flat, # For alignment check } return grpo_loss, metrics +def compute_logprob_alignment( + inference_logprobs: List[np.ndarray], + training_logprobs: List[np.ndarray], +) -> Dict[str, float]: + """ + Compute alignment stats between inference and training logprobs. + + At initialization (step 0), these should match closely if the model + weights are correctly shared between training and inference. + + Args: + inference_logprobs: Logprobs from vLLM inference + training_logprobs: Logprobs computed during training forward pass + + Returns: + Dict of alignment statistics + """ + if not inference_logprobs or not training_logprobs: + return {} + + inf_flat = np.concatenate(inference_logprobs) + train_flat = np.concatenate(training_logprobs) + + # Filter out placeholder values (1.0 or 0.0 used for prompt tokens) + inf_mask = (inf_flat != 1.0) & (inf_flat != 0.0) + train_mask = np.ones_like(train_flat, dtype=bool) # All training logprobs are valid + + inf_filtered = inf_flat[inf_mask] + + stats = { + "logprobs/inference_mean": float(np.mean(inf_filtered)) if len(inf_filtered) > 0 else 0.0, + "logprobs/inference_std": float(np.std(inf_filtered)) if len(inf_filtered) > 0 else 0.0, + "logprobs/training_mean": float(np.mean(train_flat)) if len(train_flat) > 0 else 0.0, + "logprobs/training_std": float(np.std(train_flat)) if len(train_flat) > 0 else 0.0, + } + + # Compute diff (key metric for alignment validation) + if len(inf_filtered) > 0 and len(train_flat) > 0: + stats["logprobs/diff"] = stats["logprobs/inference_mean"] - stats["logprobs/training_mean"] + + # At step 0, this diff should be very close to 0 if weights are shared correctly + # A large diff indicates the training model is using different weights than vLLM + + return stats + + def run_training_step( model: torch.nn.Module, optimizer: torch.optim.Optimizer, @@ -137,6 +197,7 @@ def run_training_step( advantage_batches: List[torch.Tensor], temperature_batches: List[torch.Tensor], config: TrainingConfig, + inference_logprobs: Optional[List[np.ndarray]] = None, ) -> dict: """ Run a single training step with gradient accumulation. @@ -146,6 +207,7 @@ def run_training_step( 2. Backward pass with gradient accumulation 3. Gradient clipping 4. Optimizer step + 5. (Optional) Logprob alignment check Args: model: The model to train @@ -155,16 +217,20 @@ def run_training_step( advantage_batches: List of advantage tensors temperature_batches: List of temperature tensors config: Training configuration + inference_logprobs: Optional logprobs from inference for alignment check Returns: Dict of training metrics for this step """ + global _logprob_alignment_stats + total_loss = 0.0 total_pos_logp = 0.0 total_neg_logp = 0.0 total_pos = 0.0 total_neg = 0.0 grad_norm = 0.0 + all_training_logprobs: List[np.ndarray] = [] # Accumulate gradients over micro-batches for tokens, labels, advantages, temperatures in zip( @@ -189,6 +255,10 @@ def run_training_step( total_neg_logp += metrics["neg_logp"] total_pos += metrics["pos_count"] total_neg += metrics["neg_count"] + + # Collect training logprobs for alignment check + if "training_logprobs" in metrics: + all_training_logprobs.append(metrics["training_logprobs"]) # Gradient clipping and optimizer step grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) @@ -202,7 +272,7 @@ def run_training_step( if total_neg > 0: total_neg_logp /= num_batches - return { + result = { "loss": total_loss, "grad_norm": grad_norm.item() if hasattr(grad_norm, 'item') else grad_norm, "pos_logp": total_pos_logp, @@ -210,6 +280,14 @@ def run_training_step( "pos_count": total_pos, "neg_count": total_neg, } + + # Compute logprob alignment stats + if inference_logprobs is not None and all_training_logprobs: + alignment_stats = compute_logprob_alignment(inference_logprobs, all_training_logprobs) + _logprob_alignment_stats.update(alignment_stats) + result["logprob_alignment"] = alignment_stats + + return result def log_metrics( @@ -229,6 +307,8 @@ def log_metrics( extra_metrics: Optional additional metrics to log benchmark: Whether to show timing/benchmark info """ + global _logprob_alignment_stats + # Build timing string (only if benchmark enabled) timing_str = "" if benchmark: @@ -259,6 +339,19 @@ def log_metrics( f" Advantages: +{int(pos_count)} / -{int(neg_count)}, " f"LogP: pos={pos_logp:.3f}, neg={neg_logp:.3f}" ) + + # Show logprob alignment stats (important for shared_vllm validation!) + if "logprob_alignment" in metrics: + alignment = metrics["logprob_alignment"] + if "logprobs/diff" in alignment: + diff = alignment["logprobs/diff"] + inf_mean = alignment.get("logprobs/inference_mean", 0) + train_mean = alignment.get("logprobs/training_mean", 0) + + # At step 0, diff should be ~0 if weights are shared correctly + status = "OK" if abs(diff) < 0.1 else "MISMATCH!" + print(f" LogProb Alignment: inf={inf_mean:.4f}, train={train_mean:.4f}, " + f"diff={diff:.4f} [{status}]") if use_wandb: log_dict = { @@ -272,6 +365,11 @@ def log_metrics( "gpu_memory_gb", "gpu_memory_reserved_gb"]: if key in metrics: log_dict[f"train/{key}"] = metrics[key] + + # Add logprob alignment stats (key for shared_vllm validation!) + if _logprob_alignment_stats: + log_dict.update(_logprob_alignment_stats) + if extra_metrics: log_dict.update(extra_metrics) wandb.log(log_dict, step=step)