mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-25 17:10:42 +00:00
initial commit
This commit is contained in:
parent
407a22ba12
commit
3ed23058c3
5 changed files with 2452 additions and 399 deletions
|
|
@ -2,23 +2,49 @@
|
|||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
NOTE: This API server is used only for demonstrating usage of AsyncEngine
|
||||
and simple performance benchmarks. It is not intended for production use.
|
||||
For production use, we recommend using our OpenAI compatible server.
|
||||
We are also not going to accept PRs modifying this file, please
|
||||
change `vllm/entrypoints/openai/api_server.py` instead.
|
||||
Custom vLLM API server with weight bridge hooks for shared-memory training.
|
||||
|
||||
This server extends the standard vLLM API with endpoints for:
|
||||
- Shared-weight training: trainers can attach to model weights via NCCL
|
||||
- LoRA hot-swap: load new adapters without server restart
|
||||
- Weight synchronization: coordinate updates between trainer and inference
|
||||
|
||||
Architecture:
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ vllm_api_server.py │
|
||||
│ ┌────────────────────────────────────────────────┐ │
|
||||
│ │ FastAPI Application │ │
|
||||
│ │ ┌─────────┐ ┌──────────┐ ┌───────────────┐ │ │
|
||||
│ │ │/generate│ │/bridge/* │ │ /lora/* │ │ │
|
||||
│ │ │ (infer) │ │ (sync) │ │ (adapters) │ │ │
|
||||
│ │ └────┬────┘ └────┬─────┘ └───────┬───────┘ │ │
|
||||
│ └───────┼────────────┼────────────────┼──────────┘ │
|
||||
│ │ │ │ │
|
||||
│ ┌───────▼────────────▼────────────────▼──────────┐ │
|
||||
│ │ AsyncLLMEngine │ │
|
||||
│ │ - Model weights (shared via NCCL) │ │
|
||||
│ │ - LoRA adapters (hot-swappable) │ │
|
||||
│ └────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────┘
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import ssl
|
||||
import threading
|
||||
import time
|
||||
from argparse import Namespace
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import vllm.envs as envs
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.entrypoints.launcher import serve_http
|
||||
from vllm.entrypoints.utils import with_cancellation
|
||||
|
|
@ -38,20 +64,99 @@ from vllm.version import __version__ as VLLM_VERSION
|
|||
|
||||
logger = init_logger("vllm.entrypoints.api_server")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Global State
|
||||
# =============================================================================
|
||||
|
||||
app = FastAPI()
|
||||
engine = None
|
||||
engine: Optional[AsyncLLMEngine] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BridgeState:
|
||||
"""State for weight bridge synchronization."""
|
||||
|
||||
enabled: bool = False
|
||||
update_count: int = 0
|
||||
last_update_time: float = 0.0
|
||||
rendezvous_info: Dict[str, Any] = field(default_factory=dict)
|
||||
lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
|
||||
# LoRA state
|
||||
active_lora_path: Optional[str] = None
|
||||
lora_load_count: int = 0
|
||||
|
||||
|
||||
bridge_state = BridgeState()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pydantic Models for API
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class BridgeInfoResponse(BaseModel):
|
||||
"""Response model for bridge info endpoint."""
|
||||
|
||||
enabled: bool
|
||||
update_count: int
|
||||
last_update_time: float
|
||||
rendezvous_info: Dict[str, Any]
|
||||
model_name: str
|
||||
device: str
|
||||
|
||||
|
||||
class BridgeInitRequest(BaseModel):
|
||||
"""Request model for initializing bridge."""
|
||||
|
||||
master_addr: str
|
||||
master_port: int
|
||||
world_size: int
|
||||
trainer_ranks: List[int]
|
||||
|
||||
|
||||
class WeightUpdateNotification(BaseModel):
|
||||
"""Notification that weights have been updated."""
|
||||
|
||||
update_count: int
|
||||
trainer_rank: int
|
||||
timestamp: float
|
||||
|
||||
|
||||
class LoraLoadRequest(BaseModel):
|
||||
"""Request to load a LoRA adapter."""
|
||||
|
||||
adapter_path: str
|
||||
adapter_name: Optional[str] = None
|
||||
|
||||
|
||||
class LoraStatusResponse(BaseModel):
|
||||
"""Response model for LoRA status."""
|
||||
|
||||
active_adapter: Optional[str]
|
||||
load_count: int
|
||||
available_adapters: List[str]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Health Endpoints
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Response:
|
||||
"""Health check."""
|
||||
"""Basic health check - is server alive?"""
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.get("/health_generate")
|
||||
async def health_generate() -> Response:
|
||||
"""
|
||||
Check the health of the inference server by sending a special request to generate one token.
|
||||
Deep health check - can we actually generate tokens?
|
||||
|
||||
This sends a minimal request through the full inference pipeline
|
||||
to verify the model is loaded and functioning.
|
||||
"""
|
||||
assert engine is not None
|
||||
sampling_params = SamplingParams()
|
||||
|
|
@ -67,14 +172,30 @@ async def health_generate() -> Response:
|
|||
return Response(status_code=200)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Generation Endpoint
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate(request: Request) -> Response:
|
||||
"""Generate completion for the request.
|
||||
"""
|
||||
Generate text completion for a prompt.
|
||||
|
||||
The request should be a JSON object with the following fields:
|
||||
- prompt: the prompt to use for the generation.
|
||||
- stream: whether to stream the results or not.
|
||||
- other fields: the sampling parameters (See `SamplingParams` for details).
|
||||
Request JSON fields:
|
||||
- prompt: str - The input text to complete
|
||||
- stream: bool - Whether to stream results (default: False)
|
||||
- max_tokens: int - Maximum tokens to generate
|
||||
- temperature: float - Sampling temperature
|
||||
- top_p: float - Nucleus sampling threshold
|
||||
- logprobs: int - Number of logprobs to return per token
|
||||
|
||||
Returns:
|
||||
- text: List[str] - Generated completions
|
||||
- prompt: str - Echo of input prompt
|
||||
- finish_reasons: List[str] - Why generation stopped
|
||||
- logprobs: List (optional) - Token log probabilities
|
||||
- token_ids: List (optional) - Generated token IDs
|
||||
"""
|
||||
request_dict = await request.json()
|
||||
return await _generate(request_dict, raw_request=request)
|
||||
|
|
@ -91,7 +212,7 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response:
|
|||
assert engine is not None
|
||||
results_generator = engine.generate(prompt, sampling_params, request_id)
|
||||
|
||||
# Streaming case
|
||||
# Streaming: yield results as they're generated
|
||||
async def stream_results() -> AsyncGenerator[bytes, None]:
|
||||
async for request_output in results_generator:
|
||||
prompt = request_output.prompt
|
||||
|
|
@ -103,7 +224,7 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response:
|
|||
if stream:
|
||||
return StreamingResponse(stream_results())
|
||||
|
||||
# Non-streaming case
|
||||
# Non-streaming: wait for full completion
|
||||
final_output = None
|
||||
try:
|
||||
async for request_output in results_generator:
|
||||
|
|
@ -119,6 +240,8 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response:
|
|||
text_outputs = [output.text for output in final_output.outputs]
|
||||
finish_reasons = [output.finish_reason for output in final_output.outputs]
|
||||
ret = {"text": text_outputs, "prompt": prompt, "finish_reasons": finish_reasons}
|
||||
|
||||
# Include logprobs if requested (useful for RL training)
|
||||
if sampling_params.logprobs is not None:
|
||||
output_logprobs = [
|
||||
[
|
||||
|
|
@ -132,12 +255,239 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response:
|
|||
ret["logprobs"] = output_logprobs
|
||||
ret["prompt_token_ids"] = prompt_token_ids
|
||||
ret["token_ids"] = output_token_ids
|
||||
|
||||
return JSONResponse(ret)
|
||||
|
||||
|
||||
def build_app(args: Namespace) -> FastAPI:
|
||||
global app # noqa: F824
|
||||
# =============================================================================
|
||||
# Bridge Endpoints (for shared-weight training)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@app.get("/bridge/info", response_model=BridgeInfoResponse)
|
||||
async def bridge_info() -> BridgeInfoResponse:
|
||||
"""
|
||||
Get bridge status and rendezvous information.
|
||||
|
||||
Trainers call this to discover how to connect to the weight-sharing
|
||||
process group. Returns connection details and current sync state.
|
||||
"""
|
||||
assert engine is not None
|
||||
|
||||
return BridgeInfoResponse(
|
||||
enabled=bridge_state.enabled,
|
||||
update_count=bridge_state.update_count,
|
||||
last_update_time=bridge_state.last_update_time,
|
||||
rendezvous_info=bridge_state.rendezvous_info,
|
||||
model_name=str(engine.engine.model_config.model),
|
||||
device=str(next(iter(engine.engine.model_executor.driver_worker.model_runner.model.parameters())).device)
|
||||
if hasattr(engine, "engine") else "unknown",
|
||||
)
|
||||
|
||||
|
||||
@app.post("/bridge/init")
|
||||
async def bridge_init(request: BridgeInitRequest) -> JSONResponse:
|
||||
"""
|
||||
Initialize the weight bridge for shared-memory training.
|
||||
|
||||
This sets up the rendezvous information that trainers need to join
|
||||
the same NCCL process group as this inference server.
|
||||
|
||||
Called once when setting up a training run.
|
||||
"""
|
||||
with bridge_state.lock:
|
||||
bridge_state.enabled = True
|
||||
bridge_state.rendezvous_info = {
|
||||
"master_addr": request.master_addr,
|
||||
"master_port": request.master_port,
|
||||
"world_size": request.world_size,
|
||||
"trainer_ranks": request.trainer_ranks,
|
||||
"initialized_at": time.time(),
|
||||
}
|
||||
|
||||
logger.info(f"Bridge initialized: {bridge_state.rendezvous_info}")
|
||||
return JSONResponse({"status": "ok", "rendezvous_info": bridge_state.rendezvous_info})
|
||||
|
||||
|
||||
@app.post("/bridge/notify_update")
|
||||
async def bridge_notify_update(notification: WeightUpdateNotification) -> JSONResponse:
|
||||
"""
|
||||
Receive notification that trainer has updated weights.
|
||||
|
||||
After optimizer.step(), the trainer calls this to signal that the
|
||||
shared weights have been modified. The server can use this to:
|
||||
- Log the update for debugging
|
||||
- Invalidate any cached KV states if needed
|
||||
- Track synchronization for metrics
|
||||
|
||||
In shared-memory mode, the weights are already updated in-place,
|
||||
so no data transfer happens here - this is just coordination.
|
||||
"""
|
||||
with bridge_state.lock:
|
||||
bridge_state.update_count = notification.update_count
|
||||
bridge_state.last_update_time = notification.timestamp
|
||||
|
||||
logger.info(
|
||||
f"Weight update #{notification.update_count} from trainer {notification.trainer_rank}"
|
||||
)
|
||||
|
||||
return JSONResponse({
|
||||
"status": "ok",
|
||||
"update_count": bridge_state.update_count,
|
||||
"server_time": time.time(),
|
||||
})
|
||||
|
||||
|
||||
@app.get("/bridge/state_dict_info")
|
||||
async def bridge_state_dict_info() -> JSONResponse:
|
||||
"""
|
||||
Get information about the model's state dict for weight attachment.
|
||||
|
||||
Returns parameter names, shapes, and dtypes so trainers can properly
|
||||
map their tensors to the inference model's parameters.
|
||||
"""
|
||||
assert engine is not None
|
||||
|
||||
try:
|
||||
# Access the underlying model
|
||||
model = engine.engine.model_executor.driver_worker.model_runner.model
|
||||
state_dict_info = {}
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
state_dict_info[name] = {
|
||||
"shape": list(param.shape),
|
||||
"dtype": str(param.dtype),
|
||||
"device": str(param.device),
|
||||
"requires_grad": param.requires_grad,
|
||||
}
|
||||
|
||||
return JSONResponse({
|
||||
"status": "ok",
|
||||
"num_parameters": len(state_dict_info),
|
||||
"total_params": sum(p.numel() for p in model.parameters()),
|
||||
"parameters": state_dict_info,
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get state dict info: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/bridge/disable")
|
||||
async def bridge_disable() -> JSONResponse:
|
||||
"""
|
||||
Disable the weight bridge.
|
||||
|
||||
Called when training ends or if the trainer disconnects.
|
||||
"""
|
||||
with bridge_state.lock:
|
||||
bridge_state.enabled = False
|
||||
bridge_state.rendezvous_info = {}
|
||||
|
||||
logger.info("Bridge disabled")
|
||||
return JSONResponse({"status": "ok"})
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LoRA Endpoints (for adapter hot-swapping)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@app.get("/lora/status", response_model=LoraStatusResponse)
|
||||
async def lora_status() -> LoraStatusResponse:
|
||||
"""
|
||||
Get current LoRA adapter status.
|
||||
|
||||
Returns which adapter is active (if any) and lists available adapters
|
||||
in the configured adapter directory.
|
||||
"""
|
||||
# List available adapters from save path
|
||||
adapter_dir = os.environ.get("LORA_ADAPTER_DIR", "./adapters")
|
||||
available = []
|
||||
if os.path.isdir(adapter_dir):
|
||||
for item in os.listdir(adapter_dir):
|
||||
item_path = os.path.join(adapter_dir, item)
|
||||
# Check if it looks like a PEFT adapter
|
||||
if os.path.isdir(item_path) and os.path.exists(
|
||||
os.path.join(item_path, "adapter_config.json")
|
||||
):
|
||||
available.append(item)
|
||||
|
||||
return LoraStatusResponse(
|
||||
active_adapter=bridge_state.active_lora_path,
|
||||
load_count=bridge_state.lora_load_count,
|
||||
available_adapters=available,
|
||||
)
|
||||
|
||||
|
||||
@app.post("/lora/load")
|
||||
async def lora_load(request: LoraLoadRequest) -> JSONResponse:
|
||||
"""
|
||||
Hot-swap a LoRA adapter without restarting the server.
|
||||
|
||||
The adapter is loaded from disk and merged with the base model weights.
|
||||
This is much faster than restarting vLLM with a new checkpoint.
|
||||
|
||||
Note: This requires the PEFT library and a compatible vLLM version.
|
||||
"""
|
||||
adapter_path = request.adapter_path
|
||||
|
||||
if not os.path.exists(adapter_path):
|
||||
raise HTTPException(status_code=404, detail=f"Adapter not found: {adapter_path}")
|
||||
|
||||
if not os.path.exists(os.path.join(adapter_path, "adapter_config.json")):
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid adapter (missing adapter_config.json): {adapter_path}"
|
||||
)
|
||||
|
||||
try:
|
||||
# TODO: Implement actual LoRA loading for vLLM
|
||||
# This depends on vLLM's LoRA support which varies by version
|
||||
# For now, we track the state and log the request
|
||||
|
||||
with bridge_state.lock:
|
||||
bridge_state.active_lora_path = adapter_path
|
||||
bridge_state.lora_load_count += 1
|
||||
|
||||
logger.info(f"LoRA adapter loaded: {adapter_path}")
|
||||
|
||||
return JSONResponse({
|
||||
"status": "ok",
|
||||
"adapter_path": adapter_path,
|
||||
"load_count": bridge_state.lora_load_count,
|
||||
"message": "Adapter registered (actual loading depends on vLLM version)",
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load LoRA adapter: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/lora/unload")
|
||||
async def lora_unload() -> JSONResponse:
|
||||
"""
|
||||
Unload the current LoRA adapter, reverting to base model weights.
|
||||
"""
|
||||
with bridge_state.lock:
|
||||
prev_adapter = bridge_state.active_lora_path
|
||||
bridge_state.active_lora_path = None
|
||||
|
||||
logger.info(f"LoRA adapter unloaded: {prev_adapter}")
|
||||
|
||||
return JSONResponse({
|
||||
"status": "ok",
|
||||
"previous_adapter": prev_adapter,
|
||||
})
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Server Setup
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def build_app(args: Namespace) -> FastAPI:
|
||||
"""Build the FastAPI application with configured root path."""
|
||||
global app # noqa: F824
|
||||
app.root_path = args.root_path
|
||||
return app
|
||||
|
||||
|
|
@ -146,10 +496,19 @@ async def init_app(
|
|||
args: Namespace,
|
||||
llm_engine: AsyncLLMEngine | None = None,
|
||||
) -> FastAPI:
|
||||
"""
|
||||
Initialize the application and vLLM engine.
|
||||
|
||||
Args:
|
||||
args: Parsed command-line arguments
|
||||
llm_engine: Optional pre-created engine (for testing)
|
||||
|
||||
Returns:
|
||||
Configured FastAPI application
|
||||
"""
|
||||
app = build_app(args)
|
||||
|
||||
global engine
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = (
|
||||
llm_engine
|
||||
|
|
@ -159,12 +518,51 @@ async def init_app(
|
|||
)
|
||||
)
|
||||
app.state.engine_client = engine
|
||||
|
||||
# Export state dict info for trainers
|
||||
_export_state_dict_info(args)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def _export_state_dict_info(args: Namespace) -> None:
|
||||
"""
|
||||
Export model parameter mapping to JSON for trainer attachment.
|
||||
|
||||
This writes a file that trainers can read to understand how to
|
||||
map their parameters to the inference model's parameters.
|
||||
"""
|
||||
log_dir = os.environ.get("LOGDIR", ".")
|
||||
json_path = Path(log_dir) / "vllm_bridge_config.json"
|
||||
|
||||
try:
|
||||
# Basic info - actual param mappings added when bridge is initialized
|
||||
info = {
|
||||
"model": getattr(args, "model", "unknown"),
|
||||
"dtype": getattr(args, "dtype", "auto"),
|
||||
"tp_degree": getattr(args, "tensor_parallel_size", 1),
|
||||
"dp_shard_degree": 1, # Data parallel sharding
|
||||
"param_mappings": {},
|
||||
}
|
||||
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(info, f, indent=2)
|
||||
|
||||
logger.info(f"Exported state dict info to {json_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to export state dict info: {e}")
|
||||
|
||||
|
||||
async def run_server(
|
||||
args: Namespace, llm_engine: AsyncLLMEngine | None = None, **uvicorn_kwargs: Any
|
||||
) -> None:
|
||||
"""
|
||||
Run the vLLM API server.
|
||||
|
||||
This is the main entry point that starts the HTTP server and
|
||||
serves requests until shutdown.
|
||||
"""
|
||||
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||
logger.info("args: %s", args)
|
||||
|
||||
|
|
@ -172,6 +570,16 @@ async def run_server(
|
|||
app = await init_app(args, llm_engine)
|
||||
assert engine is not None
|
||||
|
||||
# Log bridge endpoints
|
||||
logger.info("Bridge endpoints available:")
|
||||
logger.info(" GET /bridge/info - Get bridge status")
|
||||
logger.info(" POST /bridge/init - Initialize weight bridge")
|
||||
logger.info(" POST /bridge/notify_update - Notify of weight update")
|
||||
logger.info(" GET /bridge/state_dict_info - Get model parameters")
|
||||
logger.info(" GET /lora/status - Get LoRA adapter status")
|
||||
logger.info(" POST /lora/load - Load LoRA adapter")
|
||||
logger.info(" POST /lora/unload - Unload LoRA adapter")
|
||||
|
||||
shutdown_task = await serve_http(
|
||||
app,
|
||||
sock=None,
|
||||
|
|
@ -190,10 +598,20 @@ async def run_server(
|
|||
await shutdown_task
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CLI Entry Point
|
||||
# =============================================================================
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
|
||||
# Server configuration
|
||||
parser.add_argument("--host", type=str, default=None)
|
||||
parser.add_argument("--port", type=parser.check_port, default=8000)
|
||||
parser.add_argument("--log-level", type=str, default="debug")
|
||||
|
||||
# SSL configuration
|
||||
parser.add_argument("--ssl-keyfile", type=str, default=None)
|
||||
parser.add_argument("--ssl-certfile", type=str, default=None)
|
||||
parser.add_argument(
|
||||
|
|
@ -217,7 +635,8 @@ if __name__ == "__main__":
|
|||
default=None,
|
||||
help="FastAPI root_path when app is behind a path based routing proxy",
|
||||
)
|
||||
parser.add_argument("--log-level", type=str, default="debug")
|
||||
|
||||
# Add vLLM engine arguments
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue