Bloat reduction

This commit is contained in:
Jai Suphavadeeprasit 2026-01-17 12:05:03 -05:00
parent 01af0777bc
commit ab8d2f2dac

View file

@ -1,35 +1,31 @@
#!/usr/bin/env python3
"""
Custom vLLM API server with shared memory weight updates.
Custom vLLM API server with CUDA IPC shared memory support.
This server extends the standard vLLM API with:
- Shared-weight training via NCCL (patched GPUModelRunner)
- Single-copy mode: Exports CUDA IPC handles so trainer can share vLLM's tensors
- LoRA hot-swap without server restart
- Weight synchronization endpoints
- Bridge endpoints for coordination
ARCHITECTURE:
When --enable-shared-weights is set:
1. vLLM's GPUModelRunner is patched to call share_memory_() on weights
2. A daemon process is spawned that receives NCCL weight updates
3. Trainer broadcasts weights -> daemon copies to shared memory -> vLLM uses immediately
ARCHITECTURE (Single-Copy Mode):
When VLLM_ENABLE_SHARED_WEIGHTS=1:
1. vLLM's GPUModelRunner is patched BEFORE loading
2. Patched runner exports CUDA IPC handles to vllm_bridge_config.json
3. Trainer reads IPC handles and attaches to the SAME tensors
4. optimizer.step() updates weights in-place - vLLM sees changes immediately!
SHARED MEMORY (via share_memory_())
SINGLE GPU (True Shared Memory)
Model Weights
(accessible from MULTIPLE processes)
Model Weights (ONE copy!)
(accessible via CUDA IPC handles)
Reads Writes
Reads (inference) Writes (train)
vLLM Worker weight_updater
(inference) daemon process
NCCL
Trainer Process
vLLM Worker Trainer Process
(attached via IPC)
CRITICAL: Patches must be applied BEFORE importing vLLM!
@ -44,7 +40,6 @@ import os
import ssl
import sys
import threading
import time
from argparse import Namespace
from collections.abc import AsyncGenerator
from dataclasses import dataclass, field
@ -145,11 +140,9 @@ engine: Optional[AsyncLLM] = None
@dataclass
class BridgeState:
"""State for weight bridge synchronization."""
enabled: bool = False
"""State for shared memory and LoRA."""
update_count: int = 0
last_update_time: float = 0.0
rendezvous_info: Dict[str, Any] = field(default_factory=dict)
lock: threading.Lock = field(default_factory=threading.Lock)
# LoRA state
@ -169,24 +162,10 @@ class BridgeInfoResponse(BaseModel):
enabled: bool
update_count: int
last_update_time: float
rendezvous_info: Dict[str, Any]
model_name: str
device: str
class BridgeInitRequest(BaseModel):
master_addr: str
master_port: int
world_size: int
trainer_ranks: List[int]
class WeightUpdateNotification(BaseModel):
update_count: int
trainer_rank: int
timestamp: float
class LoraLoadRequest(BaseModel):
adapter_path: str
adapter_name: Optional[str] = None
@ -318,61 +297,15 @@ async def bridge_info() -> JSONResponse:
model_name = str(engine.model_config.model) if hasattr(engine, "model_config") else "unknown"
return JSONResponse({
"enabled": bridge_state.enabled or PATCHES_APPLIED,
"enabled": PATCHES_APPLIED,
"shared_weights": PATCHES_APPLIED,
"update_count": bridge_state.update_count,
"last_update_time": bridge_state.last_update_time,
"rendezvous_info": bridge_state.rendezvous_info,
"model_name": model_name,
"device": "cuda" if torch.cuda.is_available() else "cpu",
})
@app.post("/bridge/init")
async def bridge_init(request: BridgeInitRequest) -> JSONResponse:
"""Initialize the weight bridge."""
with bridge_state.lock:
bridge_state.enabled = True
bridge_state.rendezvous_info = {
"master_addr": request.master_addr,
"master_port": request.master_port,
"world_size": request.world_size,
"trainer_ranks": request.trainer_ranks,
}
logger.info(f"Bridge initialized: {bridge_state.rendezvous_info}")
return JSONResponse({
"status": "ok",
"message": "Weight bridge initialized",
"shared_weights_enabled": PATCHES_APPLIED,
})
@app.post("/bridge/notify_update")
async def bridge_notify_update(notification: WeightUpdateNotification) -> JSONResponse:
"""
Notification that trainer has updated weights.
In shared memory mode (PATCHES_APPLIED=True), updates are automatic
via the NCCL daemon. This endpoint is for logging/coordination.
"""
with bridge_state.lock:
bridge_state.update_count = notification.update_count
bridge_state.last_update_time = notification.timestamp
if PATCHES_APPLIED:
logger.debug(f"Weight update #{notification.update_count} (shared memory)")
else:
logger.info(f"Weight update #{notification.update_count} (HTTP notification only)")
return JSONResponse({
"status": "ok",
"update_count": bridge_state.update_count,
"shared_weights": PATCHES_APPLIED,
})
@app.get("/bridge/state_dict_info")
async def bridge_state_dict_info() -> JSONResponse:
"""Get model parameter information."""
@ -391,17 +324,8 @@ async def bridge_state_dict_info() -> JSONResponse:
return JSONResponse({"error": str(e)})
@app.post("/bridge/disable")
async def bridge_disable() -> JSONResponse:
"""Disable the weight bridge."""
with bridge_state.lock:
bridge_state.enabled = False
logger.info("Bridge disabled")
return JSONResponse({"status": "ok"})
# =============================================================================
# Pause/Resume Endpoints (for weight updates)
# Pause/Resume Endpoints
# =============================================================================