atropos/example_trainer/vllm_api_server.py
Jai Suphavadeeprasit 0dadc774ac nccl loras 2
2026-02-13 11:26:25 -05:00

1028 lines
36 KiB
Python

#!/usr/bin/env python3
"""
Custom vLLM API server with CUDA IPC shared memory support.
This server extends the standard vLLM API with:
- Single-copy mode: Exports CUDA IPC handles so trainer can share vLLM's tensors
- LoRA hot-swap without server restart
- Bridge endpoints for coordination
ARCHITECTURE (Single-Copy Mode):
When VLLM_ENABLE_SHARED_WEIGHTS=1:
1. vLLM's GPUModelRunner is patched BEFORE loading
2. Patched runner exports CUDA IPC handles to vllm_bridge_config.json
3. Trainer reads IPC handles and attaches to the SAME tensors
4. optimizer.step() updates weights in-place - vLLM sees changes immediately!
┌─────────────────────────────────────────────────────────────────────────┐
│ SINGLE GPU (True Shared Memory) │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ Model Weights (ONE copy!) │ │
│ │ (accessible via CUDA IPC handles) │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ ▲ ▲ │
│ │ Reads (inference) │ Writes (train) │
│ ┌────────┴────────┐ ┌───────────┴───────────┐ │
│ │ vLLM Worker │ │ Trainer Process │ │
│ │ │ │ (attached via IPC) │ │
│ └─────────────────┘ └───────────────────────┘ │
└─────────────────────────────────────────────────────────────────────────┘
CRITICAL: Patches must be applied BEFORE importing vLLM!
"""
# =============================================================================
# STEP 0: Standard library imports ONLY (no vLLM yet!)
# =============================================================================
import asyncio
import json
import multiprocessing
import os
import ssl
import threading
from argparse import Namespace
from collections.abc import AsyncGenerator
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, List, Optional
# Default to v0 engine to avoid CUDA fork issues with v1 engine
# Users can override with VLLM_USE_V1=1 if needed
os.environ.setdefault("VLLM_USE_V1", "0")
# Set spawn method for multiprocessing (required for CUDA)
os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
try:
multiprocessing.set_start_method("spawn", force=True)
except RuntimeError:
pass # Already set
# =============================================================================
# STEP 1: Apply patches BEFORE any vLLM imports!
# =============================================================================
def _apply_patches_early() -> bool:
"""
Apply vLLM patches if shared weights are enabled.
This MUST be called before any vLLM imports!
Returns True if patches were applied.
"""
enable_shared = os.environ.get("VLLM_ENABLE_SHARED_WEIGHTS", "0") == "1"
num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", "-1"))
if not enable_shared and num_inference_nodes < 0:
print("[vLLM Server] Shared weights not enabled, skipping patches")
return False
print("[vLLM Server] VLLM_ENABLE_SHARED_WEIGHTS=1, applying patches...")
try:
# Try relative import first (when run as module)
from .vllm_patching import apply_patches
except ImportError:
# Fall back to absolute import (when run as script)
try:
import sys
from pathlib import Path
# Add parent directory to path so we can import vllm_patching
script_dir = Path(__file__).parent
if str(script_dir) not in sys.path:
sys.path.insert(0, str(script_dir))
from vllm_patching import apply_patches
except ImportError as e:
print(f"[vLLM Server] Could not import vllm_patching: {e}")
print("[vLLM Server] Shared memory weight updates will not be available")
return False
try:
success = apply_patches()
if success:
print("[vLLM Server] ✓ vLLM patches applied successfully!")
else:
print("[vLLM Server] ✗ Failed to apply patches")
return success
except Exception as e:
print(f"[vLLM Server] Error applying patches: {e}")
import traceback
traceback.print_exc()
return False
# Apply patches NOW, before any vLLM imports below!
PATCHES_APPLIED = _apply_patches_early()
# =============================================================================
# STEP 2: Now safe to import vLLM (patches are already in place)
# =============================================================================
import torch # noqa: E402
import vllm.envs as envs # noqa: E402
from fastapi import FastAPI, HTTPException, Request # noqa: E402
from fastapi.responses import JSONResponse, Response, StreamingResponse # noqa: E402
from pydantic import BaseModel # noqa: E402
from vllm.engine.arg_utils import AsyncEngineArgs # noqa: E402
from vllm.entrypoints.launcher import serve_http # noqa: E402
from vllm.entrypoints.utils import with_cancellation # noqa: E402
from vllm.logger import init_logger # noqa: E402
from vllm.sampling_params import RequestOutputKind, SamplingParams # noqa: E402
from vllm.usage.usage_lib import UsageContext # noqa: E402
from vllm.utils import random_uuid # noqa: E402
from vllm.v1.engine.async_llm import AsyncLLM # noqa: E402
# Handle vLLM version differences - FlexibleArgumentParser was removed/renamed
try:
from vllm.utils import FlexibleArgumentParser
except ImportError:
# Create a compatible ArgumentParser that handles 'deprecated' kwarg
# (Python 3.10 doesn't support 'deprecated' in BooleanOptionalAction)
import argparse
class FlexibleArgumentParser(argparse.ArgumentParser):
"""ArgumentParser that strips unsupported kwargs for Python < 3.13."""
def add_argument(self, *args, **kwargs):
# Remove 'deprecated' kwarg if present (not supported before Python 3.13)
kwargs.pop("deprecated", None)
return super().add_argument(*args, **kwargs)
# set_ulimit might not exist in all vLLM versions
try:
from vllm.utils import set_ulimit
except ImportError:
def set_ulimit() -> None:
"""No-op fallback for set_ulimit."""
pass
from vllm.outputs import RequestOutput # noqa: F401, E402
from vllm.version import __version__ as VLLM_VERSION # noqa: E402
# Try to import LoRARequest for adapter support
try:
from vllm.lora.request import LoRARequest # noqa: E402
LORA_AVAILABLE = True
except ImportError:
LORA_AVAILABLE = False
LoRARequest = None # type: ignore
logger = init_logger("vllm.entrypoints.api_server")
app = FastAPI()
engine: Optional[AsyncLLM] = None
@dataclass
class BridgeState:
"""State for shared memory and LoRA."""
update_count: int = 0
last_update_time: float = 0.0
lock: threading.Lock = field(default_factory=threading.Lock)
# LoRA state
active_lora_path: Optional[str] = None
active_lora_name: Optional[str] = None
active_lora_id: int = 0 # vLLM requires unique integer ID per adapter
lora_load_count: int = 0
bridge_state = BridgeState()
def _get_lora_request() -> Optional["LoRARequest"]:
"""Get the current LoRA request if an adapter is active."""
if not LORA_AVAILABLE:
return None
if bridge_state.active_lora_path is None:
return None
return LoRARequest(
lora_name=bridge_state.active_lora_name or "default_adapter",
lora_int_id=bridge_state.active_lora_id,
lora_path=bridge_state.active_lora_path,
)
# =============================================================================
# Pydantic Models for API
# =============================================================================
class BridgeInfoResponse(BaseModel):
enabled: bool
update_count: int
last_update_time: float
model_name: str
device: str
class LoraLoadRequest(BaseModel):
adapter_path: str
adapter_name: Optional[str] = None
class LoraStatusResponse(BaseModel):
lora_available: bool
active_adapter_path: Optional[str]
active_adapter_name: Optional[str]
active_adapter_id: Optional[int]
load_count: int
available_adapters: List[str]
# =============================================================================
# Health Endpoints
# =============================================================================
@app.get("/health")
async def health() -> Response:
"""Health check."""
return Response(status_code=200)
@app.get("/health_generate")
async def health_generate() -> Response:
"""Health check that verifies model can generate."""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
sampling_params = SamplingParams()
request_id = random_uuid()
try:
results_generator = engine.generate(
{"prompt_token_ids": [0]}, sampling_params, request_id
)
async for _ in results_generator:
pass
return Response(status_code=200)
except asyncio.CancelledError:
return Response(status_code=499)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# =============================================================================
# Generation Endpoints
# =============================================================================
@app.post("/generate")
async def generate(request: Request) -> Response:
"""
Generate completion for the request.
The request should be a JSON object with:
- prompt: the prompt to use for generation
- stream: whether to stream results
- other fields: sampling parameters
"""
request_dict = await request.json()
return await _generate(request_dict, raw_request=request)
@with_cancellation
async def _generate(request_dict: dict, raw_request: Request) -> Response:
"""Internal generate handler."""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", False)
request_dict["output_kind"] = RequestOutputKind.FINAL_ONLY
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()
# Get active LoRA adapter if any
lora_request = _get_lora_request()
results_generator = engine.generate(
prompt, sampling_params, request_id, lora_request=lora_request
)
async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator:
prompt = request_output.prompt
assert prompt is not None
text_outputs = [prompt + output.text for output in request_output.outputs]
ret = {"text": text_outputs}
yield (json.dumps(ret) + "\n").encode("utf-8")
if stream:
return StreamingResponse(stream_results())
final_output = None
try:
async for request_output in results_generator:
final_output = request_output
except asyncio.CancelledError:
return Response(status_code=499)
assert final_output is not None
prompt = final_output.prompt or engine.tokenizer.decode(
final_output.prompt_token_ids
)
text_outputs = [output.text for output in final_output.outputs]
finish_reasons = [output.finish_reason for output in final_output.outputs]
ret = {"text": text_outputs, "prompt": prompt, "finish_reasons": finish_reasons}
if sampling_params.logprobs is not None:
output_logprobs = [
[
[{key: value.logprob for key, value in logprob.items()}]
for logprob in x.logprobs
]
for x in final_output.outputs
]
ret["logprobs"] = output_logprobs
ret["prompt_token_ids"] = final_output.prompt_token_ids
ret["token_ids"] = [x.token_ids for x in final_output.outputs]
return JSONResponse(ret)
# =============================================================================
# Bridge Endpoints (Weight Synchronization)
# =============================================================================
@app.get("/bridge/info")
async def bridge_info() -> JSONResponse:
"""Get bridge status and configuration."""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
model_name = (
str(engine.model_config.model) if hasattr(engine, "model_config") else "unknown"
)
return JSONResponse(
{
"enabled": PATCHES_APPLIED,
"shared_weights": PATCHES_APPLIED,
"update_count": bridge_state.update_count,
"last_update_time": bridge_state.last_update_time,
"model_name": model_name,
"device": "cuda" if torch.cuda.is_available() else "cpu",
}
)
@app.get("/bridge/state_dict_info")
async def bridge_state_dict_info() -> JSONResponse:
"""Get model parameter information."""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
# Basic model info
try:
model_config = engine.model_config
return JSONResponse(
{
"model": str(model_config.model),
"dtype": str(model_config.dtype),
"shared_weights_enabled": PATCHES_APPLIED,
}
)
except Exception as e:
return JSONResponse({"error": str(e)})
# =============================================================================
# Pause/Resume Endpoints
# =============================================================================
@app.post("/bridge/pause")
async def bridge_pause() -> JSONResponse:
"""Pause generation to allow weight updates."""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
try:
# vLLM v1 supports pause/resume
if hasattr(engine, "_pause_cond"):
async with engine._pause_cond:
engine._paused = True
logger.info("Engine paused")
return JSONResponse({"status": "paused"})
else:
return JSONResponse({"status": "not_supported"})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/bridge/resume")
async def bridge_resume() -> JSONResponse:
"""Resume generation after weight updates."""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
try:
if hasattr(engine, "_pause_cond"):
async with engine._pause_cond:
engine._paused = False
engine._pause_cond.notify_all()
logger.info("Engine resumed")
return JSONResponse({"status": "resumed"})
else:
return JSONResponse({"status": "not_supported"})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/bridge/is_paused")
async def bridge_is_paused() -> JSONResponse:
"""Check if engine is paused."""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
paused = getattr(engine, "_paused", False)
return JSONResponse({"paused": paused})
# =============================================================================
# Sleep/Wake Endpoints (GPU memory management)
# =============================================================================
@app.post("/bridge/sleep")
async def bridge_sleep() -> JSONResponse:
"""Put engine to sleep to free GPU memory."""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
try:
await engine.sleep()
logger.info("Engine sleeping")
return JSONResponse({"status": "sleeping"})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/bridge/wake_up")
async def bridge_wake_up() -> JSONResponse:
"""Wake engine and reload model."""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
try:
await engine.wake_up()
logger.info("Engine woken up")
return JSONResponse({"status": "awake"})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/bridge/is_sleeping")
async def bridge_is_sleeping() -> JSONResponse:
"""Check if engine is sleeping."""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
sleeping = await engine.is_sleeping()
return JSONResponse({"sleeping": sleeping})
# =============================================================================
# Debug Endpoints
# =============================================================================
@app.get("/bridge/debug")
async def bridge_debug() -> JSONResponse:
"""Debug endpoint to inspect engine state."""
debug_info = {
"engine_type": type(engine).__name__ if engine else None,
"vllm_version": VLLM_VERSION,
"patches_applied": PATCHES_APPLIED,
"shared_weights_env": os.environ.get("VLLM_ENABLE_SHARED_WEIGHTS", "0"),
"num_inference_nodes": os.environ.get("NUM_INFERENCE_NODES", "unset"),
"logdir": os.environ.get("LOGDIR", "unset"),
}
if engine is not None:
try:
debug_info["model_config"] = {
"model": str(engine.model_config.model),
"dtype": str(engine.model_config.dtype),
}
except Exception:
pass
return JSONResponse(debug_info)
@app.get("/bridge/list_endpoints")
async def list_endpoints() -> JSONResponse:
"""List all available endpoints."""
endpoints = []
for route in app.routes:
if hasattr(route, "path") and hasattr(route, "methods"):
endpoints.append(
{
"path": route.path,
"methods": list(route.methods),
}
)
return JSONResponse({"endpoints": endpoints})
# =============================================================================
# LoRA Endpoints
# =============================================================================
@app.get("/lora/status")
async def lora_status() -> LoraStatusResponse:
"""Get LoRA adapter status."""
log_dir = os.environ.get("LOGDIR", ".")
available = []
if os.path.exists(log_dir):
for item in os.listdir(log_dir):
item_path = os.path.join(log_dir, item)
if os.path.isdir(item_path) and os.path.exists(
os.path.join(item_path, "adapter_config.json")
):
available.append(item)
return LoraStatusResponse(
lora_available=LORA_AVAILABLE,
active_adapter_path=bridge_state.active_lora_path,
active_adapter_name=bridge_state.active_lora_name,
active_adapter_id=(
bridge_state.active_lora_id if bridge_state.active_lora_path else None
),
load_count=bridge_state.lora_load_count,
available_adapters=available,
)
@app.post("/lora/load")
async def lora_load(request: LoraLoadRequest) -> JSONResponse:
"""Load a LoRA adapter."""
if not os.path.exists(request.adapter_path):
raise HTTPException(
status_code=404, detail=f"Adapter not found: {request.adapter_path}"
)
# Read adapter config to validate and log details
adapter_config_path = os.path.join(request.adapter_path, "adapter_config.json")
adapter_info = {}
if os.path.exists(adapter_config_path):
try:
with open(adapter_config_path, "r") as f:
adapter_config = json.load(f)
adapter_info = {
"r": adapter_config.get("r"),
"lora_alpha": adapter_config.get("lora_alpha"),
"target_modules": adapter_config.get("target_modules"),
"base_model": adapter_config.get("base_model_name_or_path"),
}
logger.info(f"LoRA adapter config: {adapter_info}")
except Exception as e:
logger.warning(f"Could not read adapter_config.json: {e}")
else:
logger.warning(f"No adapter_config.json found at {adapter_config_path}")
with bridge_state.lock:
bridge_state.active_lora_path = request.adapter_path
bridge_state.active_lora_name = (
request.adapter_name or f"adapter_{bridge_state.lora_load_count}"
)
bridge_state.active_lora_id = (
bridge_state.lora_load_count + 1
) # vLLM needs unique int ID
bridge_state.lora_load_count += 1
logger.info(
f"LoRA adapter loaded: {request.adapter_path} (id={bridge_state.active_lora_id})"
)
return JSONResponse(
{
"status": "ok",
"adapter_path": request.adapter_path,
"adapter_name": bridge_state.active_lora_name,
"adapter_id": bridge_state.active_lora_id,
"load_count": bridge_state.lora_load_count,
"adapter_config": adapter_info,
}
)
@app.post("/lora/unload")
async def lora_unload() -> JSONResponse:
"""Unload current LoRA adapter."""
with bridge_state.lock:
prev_path = bridge_state.active_lora_path
prev_name = bridge_state.active_lora_name
bridge_state.active_lora_path = None
bridge_state.active_lora_name = None
bridge_state.active_lora_id = 0
logger.info(f"LoRA adapter unloaded: {prev_path} ({prev_name})")
return JSONResponse(
{
"status": "ok",
"previous_adapter": prev_path,
"previous_name": prev_name,
}
)
# =============================================================================
# NCCL Weight Receiver (for lora_nccl mode)
# =============================================================================
nccl_bridge: Optional[Any] = None # Will hold NCCLWeightBridge instance
nccl_setup_thread: Optional[threading.Thread] = None
nccl_setup_error: Optional[str] = None
def _setup_nccl_receiver_thread(
init_method: str,
world_size: int,
param_metadata: dict,
param_mappings: dict,
):
"""Background thread to setup NCCL receiver and wait for weight updates."""
global nccl_bridge, nccl_setup_error
logger.info(f"[NCCL] Receiver thread started, attempting to import nccl_weight_bridge...")
NCCLBridgeConfig = None
NCCLWeightBridge = None
# Try multiple import methods
try:
from .nccl_weight_bridge import NCCLBridgeConfig, NCCLWeightBridge
logger.info("[NCCL] Imported via relative import")
except ImportError:
pass
if NCCLBridgeConfig is None:
try:
from nccl_weight_bridge import NCCLBridgeConfig, NCCLWeightBridge
logger.info("[NCCL] Imported via direct import")
except ImportError:
pass
if NCCLBridgeConfig is None:
try:
import sys
from pathlib import Path
script_dir = Path(__file__).parent
if str(script_dir) not in sys.path:
sys.path.insert(0, str(script_dir))
from nccl_weight_bridge import NCCLBridgeConfig, NCCLWeightBridge
logger.info("[NCCL] Imported via sys.path manipulation")
except ImportError as e:
nccl_setup_error = f"NCCL weight bridge module not available: {e}"
logger.error(nccl_setup_error)
return
if NCCLBridgeConfig is None:
nccl_setup_error = "Failed to import NCCLBridgeConfig"
logger.error(nccl_setup_error)
return
try:
config = NCCLBridgeConfig(
rank=1, # vLLM is rank 1 (trainer is rank 0)
world_size=world_size,
init_method=init_method,
)
logger.info(f"[NCCL] Starting setup as rank 1, world_size={world_size}")
nccl_bridge = NCCLWeightBridge(config)
if not nccl_bridge.setup():
nccl_setup_error = "Failed to setup NCCL bridge"
logger.error(nccl_setup_error)
return
# Set param metadata from trainer
nccl_bridge.param_names = param_metadata.get("param_names", [])
nccl_bridge.param_shapes = {
k: tuple(v) for k, v in
param_metadata.get("param_shapes", {}).items()
}
nccl_bridge.param_dtypes = param_metadata.get("param_dtypes", {})
logger.info(f"[NCCL] ✓ Bridge setup complete, {len(nccl_bridge.param_names)} params registered")
logger.info(f"[NCCL] Starting receiver loop to wait for weight updates...")
# Start receiver loop - wait for weight updates from trainer
while True:
try:
step, weights = nccl_bridge.receive_lora_weights()
if step < 0:
logger.info("[NCCL] Received shutdown signal, exiting receiver loop")
break
logger.info(f"[NCCL] ✓ Received weights for step {step} ({len(weights)} params)")
# TODO: Apply weights to vLLM's LoRA adapter
# For now, we just log receipt - actual weight application
# would require access to vLLM's internal model state
except Exception as e:
if "shutdown" in str(e).lower() or nccl_bridge is None:
logger.info("[NCCL] Receiver loop terminated")
break
logger.error(f"[NCCL] Error receiving weights: {e}")
import traceback
traceback.print_exc()
break
except Exception as e:
nccl_setup_error = f"NCCL setup error: {e}"
logger.error(nccl_setup_error)
import traceback
traceback.print_exc()
@app.post("/nccl/start_receiver")
async def nccl_start_receiver(request: Request) -> JSONResponse:
"""
Start NCCL weight receiver (for lora_nccl training mode).
This endpoint starts the NCCL setup in a background thread so that
both trainer (rank 0) and vLLM (rank 1) can join the NCCL group
simultaneously.
Request JSON:
{
"init_method": "tcp://localhost:29500",
"world_size": 2,
"param_metadata": {...},
"param_mappings": {...}
}
"""
global nccl_bridge, nccl_setup_thread, nccl_setup_error
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
# Stop any existing bridge
if nccl_bridge is not None:
try:
nccl_bridge.cleanup()
except Exception:
pass
nccl_bridge = None
nccl_setup_error = None
request_dict = await request.json()
init_method = request_dict.get("init_method", "tcp://localhost:29500")
world_size = request_dict.get("world_size", 2)
param_metadata = request_dict.get("param_metadata", {})
param_mappings = request_dict.get("param_mappings", {})
logger.info(f"[NCCL] Received start_receiver request: init_method={init_method}, world_size={world_size}")
logger.info(f"[NCCL] Param metadata: {len(param_metadata.get('param_names', []))} params")
# Start NCCL setup in background thread
# This allows the HTTP response to return immediately while NCCL joins
nccl_setup_thread = threading.Thread(
target=_setup_nccl_receiver_thread,
args=(init_method, world_size, param_metadata, param_mappings),
daemon=True,
name="nccl_receiver_thread",
)
nccl_setup_thread.start()
# Wait a moment to catch immediate errors
import time as time_mod
time_mod.sleep(0.5)
if nccl_setup_error:
return JSONResponse({
"status": "error",
"message": nccl_setup_error,
}, status_code=500)
logger.info(f"[NCCL] Receiver thread started, waiting for trainer to connect...")
return JSONResponse({
"status": "ok",
"message": "NCCL receiver setup started - waiting for trainer to connect",
"rank": 1,
"world_size": world_size,
})
@app.post("/nccl/stop_receiver")
async def nccl_stop_receiver() -> JSONResponse:
"""Stop NCCL weight receiver."""
global nccl_bridge, nccl_setup_thread
if nccl_bridge is None:
return JSONResponse({"status": "ok", "message": "No receiver running"})
try:
nccl_bridge.cleanup()
except Exception as e:
logger.warning(f"Error during NCCL cleanup: {e}")
nccl_bridge = None
nccl_setup_thread = None
return JSONResponse({"status": "ok", "message": "NCCL receiver stopped"})
@app.get("/nccl/status")
async def nccl_status() -> JSONResponse:
"""Get NCCL receiver status."""
global nccl_setup_error
if nccl_setup_thread is not None and nccl_setup_thread.is_alive():
return JSONResponse({
"active": False,
"status": "connecting",
"message": "NCCL setup in progress...",
})
if nccl_setup_error is not None:
return JSONResponse({
"active": False,
"status": "error",
"error": nccl_setup_error,
})
if nccl_bridge is None:
return JSONResponse({
"active": False,
"status": "not_started",
"update_count": 0,
})
return JSONResponse({
"active": nccl_bridge.is_initialized,
"status": "connected" if nccl_bridge.is_initialized else "disconnected",
"update_count": nccl_bridge.update_count,
"last_update_time": nccl_bridge.last_update_time,
"num_params": len(nccl_bridge.param_names),
})
# =============================================================================
# Server Setup
# =============================================================================
def build_app(args: Namespace) -> FastAPI:
"""Build the FastAPI application."""
app.root_path = args.root_path
return app
async def init_app(args: Namespace, llm_engine: AsyncLLM | None = None) -> FastAPI:
"""Initialize the application and vLLM engine."""
app = build_app(args)
global engine
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = (
llm_engine
if llm_engine is not None
else AsyncLLM.from_engine_args(
engine_args, usage_context=UsageContext.API_SERVER
)
)
app.state.engine_client = engine
# Export basic state dict info for trainers (the patched runner exports detailed info)
_export_state_dict_info(args)
return app
def _export_state_dict_info(args: Namespace) -> None:
"""Export basic model info to JSON for trainer (backup if patches don't run)."""
# Allow explicit config path via env var, otherwise use LOGDIR
config_path = os.environ.get("VLLM_BRIDGE_CONFIG_PATH")
if config_path:
json_path = Path(config_path)
json_path.parent.mkdir(parents=True, exist_ok=True)
else:
log_dir = os.environ.get("LOGDIR", ".")
Path(log_dir).mkdir(parents=True, exist_ok=True)
json_path = Path(log_dir) / "vllm_bridge_config.json"
# Only write basic info if the file doesn't exist or is empty
# The patched runner will write complete info with param_mappings
try:
if json_path.exists():
with open(json_path, "r") as f:
existing = json.load(f)
if (
existing.get("param_mappings")
and len(existing["param_mappings"]) > 0
):
logger.info("Config already has param_mappings, not overwriting")
return
info = {
"model": getattr(args, "model", "unknown"),
"dtype": getattr(args, "dtype", "auto"),
"tp_degree": getattr(args, "tensor_parallel_size", 1),
"dp_shard_degree": 1,
"param_mappings": {},
"shared_weights_enabled": PATCHES_APPLIED,
}
with open(json_path, "w") as f:
json.dump(info, f, indent=2)
logger.info(f"Exported basic state dict info to {json_path}")
except Exception as e:
logger.warning(f"Failed to export state dict info: {e}")
async def run_server(
args: Namespace, llm_engine: AsyncLLM | None = None, **uvicorn_kwargs: Any
) -> None:
"""Run the vLLM API server."""
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)
if PATCHES_APPLIED:
logger.info("=" * 60)
logger.info("SHARED MEMORY MODE ENABLED")
logger.info("Weight updates from trainer will be reflected immediately!")
logger.info("=" * 60)
set_ulimit()
app = await init_app(args, llm_engine)
if engine is None:
raise RuntimeError("No engine initialized")
# Log available endpoints
logger.info("=" * 60)
logger.info("Streamlined vLLM Server - Training-Focused API")
logger.info("Available endpoints:")
logger.info(" POST /generate - Generate with logprobs (primary endpoint)")
logger.info(" GET /health - Health check")
logger.info(" GET /bridge/info - Bridge status")
logger.info(" POST /bridge/pause - Pause generation")
logger.info(" POST /bridge/resume - Resume generation")
logger.info(" GET /lora/status - LoRA adapter status")
logger.info(" POST /lora/load - Load LoRA adapter")
logger.info(" POST /lora/unload - Unload LoRA adapter")
logger.info(" POST /nccl/start_receiver - Start NCCL weight receiver (lora_nccl mode)")
logger.info(" POST /nccl/stop_receiver - Stop NCCL weight receiver")
logger.info(" GET /nccl/status - NCCL receiver status")
logger.info("=" * 60)
shutdown_task = await serve_http(
app,
sock=None,
enable_ssl_refresh=getattr(args, "enable_ssl_refresh", False),
host=args.host,
port=args.port,
log_level=getattr(args, "log_level", "info"),
timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
ssl_keyfile=getattr(args, "ssl_keyfile", None),
ssl_certfile=getattr(args, "ssl_certfile", None),
ssl_ca_certs=getattr(args, "ssl_ca_certs", None),
ssl_cert_reqs=getattr(args, "ssl_cert_reqs", ssl.CERT_NONE),
**uvicorn_kwargs,
)
await shutdown_task
if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=9001)
parser.add_argument("--ssl-keyfile", type=str, default=None)
parser.add_argument("--ssl-certfile", type=str, default=None)
parser.add_argument("--ssl-ca-certs", type=str, default=None)
parser.add_argument("--enable-ssl-refresh", action="store_true", default=False)
parser.add_argument("--ssl-cert-reqs", type=int, default=int(ssl.CERT_NONE))
parser.add_argument("--root-path", type=str, default=None)
parser.add_argument("--log-level", type=str, default="info")
# Add vLLM engine args
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
asyncio.run(run_server(args))