mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Bloat reduction
This commit is contained in:
parent
01af0777bc
commit
ab8d2f2dac
1 changed files with 19 additions and 95 deletions
|
|
@ -1,35 +1,31 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Custom vLLM API server with shared memory weight updates.
|
||||
Custom vLLM API server with CUDA IPC shared memory support.
|
||||
|
||||
This server extends the standard vLLM API with:
|
||||
- Shared-weight training via NCCL (patched GPUModelRunner)
|
||||
- Single-copy mode: Exports CUDA IPC handles so trainer can share vLLM's tensors
|
||||
- LoRA hot-swap without server restart
|
||||
- Weight synchronization endpoints
|
||||
- Bridge endpoints for coordination
|
||||
|
||||
ARCHITECTURE:
|
||||
When --enable-shared-weights is set:
|
||||
1. vLLM's GPUModelRunner is patched to call share_memory_() on weights
|
||||
2. A daemon process is spawned that receives NCCL weight updates
|
||||
3. Trainer broadcasts weights -> daemon copies to shared memory -> vLLM uses immediately
|
||||
ARCHITECTURE (Single-Copy Mode):
|
||||
When VLLM_ENABLE_SHARED_WEIGHTS=1:
|
||||
1. vLLM's GPUModelRunner is patched BEFORE loading
|
||||
2. Patched runner exports CUDA IPC handles to vllm_bridge_config.json
|
||||
3. Trainer reads IPC handles and attaches to the SAME tensors
|
||||
4. optimizer.step() updates weights in-place - vLLM sees changes immediately!
|
||||
|
||||
┌─────────────────────────────────────────────────────────────────────────┐
|
||||
│ SHARED MEMORY (via share_memory_()) │
|
||||
│ SINGLE GPU (True Shared Memory) │
|
||||
│ ┌─────────────────────────────────────────────────────────────────┐ │
|
||||
│ │ Model Weights │ │
|
||||
│ │ (accessible from MULTIPLE processes) │ │
|
||||
│ │ Model Weights (ONE copy!) │ │
|
||||
│ │ (accessible via CUDA IPC handles) │ │
|
||||
│ └─────────────────────────────────────────────────────────────────┘ │
|
||||
│ ▲ ▲ │
|
||||
│ │ Reads │ Writes │
|
||||
│ │ Reads (inference) │ Writes (train) │
|
||||
│ ┌────────┴────────┐ ┌───────────┴───────────┐ │
|
||||
│ │ vLLM Worker │ │ weight_updater │ │
|
||||
│ │ (inference) │ │ daemon process │ │
|
||||
│ └─────────────────┘ └───────────┬───────────┘ │
|
||||
│ │ NCCL │
|
||||
│ ▼ │
|
||||
│ ┌─────────────────────┐ │
|
||||
│ │ Trainer Process │ │
|
||||
│ └─────────────────────┘ │
|
||||
│ │ vLLM Worker │ │ Trainer Process │ │
|
||||
│ │ │ │ (attached via IPC) │ │
|
||||
│ └─────────────────┘ └───────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────────────┘
|
||||
|
||||
CRITICAL: Patches must be applied BEFORE importing vLLM!
|
||||
|
|
@ -44,7 +40,6 @@ import os
|
|||
import ssl
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from argparse import Namespace
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass, field
|
||||
|
|
@ -145,11 +140,9 @@ engine: Optional[AsyncLLM] = None
|
|||
|
||||
@dataclass
|
||||
class BridgeState:
|
||||
"""State for weight bridge synchronization."""
|
||||
enabled: bool = False
|
||||
"""State for shared memory and LoRA."""
|
||||
update_count: int = 0
|
||||
last_update_time: float = 0.0
|
||||
rendezvous_info: Dict[str, Any] = field(default_factory=dict)
|
||||
lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
|
||||
# LoRA state
|
||||
|
|
@ -169,24 +162,10 @@ class BridgeInfoResponse(BaseModel):
|
|||
enabled: bool
|
||||
update_count: int
|
||||
last_update_time: float
|
||||
rendezvous_info: Dict[str, Any]
|
||||
model_name: str
|
||||
device: str
|
||||
|
||||
|
||||
class BridgeInitRequest(BaseModel):
|
||||
master_addr: str
|
||||
master_port: int
|
||||
world_size: int
|
||||
trainer_ranks: List[int]
|
||||
|
||||
|
||||
class WeightUpdateNotification(BaseModel):
|
||||
update_count: int
|
||||
trainer_rank: int
|
||||
timestamp: float
|
||||
|
||||
|
||||
class LoraLoadRequest(BaseModel):
|
||||
adapter_path: str
|
||||
adapter_name: Optional[str] = None
|
||||
|
|
@ -318,61 +297,15 @@ async def bridge_info() -> JSONResponse:
|
|||
model_name = str(engine.model_config.model) if hasattr(engine, "model_config") else "unknown"
|
||||
|
||||
return JSONResponse({
|
||||
"enabled": bridge_state.enabled or PATCHES_APPLIED,
|
||||
"enabled": PATCHES_APPLIED,
|
||||
"shared_weights": PATCHES_APPLIED,
|
||||
"update_count": bridge_state.update_count,
|
||||
"last_update_time": bridge_state.last_update_time,
|
||||
"rendezvous_info": bridge_state.rendezvous_info,
|
||||
"model_name": model_name,
|
||||
"device": "cuda" if torch.cuda.is_available() else "cpu",
|
||||
})
|
||||
|
||||
|
||||
@app.post("/bridge/init")
|
||||
async def bridge_init(request: BridgeInitRequest) -> JSONResponse:
|
||||
"""Initialize the weight bridge."""
|
||||
with bridge_state.lock:
|
||||
bridge_state.enabled = True
|
||||
bridge_state.rendezvous_info = {
|
||||
"master_addr": request.master_addr,
|
||||
"master_port": request.master_port,
|
||||
"world_size": request.world_size,
|
||||
"trainer_ranks": request.trainer_ranks,
|
||||
}
|
||||
|
||||
logger.info(f"Bridge initialized: {bridge_state.rendezvous_info}")
|
||||
|
||||
return JSONResponse({
|
||||
"status": "ok",
|
||||
"message": "Weight bridge initialized",
|
||||
"shared_weights_enabled": PATCHES_APPLIED,
|
||||
})
|
||||
|
||||
|
||||
@app.post("/bridge/notify_update")
|
||||
async def bridge_notify_update(notification: WeightUpdateNotification) -> JSONResponse:
|
||||
"""
|
||||
Notification that trainer has updated weights.
|
||||
|
||||
In shared memory mode (PATCHES_APPLIED=True), updates are automatic
|
||||
via the NCCL daemon. This endpoint is for logging/coordination.
|
||||
"""
|
||||
with bridge_state.lock:
|
||||
bridge_state.update_count = notification.update_count
|
||||
bridge_state.last_update_time = notification.timestamp
|
||||
|
||||
if PATCHES_APPLIED:
|
||||
logger.debug(f"Weight update #{notification.update_count} (shared memory)")
|
||||
else:
|
||||
logger.info(f"Weight update #{notification.update_count} (HTTP notification only)")
|
||||
|
||||
return JSONResponse({
|
||||
"status": "ok",
|
||||
"update_count": bridge_state.update_count,
|
||||
"shared_weights": PATCHES_APPLIED,
|
||||
})
|
||||
|
||||
|
||||
@app.get("/bridge/state_dict_info")
|
||||
async def bridge_state_dict_info() -> JSONResponse:
|
||||
"""Get model parameter information."""
|
||||
|
|
@ -391,17 +324,8 @@ async def bridge_state_dict_info() -> JSONResponse:
|
|||
return JSONResponse({"error": str(e)})
|
||||
|
||||
|
||||
@app.post("/bridge/disable")
|
||||
async def bridge_disable() -> JSONResponse:
|
||||
"""Disable the weight bridge."""
|
||||
with bridge_state.lock:
|
||||
bridge_state.enabled = False
|
||||
logger.info("Bridge disabled")
|
||||
return JSONResponse({"status": "ok"})
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pause/Resume Endpoints (for weight updates)
|
||||
# Pause/Resume Endpoints
|
||||
# =============================================================================
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue