mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-30 17:40:36 +00:00
Bloat reduction
This commit is contained in:
parent
01af0777bc
commit
ab8d2f2dac
1 changed files with 19 additions and 95 deletions
|
|
@ -1,35 +1,31 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""
|
"""
|
||||||
Custom vLLM API server with shared memory weight updates.
|
Custom vLLM API server with CUDA IPC shared memory support.
|
||||||
|
|
||||||
This server extends the standard vLLM API with:
|
This server extends the standard vLLM API with:
|
||||||
- Shared-weight training via NCCL (patched GPUModelRunner)
|
- Single-copy mode: Exports CUDA IPC handles so trainer can share vLLM's tensors
|
||||||
- LoRA hot-swap without server restart
|
- LoRA hot-swap without server restart
|
||||||
- Weight synchronization endpoints
|
- Bridge endpoints for coordination
|
||||||
|
|
||||||
ARCHITECTURE:
|
ARCHITECTURE (Single-Copy Mode):
|
||||||
When --enable-shared-weights is set:
|
When VLLM_ENABLE_SHARED_WEIGHTS=1:
|
||||||
1. vLLM's GPUModelRunner is patched to call share_memory_() on weights
|
1. vLLM's GPUModelRunner is patched BEFORE loading
|
||||||
2. A daemon process is spawned that receives NCCL weight updates
|
2. Patched runner exports CUDA IPC handles to vllm_bridge_config.json
|
||||||
3. Trainer broadcasts weights -> daemon copies to shared memory -> vLLM uses immediately
|
3. Trainer reads IPC handles and attaches to the SAME tensors
|
||||||
|
4. optimizer.step() updates weights in-place - vLLM sees changes immediately!
|
||||||
|
|
||||||
┌─────────────────────────────────────────────────────────────────────────┐
|
┌─────────────────────────────────────────────────────────────────────────┐
|
||||||
│ SHARED MEMORY (via share_memory_()) │
|
│ SINGLE GPU (True Shared Memory) │
|
||||||
│ ┌─────────────────────────────────────────────────────────────────┐ │
|
│ ┌─────────────────────────────────────────────────────────────────┐ │
|
||||||
│ │ Model Weights │ │
|
│ │ Model Weights (ONE copy!) │ │
|
||||||
│ │ (accessible from MULTIPLE processes) │ │
|
│ │ (accessible via CUDA IPC handles) │ │
|
||||||
│ └─────────────────────────────────────────────────────────────────┘ │
|
│ └─────────────────────────────────────────────────────────────────┘ │
|
||||||
│ ▲ ▲ │
|
│ ▲ ▲ │
|
||||||
│ │ Reads │ Writes │
|
│ │ Reads (inference) │ Writes (train) │
|
||||||
│ ┌────────┴────────┐ ┌───────────┴───────────┐ │
|
│ ┌────────┴────────┐ ┌───────────┴───────────┐ │
|
||||||
│ │ vLLM Worker │ │ weight_updater │ │
|
│ │ vLLM Worker │ │ Trainer Process │ │
|
||||||
│ │ (inference) │ │ daemon process │ │
|
│ │ │ │ (attached via IPC) │ │
|
||||||
│ └─────────────────┘ └───────────┬───────────┘ │
|
│ └─────────────────┘ └───────────────────────┘ │
|
||||||
│ │ NCCL │
|
|
||||||
│ ▼ │
|
|
||||||
│ ┌─────────────────────┐ │
|
|
||||||
│ │ Trainer Process │ │
|
|
||||||
│ └─────────────────────┘ │
|
|
||||||
└─────────────────────────────────────────────────────────────────────────┘
|
└─────────────────────────────────────────────────────────────────────────┘
|
||||||
|
|
||||||
CRITICAL: Patches must be applied BEFORE importing vLLM!
|
CRITICAL: Patches must be applied BEFORE importing vLLM!
|
||||||
|
|
@ -44,7 +40,6 @@ import os
|
||||||
import ssl
|
import ssl
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
@ -145,11 +140,9 @@ engine: Optional[AsyncLLM] = None
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BridgeState:
|
class BridgeState:
|
||||||
"""State for weight bridge synchronization."""
|
"""State for shared memory and LoRA."""
|
||||||
enabled: bool = False
|
|
||||||
update_count: int = 0
|
update_count: int = 0
|
||||||
last_update_time: float = 0.0
|
last_update_time: float = 0.0
|
||||||
rendezvous_info: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
lock: threading.Lock = field(default_factory=threading.Lock)
|
lock: threading.Lock = field(default_factory=threading.Lock)
|
||||||
|
|
||||||
# LoRA state
|
# LoRA state
|
||||||
|
|
@ -169,24 +162,10 @@ class BridgeInfoResponse(BaseModel):
|
||||||
enabled: bool
|
enabled: bool
|
||||||
update_count: int
|
update_count: int
|
||||||
last_update_time: float
|
last_update_time: float
|
||||||
rendezvous_info: Dict[str, Any]
|
|
||||||
model_name: str
|
model_name: str
|
||||||
device: str
|
device: str
|
||||||
|
|
||||||
|
|
||||||
class BridgeInitRequest(BaseModel):
|
|
||||||
master_addr: str
|
|
||||||
master_port: int
|
|
||||||
world_size: int
|
|
||||||
trainer_ranks: List[int]
|
|
||||||
|
|
||||||
|
|
||||||
class WeightUpdateNotification(BaseModel):
|
|
||||||
update_count: int
|
|
||||||
trainer_rank: int
|
|
||||||
timestamp: float
|
|
||||||
|
|
||||||
|
|
||||||
class LoraLoadRequest(BaseModel):
|
class LoraLoadRequest(BaseModel):
|
||||||
adapter_path: str
|
adapter_path: str
|
||||||
adapter_name: Optional[str] = None
|
adapter_name: Optional[str] = None
|
||||||
|
|
@ -318,61 +297,15 @@ async def bridge_info() -> JSONResponse:
|
||||||
model_name = str(engine.model_config.model) if hasattr(engine, "model_config") else "unknown"
|
model_name = str(engine.model_config.model) if hasattr(engine, "model_config") else "unknown"
|
||||||
|
|
||||||
return JSONResponse({
|
return JSONResponse({
|
||||||
"enabled": bridge_state.enabled or PATCHES_APPLIED,
|
"enabled": PATCHES_APPLIED,
|
||||||
"shared_weights": PATCHES_APPLIED,
|
"shared_weights": PATCHES_APPLIED,
|
||||||
"update_count": bridge_state.update_count,
|
"update_count": bridge_state.update_count,
|
||||||
"last_update_time": bridge_state.last_update_time,
|
"last_update_time": bridge_state.last_update_time,
|
||||||
"rendezvous_info": bridge_state.rendezvous_info,
|
|
||||||
"model_name": model_name,
|
"model_name": model_name,
|
||||||
"device": "cuda" if torch.cuda.is_available() else "cpu",
|
"device": "cuda" if torch.cuda.is_available() else "cpu",
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
@app.post("/bridge/init")
|
|
||||||
async def bridge_init(request: BridgeInitRequest) -> JSONResponse:
|
|
||||||
"""Initialize the weight bridge."""
|
|
||||||
with bridge_state.lock:
|
|
||||||
bridge_state.enabled = True
|
|
||||||
bridge_state.rendezvous_info = {
|
|
||||||
"master_addr": request.master_addr,
|
|
||||||
"master_port": request.master_port,
|
|
||||||
"world_size": request.world_size,
|
|
||||||
"trainer_ranks": request.trainer_ranks,
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(f"Bridge initialized: {bridge_state.rendezvous_info}")
|
|
||||||
|
|
||||||
return JSONResponse({
|
|
||||||
"status": "ok",
|
|
||||||
"message": "Weight bridge initialized",
|
|
||||||
"shared_weights_enabled": PATCHES_APPLIED,
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/bridge/notify_update")
|
|
||||||
async def bridge_notify_update(notification: WeightUpdateNotification) -> JSONResponse:
|
|
||||||
"""
|
|
||||||
Notification that trainer has updated weights.
|
|
||||||
|
|
||||||
In shared memory mode (PATCHES_APPLIED=True), updates are automatic
|
|
||||||
via the NCCL daemon. This endpoint is for logging/coordination.
|
|
||||||
"""
|
|
||||||
with bridge_state.lock:
|
|
||||||
bridge_state.update_count = notification.update_count
|
|
||||||
bridge_state.last_update_time = notification.timestamp
|
|
||||||
|
|
||||||
if PATCHES_APPLIED:
|
|
||||||
logger.debug(f"Weight update #{notification.update_count} (shared memory)")
|
|
||||||
else:
|
|
||||||
logger.info(f"Weight update #{notification.update_count} (HTTP notification only)")
|
|
||||||
|
|
||||||
return JSONResponse({
|
|
||||||
"status": "ok",
|
|
||||||
"update_count": bridge_state.update_count,
|
|
||||||
"shared_weights": PATCHES_APPLIED,
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/bridge/state_dict_info")
|
@app.get("/bridge/state_dict_info")
|
||||||
async def bridge_state_dict_info() -> JSONResponse:
|
async def bridge_state_dict_info() -> JSONResponse:
|
||||||
"""Get model parameter information."""
|
"""Get model parameter information."""
|
||||||
|
|
@ -391,17 +324,8 @@ async def bridge_state_dict_info() -> JSONResponse:
|
||||||
return JSONResponse({"error": str(e)})
|
return JSONResponse({"error": str(e)})
|
||||||
|
|
||||||
|
|
||||||
@app.post("/bridge/disable")
|
|
||||||
async def bridge_disable() -> JSONResponse:
|
|
||||||
"""Disable the weight bridge."""
|
|
||||||
with bridge_state.lock:
|
|
||||||
bridge_state.enabled = False
|
|
||||||
logger.info("Bridge disabled")
|
|
||||||
return JSONResponse({"status": "ok"})
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Pause/Resume Endpoints (for weight updates)
|
# Pause/Resume Endpoints
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue