From e6ac3abdcb8b6a6315c8beb1a7065fcbaa05921e Mon Sep 17 00:00:00 2001 From: Dakota Date: Fri, 7 Nov 2025 13:06:49 -0600 Subject: [PATCH] add managed vllm server --- README.md | 2 + atroposlib/envs/README.md | 1 + .../envs/server_handling/MANAGED_SERVER.md | 2 + .../envs/server_handling/server_baseline.py | 4 +- .../envs/server_handling/server_manager.py | 11 +- .../envs/server_handling/vllm_server.py | 344 ++++++++++++++++++ environments/math_server_zero.py | 20 +- example_trainer/README.md | 4 + example_trainer/vllm_api_server.py | 224 ++++++++++++ 9 files changed, 597 insertions(+), 15 deletions(-) create mode 100644 atroposlib/envs/server_handling/vllm_server.py create mode 100644 example_trainer/vllm_api_server.py diff --git a/README.md b/README.md index cc49f372..b3ab6cf3 100644 --- a/README.md +++ b/README.md @@ -171,6 +171,8 @@ pre-commit install You should edit the config_init section of the environment file you want ([For example, in GSM8K Environment](https://github.com/NousResearch/atropos/blob/main/environments/gsm8k_server.py#L53)) to point to a running VLLM or SGLang inference server as well as any other [configuration changes](CONFIG.md) you'd like to make, such as the group size, then: + > **Note:** By default, Atropos uses the OpenAI-compatible API endpoint which works with any provider. For enhanced features, use `VLLMServer` (atroposlib/envs/server_handling/vllm_server.py) or `SGLangServer` (atroposlib/envs/server_handling/sglang_server.py) for direct access to native APIs with full token and logprob tracking. + ```bash # Start the API server run-api diff --git a/atroposlib/envs/README.md b/atroposlib/envs/README.md index 75068311..7b94d0aa 100644 --- a/atroposlib/envs/README.md +++ b/atroposlib/envs/README.md @@ -181,6 +181,7 @@ These class-level variables in `BaseEnv` can be overridden in your subclass to c * **`server_cls: Type[APIServer]`**: * Default: `APIServer` * Purpose: Specifies the class to be used for managing interactions with API servers (e.g., inference endpoints). Should mostly be used for developing additional API interfaces, but if you need a nonstandard way of connecting with an existing API you can use this to easily slot in any modifications you need. + * **Note:** In most cases, you should use the `server_type` field in your `APIServerConfig` instead of overriding this. Set `server_type` to `"openai"` (default), `"vllm"`, `"sglang"`, or `"trl"` to automatically use the appropriate server class with enhanced features like native API access and full token/logprob tracking. ## Provided Functionality diff --git a/atroposlib/envs/server_handling/MANAGED_SERVER.md b/atroposlib/envs/server_handling/MANAGED_SERVER.md index 20f7531e..fd3eb571 100644 --- a/atroposlib/envs/server_handling/MANAGED_SERVER.md +++ b/atroposlib/envs/server_handling/MANAGED_SERVER.md @@ -4,6 +4,8 @@ `ManagedServer` is a wrapper around `APIServer` that automatically tracks text sequences with aligned tokens and logprobs. It eliminates the need for manual token extraction, alignment, and masking in your environment code, making it **the recommended approach** for handling inference in Atropos environments. +**Server Compatibility:** ManagedServer works with all Atropos server types - `OpenAIServer`, `VLLMServer`, `SGLangServer`, and `TrlVllmServer`. Simply set the `server_type` field in your `APIServerConfig` to `"openai"` (default), `"vllm"`, `"sglang"`, or `"trl"` to use the appropriate backend with automatic server class selection. + ### Why Use ManagedServer? **Before ManagedServer** (manual approach): diff --git a/atroposlib/envs/server_handling/server_baseline.py b/atroposlib/envs/server_handling/server_baseline.py index 96e94b88..18dbf35d 100644 --- a/atroposlib/envs/server_handling/server_baseline.py +++ b/atroposlib/envs/server_handling/server_baseline.py @@ -108,8 +108,8 @@ class ServerBaseline(BaseModel): rolling_buffer_length: int = Field( default=1000, description="Length of the rolling buffer to store metrics." ) - server_type: Literal["openai", "trl", "sglang"] = Field( - default="openai", description="Type of server to use, openai or trl" + server_type: Literal["openai", "trl", "sglang", "vllm"] = Field( + default="openai", description="Type of server to use" ) diff --git a/atroposlib/envs/server_handling/server_manager.py b/atroposlib/envs/server_handling/server_manager.py index d2c89ab8..20cc79af 100644 --- a/atroposlib/envs/server_handling/server_manager.py +++ b/atroposlib/envs/server_handling/server_manager.py @@ -18,6 +18,7 @@ from atroposlib.envs.server_handling.server_baseline import ( from atroposlib.envs.server_handling.server_harness import ServerHarness from atroposlib.envs.server_handling.sglang_server import SGLangServer from atroposlib.envs.server_handling.trl_vllm_server import TrlVllmServer +from atroposlib.envs.server_handling.vllm_server import VLLMServer class ServerManagerConfig(BaseModel): @@ -58,6 +59,8 @@ class ServerManager: server_class = TrlVllmServer elif configs.server_type == "sglang": server_class = SGLangServer + elif configs.server_type == "vllm": + server_class = VLLMServer else: raise ValueError(f"Invalid server type: {configs.server_type}") else: @@ -67,6 +70,8 @@ class ServerManager: server_class = TrlVllmServer elif configs[0].server_type == "sglang": server_class = SGLangServer + elif configs[0].server_type == "vllm": + server_class = VLLMServer else: raise ValueError(f"Invalid server type: {configs[0].server_type}") if testing: @@ -198,7 +203,7 @@ class ServerManager: for completion in completions[1:]: out.choices.extend(completion.choices) return out - is_train = kwargs.get("split", "train") == "train" + is_train = kwargs.pop("split", "train") == "train" most_available_server = 0 most_available_server_num_slots = -1 await self.wait_for_sem(is_train) @@ -231,7 +236,7 @@ class ServerManager: for completion in completions[1:]: out.choices.extend(completion.choices) return out - is_train = kwargs.get("split", "train") == "train" + is_train = kwargs.pop("split", "train") == "train" most_available_server = 0 most_available_server_num_slots = -1 await self.wait_for_sem(is_train) @@ -276,7 +281,7 @@ class ServerManager: finish_reasons.extend(out_finish_reasons) return (prompt_tokens, output_tokens, output_logprobs, finish_reasons) - is_train = kwargs.get("split", "train") == "train" + is_train = kwargs.pop("split", "train") == "train" most_available_server = 0 most_available_server_num_slots = -1 await self.wait_for_sem(is_train) diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py new file mode 100644 index 00000000..8e043cf9 --- /dev/null +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -0,0 +1,344 @@ +# This requires a customized vLLM api server +# see example_trainer/vllm_api_server.py for an example + +import asyncio +import warnings + +import aiohttp +import openai +from openai.types.chat.chat_completion import ChatCompletion +from openai.types.completion import Completion +from pydantic_cli import FailedExecutionException +from transformers import AutoTokenizer + +from atroposlib.envs.constants import NAMESPACE_SEP, OPENAI_NAMESPACE +from atroposlib.envs.server_handling.server_baseline import APIServer, APIServerConfig + + +class VLLMServer(APIServer): + """ + VLLM server handling. + """ + + def __init__(self, config: APIServerConfig): + self.openai = openai.AsyncClient( + api_key=config.api_key, + base_url=config.base_url, + timeout=config.timeout, + ) + self.tokenizer = AutoTokenizer.from_pretrained(config.model_name) + super().__init__(config) + + async def check_server_status_task(self, chat_completion: bool = True): + while True: + try: + async with aiohttp.ClientSession() as session: + async with session.get( + f"{self.config.base_url.replace('/v1', '')}/health", + headers=( + {"Authorization": f"Bearer {self.config.api_key}"} + if self.config.api_key + else {} + ), + timeout=aiohttp.ClientTimeout(total=self.config.timeout), + ) as response: + response.raise_for_status() + self.server_healthy = True + except ( + aiohttp.ClientError, + openai.OpenAIError, + openai.APITimeoutError, + Exception, + ): + self.server_healthy = False + await asyncio.sleep(1) + + async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion: + """ + Wrapper for the chat completion using the openai client. + """ + assert ( + kwargs.get("model", None) is not None + ), "Model is required for chat completion!" + assert ( + kwargs.get("messages", None) is not None + ), "Messages are required for chat completion!" + if self.config.n_kwarg_is_ignored: + n = kwargs.pop("n", 1) + completion_list = await asyncio.gather( + *[self.openai.chat.completions.create(**kwargs) for _ in range(n)] + ) + completions = completion_list[0] + if n > 1: + for c in completion_list[1:]: + completions.choices.extend(c.choices) + else: + completions = await self.openai.chat.completions.create(**kwargs) + else: + if "n" in kwargs: + n = kwargs["n"] + else: + n = 1 + completions = await self.openai.chat.completions.create(**kwargs) + if len(completions.choices) != n: + if len(completions.choices) != 1: + raise ValueError( + f"Expected 1 or {n} completions, got {len(completions.choices)}!" + ) + else: + warnings.warn("n kwarg is ignored by the API, setting to True") + self.config.n_kwarg_is_ignored = True + completion_list = await asyncio.gather( + *[ + self.openai.chat.completions.create(**kwargs) + for _ in range(1, n) + ] + ) + for c in completion_list: + completions.choices.extend(c.choices) + return completions + + async def _completion_wrapper(self, **kwargs) -> Completion: + """ + Wrapper for the completion using the openai client. + """ + assert ( + kwargs.get("model", None) is not None + ), "Model is required for completion!" + assert ( + kwargs.get("prompt", None) is not None + ), "Prompt is required for completion!" + if self.config.n_kwarg_is_ignored: + n = kwargs.pop("n", 1) + completion_list = await asyncio.gather( + *[self.openai.completions.create(**kwargs) for _ in range(n)] + ) + completions = completion_list[0] + if n > 1: + for c in completion_list[1:]: + completions.choices.extend(c.choices) + else: + if "n" in kwargs: + n = kwargs["n"] + else: + n = 1 + completions = await self.openai.completions.create(**kwargs) + if len(completions.choices) != n: + if len(completions.choices) != 1: + raise ValueError( + f"Expected 1 or {n} completions, got {len(completions.choices)}!" + ) + else: + warnings.warn("n kwarg is ignored by the API, setting to True") + self.config.n_kwarg_is_ignored = True + completion_list = await asyncio.gather( + *[self.openai.completions.create(**kwargs) for _ in range(1, n)] + ) + for c in completion_list: + completions.choices.extend(c.choices) + return completions + + async def _tokens_and_logprobs_completion_wrapper( + self, **kwargs + ) -> tuple[list, list, list, list]: + """ + Wrapper for tokens and logprobs completion using VLLM's native API. + Returns a tuple of (prompt_tokens, output_tokens, output_logprobs, finish_reasons). + Each element is a list of lists (one per completion in the batch). + """ + assert ( + kwargs.get("model", None) is not None + ), "Model is required for completion!" + assert ( + kwargs.get("prompt", None) is not None + or kwargs.get("input_ids", None) is not None + ), "Prompt or input_ids is required for completion!" + + # Use input_ids if provided (from ManagedServer), otherwise tokenize prompt + if "input_ids" in kwargs: + prompt_tokens = kwargs.pop("input_ids") + kwargs.pop("prompt", None) # Remove prompt if it exists + else: + prompt_tokens = self.tokenizer.encode(kwargs.pop("prompt")) + + # Check for double BOS token, can happen if you use chat templates and forget that they insert a BOS token + if ( + len(prompt_tokens) >= 2 + and prompt_tokens[0] == self.tokenizer.bos_token_id == prompt_tokens[1] + ): + prompt_tokens = prompt_tokens[1:] + if "max_new_tokens" in kwargs: + kwargs["max_tokens"] = kwargs.pop("max_new_tokens") + if "model" in kwargs: + kwargs.pop("model") + # Prepare request for VLLM native API + request_data = {"prompt": {"prompt_token_ids": prompt_tokens}, "logprobs": 0} + request_data.update(kwargs) + + # Make async request to VLLM /generate endpoint + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.config.base_url.replace('/v1', '')}/generate", + json=request_data, + headers=( + {"Authorization": f"Bearer {self.config.api_key}"} + if self.config.api_key + else {} + ), + timeout=aiohttp.ClientTimeout(total=self.config.timeout), + ) as response: + response.raise_for_status() + results = await response.json() + output_tokens_list = [] + output_logprobs_list = [] + finish_reasons_list = [] + for output_token_logprobs, finish_reason in zip( + results["logprobs"], results["finish_reasons"] + ): + logprobs = [ + list(item[0].values())[0] for item in output_token_logprobs + ] # Extract logprob from [{id: logprob}] + output_ids = [ + int(list(item[0].keys())[0]) for item in output_token_logprobs + ] # Extract token ID from [{id: logprob}] + + # Get finish reason + finish_reason = finish_reason + + output_tokens_list.append(output_ids) + output_logprobs_list.append(logprobs) + finish_reasons_list.append(finish_reason) + + return ( + prompt_tokens, + output_tokens_list, + output_logprobs_list, + finish_reasons_list, + ) + + +def resolve_openai_configs( + default_server_configs, + openai_config_dict, + yaml_config, + cli_passed_flags, + logger, +): + """ + Helper to resolve the final server_configs, handling single, multiple servers, and overrides. + """ + from atroposlib.envs.server_handling.server_manager import ServerBaseline + + openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}" + openai_yaml_config = yaml_config.get(OPENAI_NAMESPACE, None) + openai_cli_config = { + k: v for k, v in cli_passed_flags.items() if k.startswith(openai_full_prefix) + } + + is_multi_server_yaml = ( + isinstance(openai_yaml_config, list) and len(openai_yaml_config) >= 2 + ) + is_multi_server_default = ( + (not is_multi_server_yaml) + and isinstance(default_server_configs, list) + and len(default_server_configs) >= 2 + ) + + if (is_multi_server_yaml or is_multi_server_default) and openai_cli_config: + raise FailedExecutionException( + message=f"CLI overrides for OpenAI settings (--{openai_full_prefix}*) are not supported " + f"when multiple servers are defined (either via YAML list under '{OPENAI_NAMESPACE}' " + "or a default list with length >= 2).", + exit_code=2, + ) + + if is_multi_server_yaml: + logger.info( + f"Using multi-server configuration defined in YAML under '{OPENAI_NAMESPACE}'." + ) + try: + server_configs = [APIServerConfig(**cfg) for cfg in openai_yaml_config] + except Exception as e: + raise FailedExecutionException( + f"Error parsing multi-server OpenAI configuration from YAML under '{OPENAI_NAMESPACE}': {e}" + ) from e + elif isinstance(default_server_configs, ServerBaseline): + logger.info("Using ServerBaseline configuration.") + server_configs = default_server_configs + elif is_multi_server_default: + logger.info("Using default multi-server configuration (length >= 2).") + server_configs = default_server_configs + else: + logger.info( + "Using single OpenAI server configuration based on merged settings (default/YAML/CLI)." + ) + try: + final_openai_config = APIServerConfig(**openai_config_dict) + except Exception as e: + raise FailedExecutionException( + f"Error creating final OpenAI configuration from merged settings: {e}\n" + f"Merged Dict: {openai_config_dict}" + ) from e + + if isinstance(default_server_configs, APIServerConfig): + server_configs = final_openai_config + elif isinstance(default_server_configs, list): + server_configs = [final_openai_config] + else: + logger.warning( + f"Unexpected type for default_server_configs: {type(default_server_configs)}. " + f"Proceeding with single OpenAI server configuration based on merged settings." + ) + server_configs = [final_openai_config] + + return server_configs + + +if __name__ == "__main__": + + async def test_tokens_and_logprobs(): + # Configure the server - update these values for your setup + config = APIServerConfig( + api_key="", # Add your API key if needed + base_url="http://localhost:8000", # Update to your VLLM server URL + model_name="Qwen/Qwen2.5-7B", # Update to your model name + timeout=120, + ) + + server = VLLMServer(config) + + # Test the tokens_and_logprobs_completion method + print("Testing tokens_and_logprobs_completion...") + try: + prompt_tokens, output_tokens, output_logprobs, finish_reasons = ( + await server.tokens_and_logprobs_completion( + prompt="The capital of France is", + n=4, + max_tokens=32, + temperature=1.0, + top_p=1.0, + stop=["User:", "Human:", "Assistant:", ""], + ) + ) + + print("\nResults:") + print(f"Prompt tokens: {prompt_tokens}") + print(f"Output tokens: {output_tokens}") + print(f"Output logprobs (first 5): {[lp[:5] for lp in output_logprobs]}") + print(f"Finish reasons: {finish_reasons}") + print(f"\nNumber of completions: {len(output_tokens)}") + print(f"Output length: {[len(tokens) for tokens in output_tokens]}") + responses = "\n\n".join( + [server.tokenizer.decode(tokens) for tokens in output_tokens] + ) + print(f"Responses:\n-{responses}") + print(f"Health: {server.server_healthy}") + + except Exception as e: + print(f"Error: {e}") + import traceback + + traceback.print_exc() + + # Run the test + asyncio.run(test_tokens_and_logprobs()) diff --git a/environments/math_server_zero.py b/environments/math_server_zero.py index c3aa3ea0..1432ab4d 100644 --- a/environments/math_server_zero.py +++ b/environments/math_server_zero.py @@ -155,7 +155,7 @@ class MathEnv(BaseEnv): server_configs = ServerBaseline( model_name="Qwen/Qwen2.5-7B", num_requests_for_eval=256, # since evaling only on one... - server_type="sglang", + server_type="vllm", ) return env_config, server_configs @@ -255,15 +255,15 @@ class MathEnv(BaseEnv): return async def rollout_and_score_eval(self, question, answer, subset): - - completion = await self.server.completion( - prompt=question, - n=1, - max_tokens=32765, - temperature=0.0, - split="eval", - stop=stop_list, - ) + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + completion = await managed.completion( + prompt=question, + n=1, + max_tokens=32765, + temperature=0.0, + split="eval", + stop=stop_list, + ) loop = asyncio.get_event_loop() gold = "\\boxed{" + answer + "}" if "\\boxed" not in answer else answer resp = completion.choices[0].text diff --git a/example_trainer/README.md b/example_trainer/README.md index 0ccd7a3d..0b2f883c 100644 --- a/example_trainer/README.md +++ b/example_trainer/README.md @@ -8,6 +8,10 @@ This example uses `vLLM` for efficient inference during the (simulated) data gen **Note:** This script is intended as a *reference example* for API integration and basic training setup. It is not optimized for large-scale, efficient training. +### Custom vLLM Server + +The `vllm_api_server.py` file in this directory provides a customized vLLM API server implementation based on vLLM's native API. This server exposes enhanced endpoints for token and logprob tracking. The `VLLMServer` class in `atroposlib/envs/server_handling/vllm_server.py` can connect to this server for direct access to vLLM's `/generate` endpoint with full token-level logprobs. + ## Prerequisites 1. **Python:** Python 3.8 or higher is recommended. diff --git a/example_trainer/vllm_api_server.py b/example_trainer/vllm_api_server.py new file mode 100644 index 00000000..920c71db --- /dev/null +++ b/example_trainer/vllm_api_server.py @@ -0,0 +1,224 @@ +# Based on https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/api_server.py +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +NOTE: This API server is used only for demonstrating usage of AsyncEngine +and simple performance benchmarks. It is not intended for production use. +For production use, we recommend using our OpenAI compatible server. +We are also not going to accept PRs modifying this file, please +change `vllm/entrypoints/openai/api_server.py` instead. +""" + +import asyncio +import json +import ssl +from argparse import Namespace +from collections.abc import AsyncGenerator +from typing import Any + +import vllm.envs as envs +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, Response, StreamingResponse +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.utils import with_cancellation +from vllm.logger import init_logger +from vllm.sampling_params import RequestOutputKind, SamplingParams +from vllm.usage.usage_lib import UsageContext +from vllm.utils import random_uuid +from vllm.v1.engine.async_llm import AsyncLLMEngine + +try: + from vllm.utils.argparse_utils import FlexibleArgumentParser + from vllm.utils.system_utils import set_ulimit +except ImportError: + from vllm.utils import FlexibleArgumentParser, set_ulimit +from vllm.outputs import RequestOutput # noqa: F401 +from vllm.version import __version__ as VLLM_VERSION + +logger = init_logger("vllm.entrypoints.api_server") + +app = FastAPI() +engine = None + + +@app.get("/health") +async def health() -> Response: + """Health check.""" + return Response(status_code=200) + + +@app.get("/health_generate") +async def health_generate() -> Response: + """ + Check the health of the inference server by sending a special request to generate one token. + """ + assert engine is not None + sampling_params = SamplingParams() + request_id = random_uuid() + results_generator = engine.generate( + {"prompt_token_ids": [0]}, sampling_params, request_id + ) + try: + async for request_output in results_generator: + final_output = request_output # type: RequestOutput # noqa: F841 + except asyncio.CancelledError: + return Response(status_code=499) + return Response(status_code=200) + + +@app.post("/generate") +async def generate(request: Request) -> Response: + """Generate completion for the request. + + The request should be a JSON object with the following fields: + - prompt: the prompt to use for the generation. + - stream: whether to stream the results or not. + - other fields: the sampling parameters (See `SamplingParams` for details). + """ + request_dict = await request.json() + return await _generate(request_dict, raw_request=request) + + +@with_cancellation +async def _generate(request_dict: dict, raw_request: Request) -> Response: + prompt = request_dict.pop("prompt") + stream = request_dict.pop("stream", False) + request_dict["output_kind"] = RequestOutputKind.FINAL_ONLY + sampling_params = SamplingParams(**request_dict) + request_id = random_uuid() + + assert engine is not None + results_generator = engine.generate(prompt, sampling_params, request_id) + + # Streaming case + async def stream_results() -> AsyncGenerator[bytes, None]: + async for request_output in results_generator: + prompt = request_output.prompt + assert prompt is not None + text_outputs = [prompt + output.text for output in request_output.outputs] + ret = {"text": text_outputs} + yield (json.dumps(ret) + "\n").encode("utf-8") + + if stream: + return StreamingResponse(stream_results()) + + # Non-streaming case + final_output = None + try: + async for request_output in results_generator: + final_output = request_output # type: RequestOutput + except asyncio.CancelledError: + return Response(status_code=499) + + assert final_output is not None + prompt = final_output.prompt or engine.tokenizer.decode( + final_output.prompt_token_ids + ) + assert prompt is not None + text_outputs = [output.text for output in final_output.outputs] + finish_reasons = [output.finish_reason for output in final_output.outputs] + ret = {"text": text_outputs, "prompt": prompt, "finish_reasons": finish_reasons} + if sampling_params.logprobs is not None: + output_logprobs = [ + [ + [{key: value.logprob for key, value in logprob.items()}] + for logprob in x.logprobs + ] + for x in final_output.outputs + ] + prompt_token_ids = final_output.prompt_token_ids + output_token_ids = [x.token_ids for x in final_output.outputs] + ret["logprobs"] = output_logprobs + ret["prompt_token_ids"] = prompt_token_ids + ret["token_ids"] = output_token_ids + return JSONResponse(ret) + + +def build_app(args: Namespace) -> FastAPI: + global app # noqa: F824 + + app.root_path = args.root_path + return app + + +async def init_app( + args: Namespace, + llm_engine: AsyncLLMEngine | None = None, +) -> FastAPI: + app = build_app(args) + + global engine + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = ( + llm_engine + if llm_engine is not None + else AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.API_SERVER + ) + ) + app.state.engine_client = engine + return app + + +async def run_server( + args: Namespace, llm_engine: AsyncLLMEngine | None = None, **uvicorn_kwargs: Any +) -> None: + logger.info("vLLM API server version %s", VLLM_VERSION) + logger.info("args: %s", args) + + set_ulimit() + app = await init_app(args, llm_engine) + assert engine is not None + + shutdown_task = await serve_http( + app, + sock=None, + enable_ssl_refresh=args.enable_ssl_refresh, + host=args.host, + port=args.port, + log_level=args.log_level, + timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + **uvicorn_kwargs, + ) + + await shutdown_task + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + parser.add_argument("--host", type=str, default=None) + parser.add_argument("--port", type=parser.check_port, default=8000) + parser.add_argument("--ssl-keyfile", type=str, default=None) + parser.add_argument("--ssl-certfile", type=str, default=None) + parser.add_argument( + "--ssl-ca-certs", type=str, default=None, help="The CA certificates file" + ) + parser.add_argument( + "--enable-ssl-refresh", + action="store_true", + default=False, + help="Refresh SSL Context when SSL certificate files change", + ) + parser.add_argument( + "--ssl-cert-reqs", + type=int, + default=int(ssl.CERT_NONE), + help="Whether client certificate is required (see stdlib ssl module's)", + ) + parser.add_argument( + "--root-path", + type=str, + default=None, + help="FastAPI root_path when app is behind a path based routing proxy", + ) + parser.add_argument("--log-level", type=str, default="debug") + parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + + asyncio.run(run_server(args))