Bloat reduction

This commit is contained in:
Jai Suphavadeeprasit 2026-01-17 12:05:03 -05:00
parent 01af0777bc
commit ab8d2f2dac

View file

@ -1,35 +1,31 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
Custom vLLM API server with shared memory weight updates. Custom vLLM API server with CUDA IPC shared memory support.
This server extends the standard vLLM API with: This server extends the standard vLLM API with:
- Shared-weight training via NCCL (patched GPUModelRunner) - Single-copy mode: Exports CUDA IPC handles so trainer can share vLLM's tensors
- LoRA hot-swap without server restart - LoRA hot-swap without server restart
- Weight synchronization endpoints - Bridge endpoints for coordination
ARCHITECTURE: ARCHITECTURE (Single-Copy Mode):
When --enable-shared-weights is set: When VLLM_ENABLE_SHARED_WEIGHTS=1:
1. vLLM's GPUModelRunner is patched to call share_memory_() on weights 1. vLLM's GPUModelRunner is patched BEFORE loading
2. A daemon process is spawned that receives NCCL weight updates 2. Patched runner exports CUDA IPC handles to vllm_bridge_config.json
3. Trainer broadcasts weights -> daemon copies to shared memory -> vLLM uses immediately 3. Trainer reads IPC handles and attaches to the SAME tensors
4. optimizer.step() updates weights in-place - vLLM sees changes immediately!
SHARED MEMORY (via share_memory_()) SINGLE GPU (True Shared Memory)
Model Weights Model Weights (ONE copy!)
(accessible from MULTIPLE processes) (accessible via CUDA IPC handles)
Reads Writes Reads (inference) Writes (train)
vLLM Worker weight_updater vLLM Worker Trainer Process
(inference) daemon process (attached via IPC)
NCCL
Trainer Process
CRITICAL: Patches must be applied BEFORE importing vLLM! CRITICAL: Patches must be applied BEFORE importing vLLM!
@ -44,7 +40,6 @@ import os
import ssl import ssl
import sys import sys
import threading import threading
import time
from argparse import Namespace from argparse import Namespace
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -145,11 +140,9 @@ engine: Optional[AsyncLLM] = None
@dataclass @dataclass
class BridgeState: class BridgeState:
"""State for weight bridge synchronization.""" """State for shared memory and LoRA."""
enabled: bool = False
update_count: int = 0 update_count: int = 0
last_update_time: float = 0.0 last_update_time: float = 0.0
rendezvous_info: Dict[str, Any] = field(default_factory=dict)
lock: threading.Lock = field(default_factory=threading.Lock) lock: threading.Lock = field(default_factory=threading.Lock)
# LoRA state # LoRA state
@ -169,24 +162,10 @@ class BridgeInfoResponse(BaseModel):
enabled: bool enabled: bool
update_count: int update_count: int
last_update_time: float last_update_time: float
rendezvous_info: Dict[str, Any]
model_name: str model_name: str
device: str device: str
class BridgeInitRequest(BaseModel):
master_addr: str
master_port: int
world_size: int
trainer_ranks: List[int]
class WeightUpdateNotification(BaseModel):
update_count: int
trainer_rank: int
timestamp: float
class LoraLoadRequest(BaseModel): class LoraLoadRequest(BaseModel):
adapter_path: str adapter_path: str
adapter_name: Optional[str] = None adapter_name: Optional[str] = None
@ -318,61 +297,15 @@ async def bridge_info() -> JSONResponse:
model_name = str(engine.model_config.model) if hasattr(engine, "model_config") else "unknown" model_name = str(engine.model_config.model) if hasattr(engine, "model_config") else "unknown"
return JSONResponse({ return JSONResponse({
"enabled": bridge_state.enabled or PATCHES_APPLIED, "enabled": PATCHES_APPLIED,
"shared_weights": PATCHES_APPLIED, "shared_weights": PATCHES_APPLIED,
"update_count": bridge_state.update_count, "update_count": bridge_state.update_count,
"last_update_time": bridge_state.last_update_time, "last_update_time": bridge_state.last_update_time,
"rendezvous_info": bridge_state.rendezvous_info,
"model_name": model_name, "model_name": model_name,
"device": "cuda" if torch.cuda.is_available() else "cpu", "device": "cuda" if torch.cuda.is_available() else "cpu",
}) })
@app.post("/bridge/init")
async def bridge_init(request: BridgeInitRequest) -> JSONResponse:
"""Initialize the weight bridge."""
with bridge_state.lock:
bridge_state.enabled = True
bridge_state.rendezvous_info = {
"master_addr": request.master_addr,
"master_port": request.master_port,
"world_size": request.world_size,
"trainer_ranks": request.trainer_ranks,
}
logger.info(f"Bridge initialized: {bridge_state.rendezvous_info}")
return JSONResponse({
"status": "ok",
"message": "Weight bridge initialized",
"shared_weights_enabled": PATCHES_APPLIED,
})
@app.post("/bridge/notify_update")
async def bridge_notify_update(notification: WeightUpdateNotification) -> JSONResponse:
"""
Notification that trainer has updated weights.
In shared memory mode (PATCHES_APPLIED=True), updates are automatic
via the NCCL daemon. This endpoint is for logging/coordination.
"""
with bridge_state.lock:
bridge_state.update_count = notification.update_count
bridge_state.last_update_time = notification.timestamp
if PATCHES_APPLIED:
logger.debug(f"Weight update #{notification.update_count} (shared memory)")
else:
logger.info(f"Weight update #{notification.update_count} (HTTP notification only)")
return JSONResponse({
"status": "ok",
"update_count": bridge_state.update_count,
"shared_weights": PATCHES_APPLIED,
})
@app.get("/bridge/state_dict_info") @app.get("/bridge/state_dict_info")
async def bridge_state_dict_info() -> JSONResponse: async def bridge_state_dict_info() -> JSONResponse:
"""Get model parameter information.""" """Get model parameter information."""
@ -391,17 +324,8 @@ async def bridge_state_dict_info() -> JSONResponse:
return JSONResponse({"error": str(e)}) return JSONResponse({"error": str(e)})
@app.post("/bridge/disable")
async def bridge_disable() -> JSONResponse:
"""Disable the weight bridge."""
with bridge_state.lock:
bridge_state.enabled = False
logger.info("Bridge disabled")
return JSONResponse({"status": "ok"})
# ============================================================================= # =============================================================================
# Pause/Resume Endpoints (for weight updates) # Pause/Resume Endpoints
# ============================================================================= # =============================================================================