atropos/example_trainer/vllm_api_server.py
2026-03-02 11:18:51 -05:00

679 lines
24 KiB
Python

"""
Custom vLLM API server with shared memory weight updates.
This server extends the standard vLLM API with:
- Shared-weight training via NCCL (patched GPUModelRunner)
- LoRA hot-swap without server restart
- Weight synchronization endpoints
ARCHITECTURE:
When --enable-shared-weights is set:
1. vLLM's GPUModelRunner is patched to call share_memory_() on weights
2. A daemon process is spawned that receives NCCL weight updates
3. Trainer broadcasts weights -> daemon copies to shared memory -> vLLM uses immediately
┌─────────────────────────────────────────────────────────────────────────┐
│ SHARED MEMORY (via share_memory_()) │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ Model Weights │ │
│ │ (accessible from MULTIPLE processes) │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ ▲ ▲ │
│ │ Reads │ Writes │
│ ┌────────┴────────┐ ┌───────────┴───────────┐ │
│ │ vLLM Worker │ │ weight_updater │ │
│ │ (inference) │ │ daemon process │ │
│ └─────────────────┘ └───────────┬───────────┘ │
│ │ NCCL │
│ ▼ │
│ ┌─────────────────────┐ │
│ │ Trainer Process │ │
│ └─────────────────────┘ │
└─────────────────────────────────────────────────────────────────────────┘
"""
import asyncio
import json
import os
import ssl
import threading
import time
from argparse import Namespace
from collections.abc import AsyncGenerator
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional
import torch
import vllm.envs as envs
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from pydantic import BaseModel
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.utils import with_cancellation
from vllm.logger import init_logger
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.usage.usage_lib import UsageContext
from vllm.utils import random_uuid
from vllm.v1.engine.async_llm import AsyncLLM
try:
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.system_utils import set_ulimit
except ImportError:
from vllm.utils import FlexibleArgumentParser, set_ulimit
from vllm.outputs import RequestOutput # noqa: F401
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger("vllm.entrypoints.api_server")
# =============================================================================
# Apply vLLM Patches for Shared Memory
# =============================================================================
def _maybe_apply_patches() -> bool:
"""
Apply vLLM patches if shared weights are enabled.
Returns True if patches were applied.
"""
enable_shared = os.environ.get("VLLM_ENABLE_SHARED_WEIGHTS", "0") == "1"
num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", -1))
if not enable_shared and num_inference_nodes < 0:
return False
try:
from .vllm_patching import apply_patches
apply_patches()
logger.info("✓ vLLM patches applied for shared memory weights")
return True
except ImportError as e:
logger.warning(f"Could not import vllm_patching: {e}")
logger.warning("Shared memory weight updates will not be available")
return False
except Exception as e:
logger.warning(f"Failed to apply patches: {e}")
return False
# Apply patches before any other vLLM imports
PATCHES_APPLIED = _maybe_apply_patches()
# =============================================================================
# Global State
# =============================================================================
app = FastAPI()
engine: Optional[AsyncLLM] = None
@dataclass
class BridgeState:
"""State for weight bridge synchronization."""
enabled: bool = False
update_count: int = 0
last_update_time: float = 0.0
rendezvous_info: Dict[str, Any] = field(default_factory=dict)
lock: threading.Lock = field(default_factory=threading.Lock)
# LoRA state
active_lora_path: Optional[str] = None
lora_load_count: int = 0
bridge_state = BridgeState()
# =============================================================================
# Pydantic Models for API
# =============================================================================
class BridgeInfoResponse(BaseModel):
enabled: bool
update_count: int
last_update_time: float
rendezvous_info: Dict[str, Any]
model_name: str
device: str
class BridgeInitRequest(BaseModel):
master_addr: str
master_port: int
world_size: int
trainer_ranks: List[int]
class WeightUpdateNotification(BaseModel):
update_count: int
trainer_rank: int
timestamp: float
class LoraLoadRequest(BaseModel):
adapter_path: str
adapter_name: Optional[str] = None
class LoraStatusResponse(BaseModel):
active_adapter: Optional[str]
load_count: int
available_adapters: List[str]
# =============================================================================
# Health Endpoints
# =============================================================================
@app.get("/health")
async def health() -> Response:
"""Health check."""
return Response(status_code=200)
@app.get("/health_generate")
async def health_generate() -> Response:
"""Health check that verifies model can generate."""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
sampling_params = SamplingParams()
request_id = random_uuid()
try:
results_generator = engine.generate(
{"prompt_token_ids": [0]}, sampling_params, request_id
)
async for _ in results_generator:
pass
return Response(status_code=200)
except asyncio.CancelledError:
return Response(status_code=499)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# =============================================================================
# Generation Endpoints
# =============================================================================
@app.post("/generate")
async def generate(request: Request) -> Response:
"""
Generate completion for the request.
The request should be a JSON object with:
- prompt: the prompt to use for generation
- stream: whether to stream results
- other fields: sampling parameters
"""
request_dict = await request.json()
return await _generate(request_dict, raw_request=request)
@with_cancellation
async def _generate(request_dict: dict, raw_request: Request) -> Response:
"""Internal generate handler."""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", False)
request_dict["output_kind"] = RequestOutputKind.FINAL_ONLY
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()
results_generator = engine.generate(prompt, sampling_params, request_id)
async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator:
prompt = request_output.prompt
assert prompt is not None
text_outputs = [prompt + output.text for output in request_output.outputs]
ret = {"text": text_outputs}
yield (json.dumps(ret) + "\n").encode("utf-8")
if stream:
return StreamingResponse(stream_results())
final_output = None
try:
async for request_output in results_generator:
final_output = request_output
except asyncio.CancelledError:
return Response(status_code=499)
assert final_output is not None
prompt = final_output.prompt or engine.tokenizer.decode(
final_output.prompt_token_ids
)
text_outputs = [output.text for output in final_output.outputs]
finish_reasons = [output.finish_reason for output in final_output.outputs]
ret = {"text": text_outputs, "prompt": prompt, "finish_reasons": finish_reasons}
if sampling_params.logprobs is not None:
output_logprobs = [
[[{key: value.logprob for key, value in logprob.items()}]
for logprob in x.logprobs]
for x in final_output.outputs
]
ret["logprobs"] = output_logprobs
ret["prompt_token_ids"] = final_output.prompt_token_ids
ret["token_ids"] = [x.token_ids for x in final_output.outputs]
return JSONResponse(ret)
# =============================================================================
# Bridge Endpoints (Weight Synchronization)
# =============================================================================
@app.get("/bridge/info")
async def bridge_info() -> JSONResponse:
"""Get bridge status and configuration."""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
model_name = str(engine.model_config.model) if hasattr(engine, "model_config") else "unknown"
return JSONResponse({
"enabled": bridge_state.enabled or PATCHES_APPLIED,
"shared_weights": PATCHES_APPLIED,
"update_count": bridge_state.update_count,
"last_update_time": bridge_state.last_update_time,
"rendezvous_info": bridge_state.rendezvous_info,
"model_name": model_name,
"device": "cuda" if torch.cuda.is_available() else "cpu",
})
@app.post("/bridge/init")
async def bridge_init(request: BridgeInitRequest) -> JSONResponse:
"""Initialize the weight bridge."""
with bridge_state.lock:
bridge_state.enabled = True
bridge_state.rendezvous_info = {
"master_addr": request.master_addr,
"master_port": request.master_port,
"world_size": request.world_size,
"trainer_ranks": request.trainer_ranks,
}
logger.info(f"Bridge initialized: {bridge_state.rendezvous_info}")
return JSONResponse({
"status": "ok",
"message": "Weight bridge initialized",
"shared_weights_enabled": PATCHES_APPLIED,
})
@app.post("/bridge/notify_update")
async def bridge_notify_update(notification: WeightUpdateNotification) -> JSONResponse:
"""
Notification that trainer has updated weights.
In shared memory mode (PATCHES_APPLIED=True), updates are automatic
via the NCCL daemon. This endpoint is for logging/coordination.
"""
with bridge_state.lock:
bridge_state.update_count = notification.update_count
bridge_state.last_update_time = notification.timestamp
if PATCHES_APPLIED:
logger.debug(f"Weight update #{notification.update_count} (shared memory)")
else:
logger.info(f"Weight update #{notification.update_count} (HTTP notification only)")
return JSONResponse({
"status": "ok",
"update_count": bridge_state.update_count,
"shared_weights": PATCHES_APPLIED,
})
@app.get("/bridge/state_dict_info")
async def bridge_state_dict_info() -> JSONResponse:
"""Get model parameter information."""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
# Basic model info
try:
model_config = engine.model_config
return JSONResponse({
"model": str(model_config.model),
"dtype": str(model_config.dtype),
"shared_weights_enabled": PATCHES_APPLIED,
})
except Exception as e:
return JSONResponse({"error": str(e)})
@app.post("/bridge/disable")
async def bridge_disable() -> JSONResponse:
"""Disable the weight bridge."""
with bridge_state.lock:
bridge_state.enabled = False
logger.info("Bridge disabled")
return JSONResponse({"status": "ok"})
# =============================================================================
# Pause/Resume Endpoints (for weight updates)
# =============================================================================
@app.post("/bridge/pause")
async def bridge_pause() -> JSONResponse:
"""Pause generation to allow weight updates."""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
try:
# vLLM v1 supports pause/resume
if hasattr(engine, '_pause_cond'):
async with engine._pause_cond:
engine._paused = True
logger.info("Engine paused")
return JSONResponse({"status": "paused"})
else:
return JSONResponse({"status": "not_supported"})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/bridge/resume")
async def bridge_resume() -> JSONResponse:
"""Resume generation after weight updates."""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
try:
if hasattr(engine, '_pause_cond'):
async with engine._pause_cond:
engine._paused = False
engine._pause_cond.notify_all()
logger.info("Engine resumed")
return JSONResponse({"status": "resumed"})
else:
return JSONResponse({"status": "not_supported"})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/bridge/is_paused")
async def bridge_is_paused() -> JSONResponse:
"""Check if engine is paused."""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
paused = getattr(engine, '_paused', False)
return JSONResponse({"paused": paused})
# =============================================================================
# Sleep/Wake Endpoints (GPU memory management)
# =============================================================================
@app.post("/bridge/sleep")
async def bridge_sleep() -> JSONResponse:
"""Put engine to sleep to free GPU memory."""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
try:
await engine.sleep()
logger.info("Engine sleeping")
return JSONResponse({"status": "sleeping"})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/bridge/wake_up")
async def bridge_wake_up() -> JSONResponse:
"""Wake engine and reload model."""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
try:
await engine.wake_up()
logger.info("Engine woken up")
return JSONResponse({"status": "awake"})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/bridge/is_sleeping")
async def bridge_is_sleeping() -> JSONResponse:
"""Check if engine is sleeping."""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
sleeping = await engine.is_sleeping()
return JSONResponse({"sleeping": sleeping})
# =============================================================================
# Debug Endpoints
# =============================================================================
@app.get("/bridge/debug")
async def bridge_debug() -> JSONResponse:
"""Debug endpoint to inspect engine state."""
debug_info = {
"engine_type": type(engine).__name__ if engine else None,
"vllm_version": VLLM_VERSION,
"patches_applied": PATCHES_APPLIED,
"shared_weights_env": os.environ.get("VLLM_ENABLE_SHARED_WEIGHTS", "0"),
"num_inference_nodes": os.environ.get("NUM_INFERENCE_NODES", "unset"),
}
if engine is not None:
try:
debug_info["model_config"] = {
"model": str(engine.model_config.model),
"dtype": str(engine.model_config.dtype),
}
except Exception:
pass
return JSONResponse(debug_info)
@app.get("/bridge/list_endpoints")
async def list_endpoints() -> JSONResponse:
"""List all available endpoints."""
endpoints = []
for route in app.routes:
if hasattr(route, "path") and hasattr(route, "methods"):
endpoints.append({
"path": route.path,
"methods": list(route.methods),
})
return JSONResponse({"endpoints": endpoints})
# =============================================================================
# LoRA Endpoints
# =============================================================================
@app.get("/lora/status")
async def lora_status() -> LoraStatusResponse:
"""Get LoRA adapter status."""
log_dir = os.environ.get("LOGDIR", ".")
available = []
if os.path.exists(log_dir):
for item in os.listdir(log_dir):
item_path = os.path.join(log_dir, item)
if os.path.isdir(item_path) and os.path.exists(
os.path.join(item_path, "adapter_config.json")
):
available.append(item)
return LoraStatusResponse(
active_adapter=bridge_state.active_lora_path,
load_count=bridge_state.lora_load_count,
available_adapters=available,
)
@app.post("/lora/load")
async def lora_load(request: LoraLoadRequest) -> JSONResponse:
"""Load a LoRA adapter."""
if not os.path.exists(request.adapter_path):
raise HTTPException(status_code=404, detail=f"Adapter not found: {request.adapter_path}")
with bridge_state.lock:
bridge_state.active_lora_path = request.adapter_path
bridge_state.lora_load_count += 1
logger.info(f"LoRA adapter loaded: {request.adapter_path}")
return JSONResponse({
"status": "ok",
"adapter_path": request.adapter_path,
"load_count": bridge_state.lora_load_count,
})
@app.post("/lora/unload")
async def lora_unload() -> JSONResponse:
"""Unload current LoRA adapter."""
with bridge_state.lock:
prev = bridge_state.active_lora_path
bridge_state.active_lora_path = None
logger.info(f"LoRA adapter unloaded: {prev}")
return JSONResponse({"status": "ok", "previous_adapter": prev})
# =============================================================================
# Server Setup
# =============================================================================
def build_app(args: Namespace) -> FastAPI:
"""Build the FastAPI application."""
global app
app.root_path = args.root_path
return app
async def init_app(args: Namespace, llm_engine: AsyncLLM | None = None) -> FastAPI:
"""Initialize the application and vLLM engine."""
app = build_app(args)
global engine
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = (
llm_engine if llm_engine is not None
else AsyncLLM.from_engine_args(engine_args, usage_context=UsageContext.API_SERVER)
)
app.state.engine_client = engine
# Export state dict info for trainers
_export_state_dict_info(args)
return app
def _export_state_dict_info(args: Namespace) -> None:
"""Export model parameter mapping to JSON for trainer."""
log_dir = os.environ.get("LOGDIR", ".")
json_path = Path(log_dir) / "vllm_bridge_config.json"
try:
info = {
"model": getattr(args, "model", "unknown"),
"dtype": getattr(args, "dtype", "auto"),
"tp_degree": getattr(args, "tensor_parallel_size", 1),
"dp_shard_degree": 1,
"param_mappings": {},
"shared_weights_enabled": PATCHES_APPLIED,
}
with open(json_path, "w") as f:
json.dump(info, f, indent=2)
logger.info(f"Exported state dict info to {json_path}")
except Exception as e:
logger.warning(f"Failed to export state dict info: {e}")
async def run_server(args: Namespace, llm_engine: AsyncLLM | None = None, **uvicorn_kwargs: Any) -> None:
"""Run the vLLM API server."""
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)
if PATCHES_APPLIED:
logger.info("=" * 60)
logger.info("SHARED MEMORY MODE ENABLED")
logger.info("Weight updates from trainer will be reflected immediately!")
logger.info("=" * 60)
set_ulimit()
app = await init_app(args, llm_engine)
if engine is None:
raise RuntimeError("No engine initialized")
# Log available endpoints
logger.info("=" * 60)
logger.info("Available endpoints:")
logger.info(" POST /generate - Generate completions")
logger.info(" GET /bridge/info - Bridge status")
logger.info(" POST /bridge/pause - Pause generation")
logger.info(" POST /bridge/resume - Resume generation")
logger.info(" GET /lora/status - LoRA adapter status")
logger.info("=" * 60)
shutdown_task = await serve_http(
app,
sock=None,
enable_ssl_refresh=getattr(args, 'enable_ssl_refresh', False),
host=args.host,
port=args.port,
log_level=getattr(args, 'log_level', 'info'),
timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
ssl_keyfile=getattr(args, 'ssl_keyfile', None),
ssl_certfile=getattr(args, 'ssl_certfile', None),
ssl_ca_certs=getattr(args, 'ssl_ca_certs', None),
ssl_cert_reqs=getattr(args, 'ssl_cert_reqs', ssl.CERT_NONE),
**uvicorn_kwargs,
)
await shutdown_task
if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=9001)
parser.add_argument("--ssl-keyfile", type=str, default=None)
parser.add_argument("--ssl-certfile", type=str, default=None)
parser.add_argument("--ssl-ca-certs", type=str, default=None)
parser.add_argument("--enable-ssl-refresh", action="store_true", default=False)
parser.add_argument("--ssl-cert-reqs", type=int, default=int(ssl.CERT_NONE))
parser.add_argument("--root-path", type=str, default=None)
parser.add_argument("--log-level", type=str, default="info")
# Add vLLM engine args
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
asyncio.run(run_server(args))