mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
add managed vllm server
This commit is contained in:
parent
578175a709
commit
e6ac3abdcb
9 changed files with 597 additions and 15 deletions
|
|
@ -171,6 +171,8 @@ pre-commit install
|
|||
|
||||
You should edit the config_init section of the environment file you want ([For example, in GSM8K Environment](https://github.com/NousResearch/atropos/blob/main/environments/gsm8k_server.py#L53)) to point to a running VLLM or SGLang inference server as well as any other [configuration changes](CONFIG.md) you'd like to make, such as the group size, then:
|
||||
|
||||
> **Note:** By default, Atropos uses the OpenAI-compatible API endpoint which works with any provider. For enhanced features, use `VLLMServer` (atroposlib/envs/server_handling/vllm_server.py) or `SGLangServer` (atroposlib/envs/server_handling/sglang_server.py) for direct access to native APIs with full token and logprob tracking.
|
||||
|
||||
```bash
|
||||
# Start the API server
|
||||
run-api
|
||||
|
|
|
|||
|
|
@ -181,6 +181,7 @@ These class-level variables in `BaseEnv` can be overridden in your subclass to c
|
|||
* **`server_cls: Type[APIServer]`**:
|
||||
* Default: `APIServer`
|
||||
* Purpose: Specifies the class to be used for managing interactions with API servers (e.g., inference endpoints). Should mostly be used for developing additional API interfaces, but if you need a nonstandard way of connecting with an existing API you can use this to easily slot in any modifications you need.
|
||||
* **Note:** In most cases, you should use the `server_type` field in your `APIServerConfig` instead of overriding this. Set `server_type` to `"openai"` (default), `"vllm"`, `"sglang"`, or `"trl"` to automatically use the appropriate server class with enhanced features like native API access and full token/logprob tracking.
|
||||
|
||||
## Provided Functionality
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@
|
|||
|
||||
`ManagedServer` is a wrapper around `APIServer` that automatically tracks text sequences with aligned tokens and logprobs. It eliminates the need for manual token extraction, alignment, and masking in your environment code, making it **the recommended approach** for handling inference in Atropos environments.
|
||||
|
||||
**Server Compatibility:** ManagedServer works with all Atropos server types - `OpenAIServer`, `VLLMServer`, `SGLangServer`, and `TrlVllmServer`. Simply set the `server_type` field in your `APIServerConfig` to `"openai"` (default), `"vllm"`, `"sglang"`, or `"trl"` to use the appropriate backend with automatic server class selection.
|
||||
|
||||
### Why Use ManagedServer?
|
||||
|
||||
**Before ManagedServer** (manual approach):
|
||||
|
|
|
|||
|
|
@ -108,8 +108,8 @@ class ServerBaseline(BaseModel):
|
|||
rolling_buffer_length: int = Field(
|
||||
default=1000, description="Length of the rolling buffer to store metrics."
|
||||
)
|
||||
server_type: Literal["openai", "trl", "sglang"] = Field(
|
||||
default="openai", description="Type of server to use, openai or trl"
|
||||
server_type: Literal["openai", "trl", "sglang", "vllm"] = Field(
|
||||
default="openai", description="Type of server to use"
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from atroposlib.envs.server_handling.server_baseline import (
|
|||
from atroposlib.envs.server_handling.server_harness import ServerHarness
|
||||
from atroposlib.envs.server_handling.sglang_server import SGLangServer
|
||||
from atroposlib.envs.server_handling.trl_vllm_server import TrlVllmServer
|
||||
from atroposlib.envs.server_handling.vllm_server import VLLMServer
|
||||
|
||||
|
||||
class ServerManagerConfig(BaseModel):
|
||||
|
|
@ -58,6 +59,8 @@ class ServerManager:
|
|||
server_class = TrlVllmServer
|
||||
elif configs.server_type == "sglang":
|
||||
server_class = SGLangServer
|
||||
elif configs.server_type == "vllm":
|
||||
server_class = VLLMServer
|
||||
else:
|
||||
raise ValueError(f"Invalid server type: {configs.server_type}")
|
||||
else:
|
||||
|
|
@ -67,6 +70,8 @@ class ServerManager:
|
|||
server_class = TrlVllmServer
|
||||
elif configs[0].server_type == "sglang":
|
||||
server_class = SGLangServer
|
||||
elif configs[0].server_type == "vllm":
|
||||
server_class = VLLMServer
|
||||
else:
|
||||
raise ValueError(f"Invalid server type: {configs[0].server_type}")
|
||||
if testing:
|
||||
|
|
@ -198,7 +203,7 @@ class ServerManager:
|
|||
for completion in completions[1:]:
|
||||
out.choices.extend(completion.choices)
|
||||
return out
|
||||
is_train = kwargs.get("split", "train") == "train"
|
||||
is_train = kwargs.pop("split", "train") == "train"
|
||||
most_available_server = 0
|
||||
most_available_server_num_slots = -1
|
||||
await self.wait_for_sem(is_train)
|
||||
|
|
@ -231,7 +236,7 @@ class ServerManager:
|
|||
for completion in completions[1:]:
|
||||
out.choices.extend(completion.choices)
|
||||
return out
|
||||
is_train = kwargs.get("split", "train") == "train"
|
||||
is_train = kwargs.pop("split", "train") == "train"
|
||||
most_available_server = 0
|
||||
most_available_server_num_slots = -1
|
||||
await self.wait_for_sem(is_train)
|
||||
|
|
@ -276,7 +281,7 @@ class ServerManager:
|
|||
finish_reasons.extend(out_finish_reasons)
|
||||
return (prompt_tokens, output_tokens, output_logprobs, finish_reasons)
|
||||
|
||||
is_train = kwargs.get("split", "train") == "train"
|
||||
is_train = kwargs.pop("split", "train") == "train"
|
||||
most_available_server = 0
|
||||
most_available_server_num_slots = -1
|
||||
await self.wait_for_sem(is_train)
|
||||
|
|
|
|||
344
atroposlib/envs/server_handling/vllm_server.py
Normal file
344
atroposlib/envs/server_handling/vllm_server.py
Normal file
|
|
@ -0,0 +1,344 @@
|
|||
# This requires a customized vLLM api server
|
||||
# see example_trainer/vllm_api_server.py for an example
|
||||
|
||||
import asyncio
|
||||
import warnings
|
||||
|
||||
import aiohttp
|
||||
import openai
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.completion import Completion
|
||||
from pydantic_cli import FailedExecutionException
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from atroposlib.envs.constants import NAMESPACE_SEP, OPENAI_NAMESPACE
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServer, APIServerConfig
|
||||
|
||||
|
||||
class VLLMServer(APIServer):
|
||||
"""
|
||||
VLLM server handling.
|
||||
"""
|
||||
|
||||
def __init__(self, config: APIServerConfig):
|
||||
self.openai = openai.AsyncClient(
|
||||
api_key=config.api_key,
|
||||
base_url=config.base_url,
|
||||
timeout=config.timeout,
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
||||
super().__init__(config)
|
||||
|
||||
async def check_server_status_task(self, chat_completion: bool = True):
|
||||
while True:
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.config.base_url.replace('/v1', '')}/health",
|
||||
headers=(
|
||||
{"Authorization": f"Bearer {self.config.api_key}"}
|
||||
if self.config.api_key
|
||||
else {}
|
||||
),
|
||||
timeout=aiohttp.ClientTimeout(total=self.config.timeout),
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
self.server_healthy = True
|
||||
except (
|
||||
aiohttp.ClientError,
|
||||
openai.OpenAIError,
|
||||
openai.APITimeoutError,
|
||||
Exception,
|
||||
):
|
||||
self.server_healthy = False
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion:
|
||||
"""
|
||||
Wrapper for the chat completion using the openai client.
|
||||
"""
|
||||
assert (
|
||||
kwargs.get("model", None) is not None
|
||||
), "Model is required for chat completion!"
|
||||
assert (
|
||||
kwargs.get("messages", None) is not None
|
||||
), "Messages are required for chat completion!"
|
||||
if self.config.n_kwarg_is_ignored:
|
||||
n = kwargs.pop("n", 1)
|
||||
completion_list = await asyncio.gather(
|
||||
*[self.openai.chat.completions.create(**kwargs) for _ in range(n)]
|
||||
)
|
||||
completions = completion_list[0]
|
||||
if n > 1:
|
||||
for c in completion_list[1:]:
|
||||
completions.choices.extend(c.choices)
|
||||
else:
|
||||
completions = await self.openai.chat.completions.create(**kwargs)
|
||||
else:
|
||||
if "n" in kwargs:
|
||||
n = kwargs["n"]
|
||||
else:
|
||||
n = 1
|
||||
completions = await self.openai.chat.completions.create(**kwargs)
|
||||
if len(completions.choices) != n:
|
||||
if len(completions.choices) != 1:
|
||||
raise ValueError(
|
||||
f"Expected 1 or {n} completions, got {len(completions.choices)}!"
|
||||
)
|
||||
else:
|
||||
warnings.warn("n kwarg is ignored by the API, setting to True")
|
||||
self.config.n_kwarg_is_ignored = True
|
||||
completion_list = await asyncio.gather(
|
||||
*[
|
||||
self.openai.chat.completions.create(**kwargs)
|
||||
for _ in range(1, n)
|
||||
]
|
||||
)
|
||||
for c in completion_list:
|
||||
completions.choices.extend(c.choices)
|
||||
return completions
|
||||
|
||||
async def _completion_wrapper(self, **kwargs) -> Completion:
|
||||
"""
|
||||
Wrapper for the completion using the openai client.
|
||||
"""
|
||||
assert (
|
||||
kwargs.get("model", None) is not None
|
||||
), "Model is required for completion!"
|
||||
assert (
|
||||
kwargs.get("prompt", None) is not None
|
||||
), "Prompt is required for completion!"
|
||||
if self.config.n_kwarg_is_ignored:
|
||||
n = kwargs.pop("n", 1)
|
||||
completion_list = await asyncio.gather(
|
||||
*[self.openai.completions.create(**kwargs) for _ in range(n)]
|
||||
)
|
||||
completions = completion_list[0]
|
||||
if n > 1:
|
||||
for c in completion_list[1:]:
|
||||
completions.choices.extend(c.choices)
|
||||
else:
|
||||
if "n" in kwargs:
|
||||
n = kwargs["n"]
|
||||
else:
|
||||
n = 1
|
||||
completions = await self.openai.completions.create(**kwargs)
|
||||
if len(completions.choices) != n:
|
||||
if len(completions.choices) != 1:
|
||||
raise ValueError(
|
||||
f"Expected 1 or {n} completions, got {len(completions.choices)}!"
|
||||
)
|
||||
else:
|
||||
warnings.warn("n kwarg is ignored by the API, setting to True")
|
||||
self.config.n_kwarg_is_ignored = True
|
||||
completion_list = await asyncio.gather(
|
||||
*[self.openai.completions.create(**kwargs) for _ in range(1, n)]
|
||||
)
|
||||
for c in completion_list:
|
||||
completions.choices.extend(c.choices)
|
||||
return completions
|
||||
|
||||
async def _tokens_and_logprobs_completion_wrapper(
|
||||
self, **kwargs
|
||||
) -> tuple[list, list, list, list]:
|
||||
"""
|
||||
Wrapper for tokens and logprobs completion using VLLM's native API.
|
||||
Returns a tuple of (prompt_tokens, output_tokens, output_logprobs, finish_reasons).
|
||||
Each element is a list of lists (one per completion in the batch).
|
||||
"""
|
||||
assert (
|
||||
kwargs.get("model", None) is not None
|
||||
), "Model is required for completion!"
|
||||
assert (
|
||||
kwargs.get("prompt", None) is not None
|
||||
or kwargs.get("input_ids", None) is not None
|
||||
), "Prompt or input_ids is required for completion!"
|
||||
|
||||
# Use input_ids if provided (from ManagedServer), otherwise tokenize prompt
|
||||
if "input_ids" in kwargs:
|
||||
prompt_tokens = kwargs.pop("input_ids")
|
||||
kwargs.pop("prompt", None) # Remove prompt if it exists
|
||||
else:
|
||||
prompt_tokens = self.tokenizer.encode(kwargs.pop("prompt"))
|
||||
|
||||
# Check for double BOS token, can happen if you use chat templates and forget that they insert a BOS token
|
||||
if (
|
||||
len(prompt_tokens) >= 2
|
||||
and prompt_tokens[0] == self.tokenizer.bos_token_id == prompt_tokens[1]
|
||||
):
|
||||
prompt_tokens = prompt_tokens[1:]
|
||||
if "max_new_tokens" in kwargs:
|
||||
kwargs["max_tokens"] = kwargs.pop("max_new_tokens")
|
||||
if "model" in kwargs:
|
||||
kwargs.pop("model")
|
||||
# Prepare request for VLLM native API
|
||||
request_data = {"prompt": {"prompt_token_ids": prompt_tokens}, "logprobs": 0}
|
||||
request_data.update(kwargs)
|
||||
|
||||
# Make async request to VLLM /generate endpoint
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.config.base_url.replace('/v1', '')}/generate",
|
||||
json=request_data,
|
||||
headers=(
|
||||
{"Authorization": f"Bearer {self.config.api_key}"}
|
||||
if self.config.api_key
|
||||
else {}
|
||||
),
|
||||
timeout=aiohttp.ClientTimeout(total=self.config.timeout),
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
results = await response.json()
|
||||
output_tokens_list = []
|
||||
output_logprobs_list = []
|
||||
finish_reasons_list = []
|
||||
for output_token_logprobs, finish_reason in zip(
|
||||
results["logprobs"], results["finish_reasons"]
|
||||
):
|
||||
logprobs = [
|
||||
list(item[0].values())[0] for item in output_token_logprobs
|
||||
] # Extract logprob from [{id: logprob}]
|
||||
output_ids = [
|
||||
int(list(item[0].keys())[0]) for item in output_token_logprobs
|
||||
] # Extract token ID from [{id: logprob}]
|
||||
|
||||
# Get finish reason
|
||||
finish_reason = finish_reason
|
||||
|
||||
output_tokens_list.append(output_ids)
|
||||
output_logprobs_list.append(logprobs)
|
||||
finish_reasons_list.append(finish_reason)
|
||||
|
||||
return (
|
||||
prompt_tokens,
|
||||
output_tokens_list,
|
||||
output_logprobs_list,
|
||||
finish_reasons_list,
|
||||
)
|
||||
|
||||
|
||||
def resolve_openai_configs(
|
||||
default_server_configs,
|
||||
openai_config_dict,
|
||||
yaml_config,
|
||||
cli_passed_flags,
|
||||
logger,
|
||||
):
|
||||
"""
|
||||
Helper to resolve the final server_configs, handling single, multiple servers, and overrides.
|
||||
"""
|
||||
from atroposlib.envs.server_handling.server_manager import ServerBaseline
|
||||
|
||||
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
|
||||
openai_yaml_config = yaml_config.get(OPENAI_NAMESPACE, None)
|
||||
openai_cli_config = {
|
||||
k: v for k, v in cli_passed_flags.items() if k.startswith(openai_full_prefix)
|
||||
}
|
||||
|
||||
is_multi_server_yaml = (
|
||||
isinstance(openai_yaml_config, list) and len(openai_yaml_config) >= 2
|
||||
)
|
||||
is_multi_server_default = (
|
||||
(not is_multi_server_yaml)
|
||||
and isinstance(default_server_configs, list)
|
||||
and len(default_server_configs) >= 2
|
||||
)
|
||||
|
||||
if (is_multi_server_yaml or is_multi_server_default) and openai_cli_config:
|
||||
raise FailedExecutionException(
|
||||
message=f"CLI overrides for OpenAI settings (--{openai_full_prefix}*) are not supported "
|
||||
f"when multiple servers are defined (either via YAML list under '{OPENAI_NAMESPACE}' "
|
||||
"or a default list with length >= 2).",
|
||||
exit_code=2,
|
||||
)
|
||||
|
||||
if is_multi_server_yaml:
|
||||
logger.info(
|
||||
f"Using multi-server configuration defined in YAML under '{OPENAI_NAMESPACE}'."
|
||||
)
|
||||
try:
|
||||
server_configs = [APIServerConfig(**cfg) for cfg in openai_yaml_config]
|
||||
except Exception as e:
|
||||
raise FailedExecutionException(
|
||||
f"Error parsing multi-server OpenAI configuration from YAML under '{OPENAI_NAMESPACE}': {e}"
|
||||
) from e
|
||||
elif isinstance(default_server_configs, ServerBaseline):
|
||||
logger.info("Using ServerBaseline configuration.")
|
||||
server_configs = default_server_configs
|
||||
elif is_multi_server_default:
|
||||
logger.info("Using default multi-server configuration (length >= 2).")
|
||||
server_configs = default_server_configs
|
||||
else:
|
||||
logger.info(
|
||||
"Using single OpenAI server configuration based on merged settings (default/YAML/CLI)."
|
||||
)
|
||||
try:
|
||||
final_openai_config = APIServerConfig(**openai_config_dict)
|
||||
except Exception as e:
|
||||
raise FailedExecutionException(
|
||||
f"Error creating final OpenAI configuration from merged settings: {e}\n"
|
||||
f"Merged Dict: {openai_config_dict}"
|
||||
) from e
|
||||
|
||||
if isinstance(default_server_configs, APIServerConfig):
|
||||
server_configs = final_openai_config
|
||||
elif isinstance(default_server_configs, list):
|
||||
server_configs = [final_openai_config]
|
||||
else:
|
||||
logger.warning(
|
||||
f"Unexpected type for default_server_configs: {type(default_server_configs)}. "
|
||||
f"Proceeding with single OpenAI server configuration based on merged settings."
|
||||
)
|
||||
server_configs = [final_openai_config]
|
||||
|
||||
return server_configs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
async def test_tokens_and_logprobs():
|
||||
# Configure the server - update these values for your setup
|
||||
config = APIServerConfig(
|
||||
api_key="", # Add your API key if needed
|
||||
base_url="http://localhost:8000", # Update to your VLLM server URL
|
||||
model_name="Qwen/Qwen2.5-7B", # Update to your model name
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
server = VLLMServer(config)
|
||||
|
||||
# Test the tokens_and_logprobs_completion method
|
||||
print("Testing tokens_and_logprobs_completion...")
|
||||
try:
|
||||
prompt_tokens, output_tokens, output_logprobs, finish_reasons = (
|
||||
await server.tokens_and_logprobs_completion(
|
||||
prompt="The capital of France is",
|
||||
n=4,
|
||||
max_tokens=32,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
stop=["User:", "Human:", "Assistant:", "</answer>"],
|
||||
)
|
||||
)
|
||||
|
||||
print("\nResults:")
|
||||
print(f"Prompt tokens: {prompt_tokens}")
|
||||
print(f"Output tokens: {output_tokens}")
|
||||
print(f"Output logprobs (first 5): {[lp[:5] for lp in output_logprobs]}")
|
||||
print(f"Finish reasons: {finish_reasons}")
|
||||
print(f"\nNumber of completions: {len(output_tokens)}")
|
||||
print(f"Output length: {[len(tokens) for tokens in output_tokens]}")
|
||||
responses = "\n\n".join(
|
||||
[server.tokenizer.decode(tokens) for tokens in output_tokens]
|
||||
)
|
||||
print(f"Responses:\n-{responses}")
|
||||
print(f"Health: {server.server_healthy}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
# Run the test
|
||||
asyncio.run(test_tokens_and_logprobs())
|
||||
|
|
@ -155,7 +155,7 @@ class MathEnv(BaseEnv):
|
|||
server_configs = ServerBaseline(
|
||||
model_name="Qwen/Qwen2.5-7B",
|
||||
num_requests_for_eval=256, # since evaling only on one...
|
||||
server_type="sglang",
|
||||
server_type="vllm",
|
||||
)
|
||||
|
||||
return env_config, server_configs
|
||||
|
|
@ -255,15 +255,15 @@ class MathEnv(BaseEnv):
|
|||
return
|
||||
|
||||
async def rollout_and_score_eval(self, question, answer, subset):
|
||||
|
||||
completion = await self.server.completion(
|
||||
prompt=question,
|
||||
n=1,
|
||||
max_tokens=32765,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
stop=stop_list,
|
||||
)
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
completion = await managed.completion(
|
||||
prompt=question,
|
||||
n=1,
|
||||
max_tokens=32765,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
stop=stop_list,
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
gold = "\\boxed{" + answer + "}" if "\\boxed" not in answer else answer
|
||||
resp = completion.choices[0].text
|
||||
|
|
|
|||
|
|
@ -8,6 +8,10 @@ This example uses `vLLM` for efficient inference during the (simulated) data gen
|
|||
|
||||
**Note:** This script is intended as a *reference example* for API integration and basic training setup. It is not optimized for large-scale, efficient training.
|
||||
|
||||
### Custom vLLM Server
|
||||
|
||||
The `vllm_api_server.py` file in this directory provides a customized vLLM API server implementation based on vLLM's native API. This server exposes enhanced endpoints for token and logprob tracking. The `VLLMServer` class in `atroposlib/envs/server_handling/vllm_server.py` can connect to this server for direct access to vLLM's `/generate` endpoint with full token-level logprobs.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
1. **Python:** Python 3.8 or higher is recommended.
|
||||
|
|
|
|||
224
example_trainer/vllm_api_server.py
Normal file
224
example_trainer/vllm_api_server.py
Normal file
|
|
@ -0,0 +1,224 @@
|
|||
# Based on https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/api_server.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
NOTE: This API server is used only for demonstrating usage of AsyncEngine
|
||||
and simple performance benchmarks. It is not intended for production use.
|
||||
For production use, we recommend using our OpenAI compatible server.
|
||||
We are also not going to accept PRs modifying this file, please
|
||||
change `vllm/entrypoints/openai/api_server.py` instead.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import ssl
|
||||
from argparse import Namespace
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
import vllm.envs as envs
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.entrypoints.launcher import serve_http
|
||||
from vllm.entrypoints.utils import with_cancellation
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.v1.engine.async_llm import AsyncLLMEngine
|
||||
|
||||
try:
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.system_utils import set_ulimit
|
||||
except ImportError:
|
||||
from vllm.utils import FlexibleArgumentParser, set_ulimit
|
||||
from vllm.outputs import RequestOutput # noqa: F401
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger("vllm.entrypoints.api_server")
|
||||
|
||||
app = FastAPI()
|
||||
engine = None
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Response:
|
||||
"""Health check."""
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.get("/health_generate")
|
||||
async def health_generate() -> Response:
|
||||
"""
|
||||
Check the health of the inference server by sending a special request to generate one token.
|
||||
"""
|
||||
assert engine is not None
|
||||
sampling_params = SamplingParams()
|
||||
request_id = random_uuid()
|
||||
results_generator = engine.generate(
|
||||
{"prompt_token_ids": [0]}, sampling_params, request_id
|
||||
)
|
||||
try:
|
||||
async for request_output in results_generator:
|
||||
final_output = request_output # type: RequestOutput # noqa: F841
|
||||
except asyncio.CancelledError:
|
||||
return Response(status_code=499)
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate(request: Request) -> Response:
|
||||
"""Generate completion for the request.
|
||||
|
||||
The request should be a JSON object with the following fields:
|
||||
- prompt: the prompt to use for the generation.
|
||||
- stream: whether to stream the results or not.
|
||||
- other fields: the sampling parameters (See `SamplingParams` for details).
|
||||
"""
|
||||
request_dict = await request.json()
|
||||
return await _generate(request_dict, raw_request=request)
|
||||
|
||||
|
||||
@with_cancellation
|
||||
async def _generate(request_dict: dict, raw_request: Request) -> Response:
|
||||
prompt = request_dict.pop("prompt")
|
||||
stream = request_dict.pop("stream", False)
|
||||
request_dict["output_kind"] = RequestOutputKind.FINAL_ONLY
|
||||
sampling_params = SamplingParams(**request_dict)
|
||||
request_id = random_uuid()
|
||||
|
||||
assert engine is not None
|
||||
results_generator = engine.generate(prompt, sampling_params, request_id)
|
||||
|
||||
# Streaming case
|
||||
async def stream_results() -> AsyncGenerator[bytes, None]:
|
||||
async for request_output in results_generator:
|
||||
prompt = request_output.prompt
|
||||
assert prompt is not None
|
||||
text_outputs = [prompt + output.text for output in request_output.outputs]
|
||||
ret = {"text": text_outputs}
|
||||
yield (json.dumps(ret) + "\n").encode("utf-8")
|
||||
|
||||
if stream:
|
||||
return StreamingResponse(stream_results())
|
||||
|
||||
# Non-streaming case
|
||||
final_output = None
|
||||
try:
|
||||
async for request_output in results_generator:
|
||||
final_output = request_output # type: RequestOutput
|
||||
except asyncio.CancelledError:
|
||||
return Response(status_code=499)
|
||||
|
||||
assert final_output is not None
|
||||
prompt = final_output.prompt or engine.tokenizer.decode(
|
||||
final_output.prompt_token_ids
|
||||
)
|
||||
assert prompt is not None
|
||||
text_outputs = [output.text for output in final_output.outputs]
|
||||
finish_reasons = [output.finish_reason for output in final_output.outputs]
|
||||
ret = {"text": text_outputs, "prompt": prompt, "finish_reasons": finish_reasons}
|
||||
if sampling_params.logprobs is not None:
|
||||
output_logprobs = [
|
||||
[
|
||||
[{key: value.logprob for key, value in logprob.items()}]
|
||||
for logprob in x.logprobs
|
||||
]
|
||||
for x in final_output.outputs
|
||||
]
|
||||
prompt_token_ids = final_output.prompt_token_ids
|
||||
output_token_ids = [x.token_ids for x in final_output.outputs]
|
||||
ret["logprobs"] = output_logprobs
|
||||
ret["prompt_token_ids"] = prompt_token_ids
|
||||
ret["token_ids"] = output_token_ids
|
||||
return JSONResponse(ret)
|
||||
|
||||
|
||||
def build_app(args: Namespace) -> FastAPI:
|
||||
global app # noqa: F824
|
||||
|
||||
app.root_path = args.root_path
|
||||
return app
|
||||
|
||||
|
||||
async def init_app(
|
||||
args: Namespace,
|
||||
llm_engine: AsyncLLMEngine | None = None,
|
||||
) -> FastAPI:
|
||||
app = build_app(args)
|
||||
|
||||
global engine
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = (
|
||||
llm_engine
|
||||
if llm_engine is not None
|
||||
else AsyncLLMEngine.from_engine_args(
|
||||
engine_args, usage_context=UsageContext.API_SERVER
|
||||
)
|
||||
)
|
||||
app.state.engine_client = engine
|
||||
return app
|
||||
|
||||
|
||||
async def run_server(
|
||||
args: Namespace, llm_engine: AsyncLLMEngine | None = None, **uvicorn_kwargs: Any
|
||||
) -> None:
|
||||
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||
logger.info("args: %s", args)
|
||||
|
||||
set_ulimit()
|
||||
app = await init_app(args, llm_engine)
|
||||
assert engine is not None
|
||||
|
||||
shutdown_task = await serve_http(
|
||||
app,
|
||||
sock=None,
|
||||
enable_ssl_refresh=args.enable_ssl_refresh,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level=args.log_level,
|
||||
timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
|
||||
ssl_keyfile=args.ssl_keyfile,
|
||||
ssl_certfile=args.ssl_certfile,
|
||||
ssl_ca_certs=args.ssl_ca_certs,
|
||||
ssl_cert_reqs=args.ssl_cert_reqs,
|
||||
**uvicorn_kwargs,
|
||||
)
|
||||
|
||||
await shutdown_task
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.add_argument("--host", type=str, default=None)
|
||||
parser.add_argument("--port", type=parser.check_port, default=8000)
|
||||
parser.add_argument("--ssl-keyfile", type=str, default=None)
|
||||
parser.add_argument("--ssl-certfile", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--ssl-ca-certs", type=str, default=None, help="The CA certificates file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-ssl-refresh",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Refresh SSL Context when SSL certificate files change",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ssl-cert-reqs",
|
||||
type=int,
|
||||
default=int(ssl.CERT_NONE),
|
||||
help="Whether client certificate is required (see stdlib ssl module's)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="FastAPI root_path when app is behind a path based routing proxy",
|
||||
)
|
||||
parser.add_argument("--log-level", type=str, default="debug")
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(run_server(args))
|
||||
Loading…
Add table
Add a link
Reference in a new issue