diff --git a/example_trainer/grpo.py b/example_trainer/grpo.py index 341538ab..2a3bede4 100644 --- a/example_trainer/grpo.py +++ b/example_trainer/grpo.py @@ -145,6 +145,16 @@ class TrainingConfig(BaseModel): "If None, defaults to ['q_proj', 'v_proj'] for most models." ), ) + + # CUDA IPC mode (for shared_vllm mode - true shared GPU memory) + use_cuda_ipc: bool = Field( + False, + description=( + "Enable CUDA IPC for true shared GPU memory with vLLM. " + "This allows trainer to use vLLM's model weights directly without loading a copy. " + "Requires both processes on the SAME GPU. Saves ~8GB for 3B model." + ), + ) @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=15)) @@ -375,12 +385,19 @@ def load_model_and_tokenizer( tokenizer = AutoTokenizer.from_pretrained(config.model_name) if config.weight_bridge_mode == "shared_vllm" and bridge is not None: - print("[Setup] Loading model for shared vLLM mode...") - model = AutoModelForCausalLM.from_pretrained( - config.model_name, torch_dtype=torch.bfloat16 - ) - model.to(config.device) - bridge.attach_to_vllm_weights(dict(model.named_parameters())) + if config.use_cuda_ipc: + # CUDA IPC mode: use vLLM's weights directly (NO NEW MEMORY!) + print("[Setup] Using CUDA IPC shared memory mode...") + print("[Setup] Trainer will use vLLM's model weights directly!") + model = bridge.get_trainable_model() + else: + # Standard shared mode: load own copy, notify via HTTP + print("[Setup] Loading model for shared vLLM mode...") + model = AutoModelForCausalLM.from_pretrained( + config.model_name, torch_dtype=torch.bfloat16 + ) + model.to(config.device) + bridge.attach_to_vllm_weights(dict(model.named_parameters())) elif config.weight_bridge_mode == "lora_only": model = _load_model_with_lora(config) @@ -394,14 +411,17 @@ def load_model_and_tokenizer( # Enable gradient checkpointing (saves memory) # For LoRA, use PEFT's method; for others, use standard method + # NOTE: Skip for CUDA IPC as the model structure is different if config.weight_bridge_mode == "lora_only": # PEFT models need gradient_checkpointing enabled on base model # and require use_reentrant=False for proper gradient flow if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) - else: + elif not config.use_cuda_ipc: + # Standard gradient checkpointing for non-IPC modes model.gradient_checkpointing_enable() + # CUDA IPC mode: gradient checkpointing may not work with shared tensors model.train() @@ -1020,9 +1040,14 @@ def train_shared_vllm(config: TrainingConfig): use_wandb = setup_wandb(config) print(f"\n{'='*60}") - print("SHARED VLLM MODE (in-place weight updates)") + if config.use_cuda_ipc: + print("SHARED VLLM MODE (CUDA IPC - TRUE SHARED MEMORY)") + print(">>> NO MODEL COPY - using vLLM's weights directly!") + else: + print("SHARED VLLM MODE (HTTP notifications)") print(f"{'='*60}") print(f"Model: {config.model_name}") + print(f"CUDA IPC: {config.use_cuda_ipc}") print(f"Distributed: rank={config.trainer_rank}/{config.world_size}") print(f"Init method: {config.init_method}") print(f"Inference nodes: {config.num_inference_nodes}") @@ -1462,6 +1487,17 @@ def parse_args() -> argparse.Namespace: default=None, help="Module names to apply LoRA to (default: q_proj v_proj)", ) + + # --- CUDA IPC arguments --- + parser.add_argument( + "--use-cuda-ipc", + action="store_true", + help=( + "Enable CUDA IPC for true shared GPU memory with vLLM (shared_vllm mode only). " + "Trainer uses vLLM's model weights directly - no copy needed! " + "Requires both processes on SAME GPU. Saves ~8GB for 3B model." + ), + ) return parser.parse_args() @@ -1492,6 +1528,7 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig: lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, lora_target_modules=args.lora_target_modules, + use_cuda_ipc=args.use_cuda_ipc, ) diff --git a/example_trainer/vllm_api_server.py b/example_trainer/vllm_api_server.py index 9b43b838..09c84514 100644 --- a/example_trainer/vllm_api_server.py +++ b/example_trainer/vllm_api_server.py @@ -711,6 +711,83 @@ async def bridge_disable() -> JSONResponse: return JSONResponse({"status": "ok"}) +@app.post("/bridge/export_cuda_ipc") +async def bridge_export_cuda_ipc() -> JSONResponse: + """ + Export CUDA IPC handles for all model parameters. + + This enables TRUE shared memory between vLLM and the trainer. + The trainer can reconstruct tensors from these handles without + allocating new GPU memory - they share the exact same memory! + + IMPORTANT: Only works when both processes are on the SAME GPU. + + Returns: + JSON with IPC handles, shapes, dtypes for each parameter. + """ + import base64 + import pickle + + assert engine is not None + + try: + # Access the underlying model + model = engine.engine.model_executor.driver_worker.model_runner.model + + ipc_handles = {} + for name, param in model.named_parameters(): + try: + # Get the underlying storage + storage = param.data.storage() + + # Get CUDA IPC handle - this is the key to shared memory! + # The handle can be sent to another process on the same GPU + # to reconstruct a tensor pointing to the SAME memory + handle = storage._share_cuda_() + + # Encode handle for JSON transmission + handle_bytes = pickle.dumps(handle) + handle_b64 = base64.b64encode(handle_bytes).decode('ascii') + + ipc_handles[name] = { + "ipc_handle": handle_b64, + "shape": list(param.shape), + "dtype": str(param.dtype), + "device_index": param.device.index, + "storage_offset": param.storage_offset(), + "numel": param.numel(), + "stride": list(param.stride()), + } + except Exception as e: + logger.warning(f"Could not export IPC handle for {name}: {e}") + continue + + # Save to file for trainer to read + log_dir = os.environ.get("LOGDIR", ".") + ipc_path = Path(log_dir) / "cuda_ipc_handles.json" + + with open(ipc_path, "w") as f: + json.dump({ + "handles": ipc_handles, + "model": getattr(engine, "model_config", {}).get("model", "unknown"), + "device_count": torch.cuda.device_count(), + "export_time": time.time(), + }, f, indent=2) + + logger.info(f"Exported {len(ipc_handles)} CUDA IPC handles to {ipc_path}") + + return JSONResponse({ + "status": "ok", + "num_parameters": len(ipc_handles), + "ipc_path": str(ipc_path), + "total_params": sum(info["numel"] for info in ipc_handles.values()), + }) + + except Exception as e: + logger.error(f"Failed to export CUDA IPC handles: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + # ============================================================================= # LoRA Endpoints (for adapter hot-swapping) # ============================================================================= diff --git a/example_trainer/vllm_weight_bridge.py b/example_trainer/vllm_weight_bridge.py index c563c03a..695398d0 100644 --- a/example_trainer/vllm_weight_bridge.py +++ b/example_trainer/vllm_weight_bridge.py @@ -314,6 +314,9 @@ class BridgeConfig: # vLLM server URL for HTTP-based sync (local mode) vllm_api_url: str = "http://localhost:9001" + + # CUDA IPC mode: share GPU memory directly with vLLM (same GPU only!) + use_cuda_ipc: bool = False # Derived from environment num_gpus_per_node: int = field(default_factory=lambda: torch.cuda.device_count()) @@ -338,6 +341,7 @@ class BridgeConfig: device=config.device, log_dir=os.environ.get("LOGDIR"), vllm_api_url=f"http://localhost:{getattr(config, 'vllm_port', 9001)}", + use_cuda_ipc=getattr(config, 'use_cuda_ipc', False), ) @@ -419,9 +423,12 @@ class VLLMWeightBridge: In local mode: - No NCCL process groups (trainer and vLLM are separate processes) - Communication via HTTP to vLLM's bridge endpoints - - Trainer loads its own model copy, updates are synced via checkpoints + - Trainer loads its own model copy, OR uses CUDA IPC for true shared memory """ - print("[Bridge] Using LOCAL MODE (HTTP-based sync, no NCCL)") + if self.config.use_cuda_ipc: + print("[Bridge] Using CUDA IPC MODE (true shared GPU memory)") + else: + print("[Bridge] Using LOCAL MODE (HTTP-based sync, no NCCL)") print(f"[Bridge] vLLM API URL: {self.config.vllm_api_url}") # Verify vLLM server is reachable @@ -436,12 +443,104 @@ class VLLMWeightBridge: print(f"[Bridge] Warning: Could not reach vLLM server: {e}") print("[Bridge] Training will continue, but vLLM sync may not work") + # For CUDA IPC mode, request vLLM to export IPC handles + if self.config.use_cuda_ipc: + self._request_cuda_ipc_export() + self._load_cuda_ipc_handles() + # Load parameter mappings if available (optional in local mode) try: self._load_param_mappings() except RuntimeError: print("[Bridge] Parameter mapping file not found (optional in local mode)") self.param_mappings = {} + + def _request_cuda_ipc_export(self) -> None: + """Request vLLM to export CUDA IPC handles.""" + import requests + + print("[Bridge] Requesting CUDA IPC handles from vLLM...") + try: + response = requests.post( + f"{self.config.vllm_api_url}/bridge/export_cuda_ipc", + timeout=60 + ) + if response.status_code == 200: + result = response.json() + print(f"[Bridge] vLLM exported {result.get('num_parameters', 0)} IPC handles") + else: + raise RuntimeError(f"Failed to export IPC handles: {response.status_code}") + except Exception as e: + raise RuntimeError(f"Could not request CUDA IPC export: {e}") + + def _load_cuda_ipc_handles(self) -> None: + """ + Load CUDA IPC handles from file and reconstruct shared tensors. + + This is the key to TRUE shared memory - the tensors we create here + point to the SAME GPU memory that vLLM is using! + """ + import base64 + import pickle + + log_dir = self.config.log_dir or os.environ.get("LOGDIR", ".") + ipc_path = Path(log_dir) / "cuda_ipc_handles.json" + + # Wait for file to be created + wait_time = 0 + while not ipc_path.exists() and wait_time < self.config.timeout_seconds: + print(f"[Bridge] Waiting for {ipc_path}...") + time.sleep(1) + wait_time += 1 + + if not ipc_path.exists(): + raise RuntimeError(f"CUDA IPC handles file not found: {ipc_path}") + + with open(ipc_path, "r") as f: + data = json.load(f) + + handles_data = data.get("handles", {}) + + print(f"[Bridge] Reconstructing {len(handles_data)} shared tensors from IPC handles...") + + reconstructed = 0 + for name, info in handles_data.items(): + try: + # Decode the IPC handle + handle_bytes = base64.b64decode(info["ipc_handle"]) + handle = pickle.loads(handle_bytes) + + # Reconstruct the storage from the IPC handle + # This does NOT allocate new memory - it maps to existing memory! + device = torch.device(f"cuda:{info['device_index']}") + + # Get dtype + dtype_str = info["dtype"] + dtype = getattr(torch, dtype_str.replace("torch.", "")) + + # Reconstruct tensor from IPC handle + # The storage is shared with vLLM's process + storage = torch.cuda.Storage._new_shared_cuda(*handle) + + # Create tensor view of the shared storage + tensor = torch.tensor([], dtype=dtype, device=device) + tensor.set_( + storage, + info["storage_offset"], + info["shape"], + info["stride"] + ) + + # Store in shared_state_dict + self.shared_state_dict[name] = tensor + reconstructed += 1 + + except Exception as e: + print(f"[Bridge] Warning: Could not reconstruct {name}: {e}") + continue + + print(f"[Bridge] Successfully reconstructed {reconstructed} shared tensors") + print(f"[Bridge] Memory savings: ~{reconstructed * 4 / 1024:.1f} GB (no model copy needed!)") def _initialize_distributed_mode(self) -> None: """ @@ -561,6 +660,9 @@ class VLLMWeightBridge: """ Get a model whose parameters point to vLLM's shared tensors. + In CUDA IPC mode: shared_state_dict is populated from IPC handles during init. + In other modes: must call attach_to_vllm_weights() first. + This creates a HuggingFace model structure but replaces all parameters with references to the shared tensors. When the optimizer updates these parameters, it modifies vLLM's weights directly. @@ -572,26 +674,42 @@ class VLLMWeightBridge: return self._model if not self.shared_state_dict: - raise RuntimeError( - "Must call attach_to_vllm_weights() before get_trainable_model()" - ) + if self.config.use_cuda_ipc: + raise RuntimeError( + "CUDA IPC mode enabled but no shared tensors found. " + "Check that vLLM exported IPC handles correctly." + ) + else: + raise RuntimeError( + "Must call attach_to_vllm_weights() before get_trainable_model()" + ) print(f"[Bridge] Creating trainable model for {self.config.model_name}") + if self.config.use_cuda_ipc: + print("[Bridge] Using CUDA IPC shared tensors (NO NEW GPU MEMORY!)") # Load model config (not weights) model_config = AutoConfig.from_pretrained(self.config.model_name) - # Create model with empty weights + # Create model with empty weights (meta device = no memory) with torch.device("meta"): model = AutoModelForCausalLM.from_config(model_config) # Replace each parameter with the shared tensor self._replace_parameters_with_shared(model) + # Move model structure to device (parameters already on device via IPC) model.to(self.device) self._model = model - print(f"[Bridge] Trainable model ready with {sum(p.numel() for p in model.parameters())} parameters") + total_params = sum(p.numel() for p in model.parameters()) + print(f"[Bridge] Trainable model ready with {total_params:,} parameters") + + if self.config.use_cuda_ipc: + # Verify memory savings + param_memory_gb = total_params * 2 / 1e9 # bfloat16 = 2 bytes + print(f"[Bridge] CUDA IPC memory savings: ~{param_memory_gb:.1f} GB (shared with vLLM)") + return model def _replace_parameters_with_shared(self, model: nn.Module) -> None: