# Based on https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/api_server.py # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ NOTE: This API server is used only for demonstrating usage of AsyncEngine and simple performance benchmarks. It is not intended for production use. For production use, we recommend using our OpenAI compatible server. We are also not going to accept PRs modifying this file, please change `vllm/entrypoints/openai/api_server.py` instead. """ import asyncio import json import ssl from argparse import Namespace from collections.abc import AsyncGenerator from typing import Any import vllm.envs as envs from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, Response, StreamingResponse from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.utils import with_cancellation from vllm.logger import init_logger from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.usage.usage_lib import UsageContext from vllm.utils import random_uuid from vllm.v1.engine.async_llm import AsyncLLMEngine try: from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.system_utils import set_ulimit except ImportError: from vllm.utils import FlexibleArgumentParser, set_ulimit from vllm.outputs import RequestOutput # noqa: F401 from vllm.version import __version__ as VLLM_VERSION logger = init_logger("vllm.entrypoints.api_server") app = FastAPI() engine = None @app.get("/health") async def health() -> Response: """Health check.""" return Response(status_code=200) @app.get("/health_generate") async def health_generate() -> Response: """ Check the health of the inference server by sending a special request to generate one token. """ assert engine is not None sampling_params = SamplingParams() request_id = random_uuid() results_generator = engine.generate( {"prompt_token_ids": [0]}, sampling_params, request_id ) try: async for request_output in results_generator: final_output = request_output # type: RequestOutput # noqa: F841 except asyncio.CancelledError: return Response(status_code=499) return Response(status_code=200) @app.post("/generate") async def generate(request: Request) -> Response: """Generate completion for the request. The request should be a JSON object with the following fields: - prompt: the prompt to use for the generation. - stream: whether to stream the results or not. - other fields: the sampling parameters (See `SamplingParams` for details). """ request_dict = await request.json() return await _generate(request_dict, raw_request=request) @with_cancellation async def _generate(request_dict: dict, raw_request: Request) -> Response: prompt = request_dict.pop("prompt") stream = request_dict.pop("stream", False) request_dict["output_kind"] = RequestOutputKind.FINAL_ONLY sampling_params = SamplingParams(**request_dict) request_id = random_uuid() assert engine is not None results_generator = engine.generate(prompt, sampling_params, request_id) # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: async for request_output in results_generator: prompt = request_output.prompt assert prompt is not None text_outputs = [prompt + output.text for output in request_output.outputs] ret = {"text": text_outputs} yield (json.dumps(ret) + "\n").encode("utf-8") if stream: return StreamingResponse(stream_results()) # Non-streaming case final_output = None try: async for request_output in results_generator: final_output = request_output # type: RequestOutput except asyncio.CancelledError: return Response(status_code=499) assert final_output is not None prompt = final_output.prompt or engine.tokenizer.decode( final_output.prompt_token_ids ) assert prompt is not None text_outputs = [output.text for output in final_output.outputs] finish_reasons = [output.finish_reason for output in final_output.outputs] ret = {"text": text_outputs, "prompt": prompt, "finish_reasons": finish_reasons} if sampling_params.logprobs is not None: output_logprobs = [ [ [{key: value.logprob for key, value in logprob.items()}] for logprob in x.logprobs ] for x in final_output.outputs ] prompt_token_ids = final_output.prompt_token_ids output_token_ids = [x.token_ids for x in final_output.outputs] ret["logprobs"] = output_logprobs ret["prompt_token_ids"] = prompt_token_ids ret["token_ids"] = output_token_ids return JSONResponse(ret) def build_app(args: Namespace) -> FastAPI: global app # noqa: F824 app.root_path = args.root_path return app async def init_app( args: Namespace, llm_engine: AsyncLLMEngine | None = None, ) -> FastAPI: app = build_app(args) global engine engine_args = AsyncEngineArgs.from_cli_args(args) engine = ( llm_engine if llm_engine is not None else AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.API_SERVER ) ) app.state.engine_client = engine return app async def run_server( args: Namespace, llm_engine: AsyncLLMEngine | None = None, **uvicorn_kwargs: Any ) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) set_ulimit() app = await init_app(args, llm_engine) assert engine is not None shutdown_task = await serve_http( app, sock=None, enable_ssl_refresh=args.enable_ssl_refresh, host=args.host, port=args.port, log_level=args.log_level, timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE, ssl_keyfile=args.ssl_keyfile, ssl_certfile=args.ssl_certfile, ssl_ca_certs=args.ssl_ca_certs, ssl_cert_reqs=args.ssl_cert_reqs, **uvicorn_kwargs, ) await shutdown_task if __name__ == "__main__": parser = FlexibleArgumentParser() parser.add_argument("--host", type=str, default=None) parser.add_argument("--port", type=parser.check_port, default=8000) parser.add_argument("--ssl-keyfile", type=str, default=None) parser.add_argument("--ssl-certfile", type=str, default=None) parser.add_argument( "--ssl-ca-certs", type=str, default=None, help="The CA certificates file" ) parser.add_argument( "--enable-ssl-refresh", action="store_true", default=False, help="Refresh SSL Context when SSL certificate files change", ) parser.add_argument( "--ssl-cert-reqs", type=int, default=int(ssl.CERT_NONE), help="Whether client certificate is required (see stdlib ssl module's)", ) parser.add_argument( "--root-path", type=str, default=None, help="FastAPI root_path when app is behind a path based routing proxy", ) parser.add_argument("--log-level", type=str, default="debug") parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() asyncio.run(run_server(args))