diff --git a/example_trainer/vllm_api_server.py b/example_trainer/vllm_api_server.py index 1de50c29..115135c0 100644 --- a/example_trainer/vllm_api_server.py +++ b/example_trainer/vllm_api_server.py @@ -1,35 +1,31 @@ #!/usr/bin/env python3 """ -Custom vLLM API server with shared memory weight updates. +Custom vLLM API server with CUDA IPC shared memory support. This server extends the standard vLLM API with: -- Shared-weight training via NCCL (patched GPUModelRunner) +- Single-copy mode: Exports CUDA IPC handles so trainer can share vLLM's tensors - LoRA hot-swap without server restart -- Weight synchronization endpoints +- Bridge endpoints for coordination -ARCHITECTURE: - When --enable-shared-weights is set: - 1. vLLM's GPUModelRunner is patched to call share_memory_() on weights - 2. A daemon process is spawned that receives NCCL weight updates - 3. Trainer broadcasts weights -> daemon copies to shared memory -> vLLM uses immediately +ARCHITECTURE (Single-Copy Mode): + When VLLM_ENABLE_SHARED_WEIGHTS=1: + 1. vLLM's GPUModelRunner is patched BEFORE loading + 2. Patched runner exports CUDA IPC handles to vllm_bridge_config.json + 3. Trainer reads IPC handles and attaches to the SAME tensors + 4. optimizer.step() updates weights in-place - vLLM sees changes immediately! ┌─────────────────────────────────────────────────────────────────────────┐ - │ SHARED MEMORY (via share_memory_()) │ + │ SINGLE GPU (True Shared Memory) │ │ ┌─────────────────────────────────────────────────────────────────┐ │ - │ │ Model Weights │ │ - │ │ (accessible from MULTIPLE processes) │ │ + │ │ Model Weights (ONE copy!) │ │ + │ │ (accessible via CUDA IPC handles) │ │ │ └─────────────────────────────────────────────────────────────────┘ │ │ ▲ ▲ │ - │ │ Reads │ Writes │ + │ │ Reads (inference) │ Writes (train) │ │ ┌────────┴────────┐ ┌───────────┴───────────┐ │ - │ │ vLLM Worker │ │ weight_updater │ │ - │ │ (inference) │ │ daemon process │ │ - │ └─────────────────┘ └───────────┬───────────┘ │ - │ │ NCCL │ - │ ▼ │ - │ ┌─────────────────────┐ │ - │ │ Trainer Process │ │ - │ └─────────────────────┘ │ + │ │ vLLM Worker │ │ Trainer Process │ │ + │ │ │ │ (attached via IPC) │ │ + │ └─────────────────┘ └───────────────────────┘ │ └─────────────────────────────────────────────────────────────────────────┘ CRITICAL: Patches must be applied BEFORE importing vLLM! @@ -44,7 +40,6 @@ import os import ssl import sys import threading -import time from argparse import Namespace from collections.abc import AsyncGenerator from dataclasses import dataclass, field @@ -145,11 +140,9 @@ engine: Optional[AsyncLLM] = None @dataclass class BridgeState: - """State for weight bridge synchronization.""" - enabled: bool = False + """State for shared memory and LoRA.""" update_count: int = 0 last_update_time: float = 0.0 - rendezvous_info: Dict[str, Any] = field(default_factory=dict) lock: threading.Lock = field(default_factory=threading.Lock) # LoRA state @@ -169,24 +162,10 @@ class BridgeInfoResponse(BaseModel): enabled: bool update_count: int last_update_time: float - rendezvous_info: Dict[str, Any] model_name: str device: str -class BridgeInitRequest(BaseModel): - master_addr: str - master_port: int - world_size: int - trainer_ranks: List[int] - - -class WeightUpdateNotification(BaseModel): - update_count: int - trainer_rank: int - timestamp: float - - class LoraLoadRequest(BaseModel): adapter_path: str adapter_name: Optional[str] = None @@ -318,61 +297,15 @@ async def bridge_info() -> JSONResponse: model_name = str(engine.model_config.model) if hasattr(engine, "model_config") else "unknown" return JSONResponse({ - "enabled": bridge_state.enabled or PATCHES_APPLIED, + "enabled": PATCHES_APPLIED, "shared_weights": PATCHES_APPLIED, "update_count": bridge_state.update_count, "last_update_time": bridge_state.last_update_time, - "rendezvous_info": bridge_state.rendezvous_info, "model_name": model_name, "device": "cuda" if torch.cuda.is_available() else "cpu", }) -@app.post("/bridge/init") -async def bridge_init(request: BridgeInitRequest) -> JSONResponse: - """Initialize the weight bridge.""" - with bridge_state.lock: - bridge_state.enabled = True - bridge_state.rendezvous_info = { - "master_addr": request.master_addr, - "master_port": request.master_port, - "world_size": request.world_size, - "trainer_ranks": request.trainer_ranks, - } - - logger.info(f"Bridge initialized: {bridge_state.rendezvous_info}") - - return JSONResponse({ - "status": "ok", - "message": "Weight bridge initialized", - "shared_weights_enabled": PATCHES_APPLIED, - }) - - -@app.post("/bridge/notify_update") -async def bridge_notify_update(notification: WeightUpdateNotification) -> JSONResponse: - """ - Notification that trainer has updated weights. - - In shared memory mode (PATCHES_APPLIED=True), updates are automatic - via the NCCL daemon. This endpoint is for logging/coordination. - """ - with bridge_state.lock: - bridge_state.update_count = notification.update_count - bridge_state.last_update_time = notification.timestamp - - if PATCHES_APPLIED: - logger.debug(f"Weight update #{notification.update_count} (shared memory)") - else: - logger.info(f"Weight update #{notification.update_count} (HTTP notification only)") - - return JSONResponse({ - "status": "ok", - "update_count": bridge_state.update_count, - "shared_weights": PATCHES_APPLIED, - }) - - @app.get("/bridge/state_dict_info") async def bridge_state_dict_info() -> JSONResponse: """Get model parameter information.""" @@ -391,17 +324,8 @@ async def bridge_state_dict_info() -> JSONResponse: return JSONResponse({"error": str(e)}) -@app.post("/bridge/disable") -async def bridge_disable() -> JSONResponse: - """Disable the weight bridge.""" - with bridge_state.lock: - bridge_state.enabled = False - logger.info("Bridge disabled") - return JSONResponse({"status": "ok"}) - - # ============================================================================= -# Pause/Resume Endpoints (for weight updates) +# Pause/Resume Endpoints # =============================================================================