diff --git a/example_trainer/checkpointing.py b/example_trainer/checkpointing.py index dec746a5..bf70fbdf 100644 --- a/example_trainer/checkpointing.py +++ b/example_trainer/checkpointing.py @@ -4,20 +4,56 @@ Checkpoint saving utilities for GRPO trainer. Handles saving model checkpoints for different training modes: - Full model checkpoints (legacy and shared_vllm modes) - LoRA adapter checkpoints + +IMPORTANT: For shared_vllm mode, the model parameters are VIEWS into vLLM's +fused tensors (qkv_proj, gate_up_proj). This module handles unfusing them +back to HuggingFace format for safe checkpoint saving. """ import os import shutil +from typing import Dict import torch +def _ensure_contiguous_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: + """ + Create a state dict with contiguous tensors for safe saving. + + This is critical for shared_vllm mode where parameters are views into + vLLM's fused tensors. Views may share storage and not be contiguous, + which can cause issues when saving. + + Returns: + State dict with all tensors made contiguous (copied if necessary) + """ + state_dict = {} + for name, param in model.named_parameters(): + # Check if tensor is a view (non-contiguous or shares storage) + if not param.is_contiguous() or param.storage_offset() != 0: + # Make a contiguous copy - this "unfuses" the view + state_dict[name] = param.detach().clone().contiguous() + else: + state_dict[name] = param.detach() + + # Also include buffers + for name, buffer in model.named_buffers(): + if not buffer.is_contiguous() or buffer.storage_offset() != 0: + state_dict[name] = buffer.detach().clone().contiguous() + else: + state_dict[name] = buffer.detach() + + return state_dict + + def save_checkpoint( model: torch.nn.Module, tokenizer, save_path: str, step: int, is_final: bool = False, + safe_mode: bool = True, ) -> str: """ Save full model checkpoint. @@ -28,6 +64,9 @@ def save_checkpoint( save_path: Base directory for checkpoints step: Current training step is_final: Whether this is the final checkpoint + safe_mode: If True, ensure all tensors are contiguous before saving. + This is important for shared_vllm mode where params are + views into fused vLLM tensors. Returns: Path where checkpoint was saved @@ -43,7 +82,26 @@ def save_checkpoint( shutil.rmtree(checkpoint_path) os.makedirs(checkpoint_path, exist_ok=True) - model.save_pretrained(checkpoint_path) + if safe_mode: + # For shared_vllm mode: ensure views are properly unfused + print(" [Checkpoint] Using safe mode - ensuring contiguous tensors...") + state_dict = _ensure_contiguous_state_dict(model) + + # Count how many were non-contiguous (views into fused tensors) + view_count = sum( + 1 for name, param in model.named_parameters() + if not param.is_contiguous() or param.storage_offset() != 0 + ) + if view_count > 0: + print(f" [Checkpoint] Unfused {view_count} view tensors (qkv/gate_up fusions)") + + # Save state dict manually, then save config separately + torch.save(state_dict, os.path.join(checkpoint_path, "pytorch_model.bin")) + model.config.save_pretrained(checkpoint_path) + else: + # Standard save (may have issues with view tensors) + model.save_pretrained(checkpoint_path) + tokenizer.save_pretrained(checkpoint_path) print(" Checkpoint saved.") diff --git a/example_trainer/trainers.py b/example_trainer/trainers.py index 8cfc9251..c07f322f 100644 --- a/example_trainer/trainers.py +++ b/example_trainer/trainers.py @@ -90,7 +90,8 @@ def train_legacy(config: TrainingConfig): # Fetch data data_fetch_start = time.time() if len(batches) == 0: - batches = get_data(config.batch_size, config.seq_len, config.atropos_url) + batches, _ = get_data(config.batch_size, config.seq_len, config.atropos_url, + extract_inference_logprobs=False) token_batches, label_batches, advantage_batches, temperature_batches = batches.pop(0) data_fetch_time = time.time() - data_fetch_start benchmark_stats["data_fetch_times"].append(data_fetch_time) @@ -206,26 +207,34 @@ def train_shared_vllm(config: TrainingConfig): # === Training Loop === batches = [] + inference_logprobs = None for step in range(config.training_steps): print(f"\nStep {step+1}/{config.training_steps}") - # Fetch data + # Fetch data (with inference logprobs for alignment check) data_fetch_start = time.time() if len(batches) == 0: - batches = get_data(config.batch_size, config.seq_len, config.atropos_url) + batches, inference_logprobs = get_data( + config.batch_size, config.seq_len, config.atropos_url, + extract_inference_logprobs=True, # Enable logprob alignment check + ) token_batches, label_batches, advantage_batches, temperature_batches = batches.pop(0) data_fetch_time = time.time() - data_fetch_start benchmark_stats["data_fetch_times"].append(data_fetch_time) - # Training step + # Training step (with logprob alignment check) step_start = time.time() metrics = run_training_step( model, optimizer, token_batches, label_batches, advantage_batches, temperature_batches, config, + inference_logprobs=inference_logprobs, # Pass for alignment validation ) step_time = time.time() - step_start benchmark_stats["step_times"].append(step_time) + + # Clear inference logprobs after use (will be refreshed with new data) + inference_logprobs = None # GPU memory tracking gpu_mem_gb = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0 @@ -335,7 +344,8 @@ def train_lora(config: TrainingConfig): # Fetch data data_fetch_start = time.time() if len(batches) == 0: - batches = get_data(config.batch_size, config.seq_len, config.atropos_url) + batches, _ = get_data(config.batch_size, config.seq_len, config.atropos_url, + extract_inference_logprobs=False) token_batches, label_batches, advantage_batches, temperature_batches = batches.pop(0) data_fetch_time = time.time() - data_fetch_start benchmark_stats["data_fetch_times"].append(data_fetch_time)