mirror of
https://github.com/lilakk/BLEUBERI.git
synced 2026-04-19 12:58:12 +00:00
939 lines
32 KiB
Python
939 lines
32 KiB
Python
"""A server that provides OpenAI-compatible RESTful APIs. It supports:
|
|
|
|
- Chat Completions. (Reference: https://platform.openai.com/docs/api-reference/chat)
|
|
- Completions. (Reference: https://platform.openai.com/docs/api-reference/completions)
|
|
- Embeddings. (Reference: https://platform.openai.com/docs/api-reference/embeddings)
|
|
|
|
Usage:
|
|
python3 -m fastchat.serve.openai_api_server
|
|
"""
|
|
import asyncio
|
|
import argparse
|
|
import json
|
|
import os
|
|
from typing import Generator, Optional, Union, Dict, List, Any
|
|
|
|
import aiohttp
|
|
import fastapi
|
|
from fastapi import Depends, HTTPException
|
|
from fastapi.exceptions import RequestValidationError
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import StreamingResponse, JSONResponse
|
|
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
|
import httpx
|
|
|
|
from pydantic_settings import BaseSettings
|
|
import shortuuid
|
|
import tiktoken
|
|
import uvicorn
|
|
|
|
from fastchat.constants import (
|
|
WORKER_API_TIMEOUT,
|
|
WORKER_API_EMBEDDING_BATCH_SIZE,
|
|
ErrorCode,
|
|
)
|
|
from fastchat.conversation import Conversation, SeparatorStyle
|
|
from fastchat.protocol.openai_api_protocol import (
|
|
ChatCompletionRequest,
|
|
ChatCompletionResponse,
|
|
ChatCompletionResponseStreamChoice,
|
|
ChatCompletionStreamResponse,
|
|
ChatMessage,
|
|
ChatCompletionResponseChoice,
|
|
CompletionRequest,
|
|
CompletionResponse,
|
|
CompletionResponseChoice,
|
|
DeltaMessage,
|
|
CompletionResponseStreamChoice,
|
|
CompletionStreamResponse,
|
|
EmbeddingsRequest,
|
|
EmbeddingsResponse,
|
|
ErrorResponse,
|
|
LogProbs,
|
|
ModelCard,
|
|
ModelList,
|
|
ModelPermission,
|
|
UsageInfo,
|
|
)
|
|
from fastchat.protocol.api_protocol import (
|
|
APIChatCompletionRequest,
|
|
APITokenCheckRequest,
|
|
APITokenCheckResponse,
|
|
APITokenCheckResponseItem,
|
|
)
|
|
from fastchat.utils import build_logger
|
|
|
|
logger = build_logger("openai_api_server", "openai_api_server.log")
|
|
|
|
conv_template_map = {}
|
|
|
|
fetch_timeout = aiohttp.ClientTimeout(total=3 * 3600)
|
|
|
|
|
|
async def fetch_remote(url, pload=None, name=None):
|
|
async with aiohttp.ClientSession(timeout=fetch_timeout) as session:
|
|
async with session.post(url, json=pload) as response:
|
|
chunks = []
|
|
if response.status != 200:
|
|
ret = {
|
|
"text": f"{response.reason}",
|
|
"error_code": ErrorCode.INTERNAL_ERROR,
|
|
}
|
|
return json.dumps(ret)
|
|
|
|
async for chunk, _ in response.content.iter_chunks():
|
|
chunks.append(chunk)
|
|
output = b"".join(chunks)
|
|
|
|
if name is not None:
|
|
res = json.loads(output)
|
|
if name != "":
|
|
res = res[name]
|
|
return res
|
|
|
|
return output
|
|
|
|
|
|
class AppSettings(BaseSettings):
|
|
# The address of the model controller.
|
|
controller_address: str = "http://localhost:21001"
|
|
api_keys: Optional[List[str]] = None
|
|
|
|
|
|
app_settings = AppSettings()
|
|
app = fastapi.FastAPI()
|
|
headers = {"User-Agent": "FastChat API Server"}
|
|
get_bearer_token = HTTPBearer(auto_error=False)
|
|
|
|
|
|
async def check_api_key(
|
|
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
|
|
) -> str:
|
|
if app_settings.api_keys:
|
|
if auth is None or (token := auth.credentials) not in app_settings.api_keys:
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail={
|
|
"error": {
|
|
"message": "",
|
|
"type": "invalid_request_error",
|
|
"param": None,
|
|
"code": "invalid_api_key",
|
|
}
|
|
},
|
|
)
|
|
return token
|
|
else:
|
|
# api_keys not set; allow all
|
|
return None
|
|
|
|
|
|
def create_error_response(code: int, message: str) -> JSONResponse:
|
|
return JSONResponse(
|
|
ErrorResponse(message=message, code=code).model_dump(), status_code=400
|
|
)
|
|
|
|
|
|
@app.exception_handler(RequestValidationError)
|
|
async def validation_exception_handler(request, exc):
|
|
return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc))
|
|
|
|
|
|
async def check_model(request) -> Optional[JSONResponse]:
|
|
controller_address = app_settings.controller_address
|
|
ret = None
|
|
|
|
models = await fetch_remote(controller_address + "/list_models", None, "models")
|
|
if request.model not in models:
|
|
ret = create_error_response(
|
|
ErrorCode.INVALID_MODEL,
|
|
f"Only {'&&'.join(models)} allowed now, your model {request.model}",
|
|
)
|
|
return ret
|
|
|
|
|
|
async def check_length(request, prompt, max_tokens, worker_addr):
|
|
if (
|
|
not isinstance(max_tokens, int) or max_tokens <= 0
|
|
): # model worker not support max_tokens=None
|
|
max_tokens = 1024 * 1024
|
|
|
|
context_len = await fetch_remote(
|
|
worker_addr + "/model_details", {"model": request.model}, "context_length"
|
|
)
|
|
token_num = await fetch_remote(
|
|
worker_addr + "/count_token",
|
|
{"model": request.model, "prompt": prompt},
|
|
"count",
|
|
)
|
|
length = min(max_tokens, context_len - token_num)
|
|
|
|
if length <= 0:
|
|
return None, create_error_response(
|
|
ErrorCode.CONTEXT_OVERFLOW,
|
|
f"This model's maximum context length is {context_len} tokens. However, your messages resulted in {token_num} tokens. Please reduce the length of the messages.",
|
|
)
|
|
|
|
return length, None
|
|
|
|
|
|
def check_requests(request) -> Optional[JSONResponse]:
|
|
# Check all params
|
|
if request.max_tokens is not None and request.max_tokens <= 0:
|
|
return create_error_response(
|
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
|
f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'",
|
|
)
|
|
if request.n is not None and request.n <= 0:
|
|
return create_error_response(
|
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
|
f"{request.n} is less than the minimum of 1 - 'n'",
|
|
)
|
|
if request.temperature is not None and request.temperature < 0:
|
|
return create_error_response(
|
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
|
f"{request.temperature} is less than the minimum of 0 - 'temperature'",
|
|
)
|
|
if request.temperature is not None and request.temperature > 2:
|
|
return create_error_response(
|
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
|
f"{request.temperature} is greater than the maximum of 2 - 'temperature'",
|
|
)
|
|
if request.top_p is not None and request.top_p < 0:
|
|
return create_error_response(
|
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
|
f"{request.top_p} is less than the minimum of 0 - 'top_p'",
|
|
)
|
|
if request.top_p is not None and request.top_p > 1:
|
|
return create_error_response(
|
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
|
f"{request.top_p} is greater than the maximum of 1 - 'top_p'",
|
|
)
|
|
if request.top_k is not None and (request.top_k > -1 and request.top_k < 1):
|
|
return create_error_response(
|
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
|
f"{request.top_k} is out of Range. Either set top_k to -1 or >=1.",
|
|
)
|
|
if request.stop is not None and (
|
|
not isinstance(request.stop, str) and not isinstance(request.stop, list)
|
|
):
|
|
return create_error_response(
|
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
|
f"{request.stop} is not valid under any of the given schemas - 'stop'",
|
|
)
|
|
|
|
return None
|
|
|
|
|
|
def process_input(model_name, inp):
|
|
if isinstance(inp, str):
|
|
inp = [inp]
|
|
elif isinstance(inp, list):
|
|
if isinstance(inp[0], int):
|
|
try:
|
|
decoding = tiktoken.model.encoding_for_model(model_name)
|
|
except KeyError:
|
|
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
|
model = "cl100k_base"
|
|
decoding = tiktoken.get_encoding(model)
|
|
inp = [decoding.decode(inp)]
|
|
elif isinstance(inp[0], list):
|
|
try:
|
|
decoding = tiktoken.model.encoding_for_model(model_name)
|
|
except KeyError:
|
|
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
|
model = "cl100k_base"
|
|
decoding = tiktoken.get_encoding(model)
|
|
inp = [decoding.decode(text) for text in inp]
|
|
|
|
return inp
|
|
|
|
|
|
def create_openai_logprobs(logprob_dict):
|
|
"""Create OpenAI-style logprobs."""
|
|
return LogProbs(**logprob_dict) if logprob_dict is not None else None
|
|
|
|
|
|
def _add_to_set(s, new_stop):
|
|
if not s:
|
|
return
|
|
if isinstance(s, str):
|
|
new_stop.add(s)
|
|
else:
|
|
new_stop.update(s)
|
|
|
|
|
|
async def get_gen_params(
|
|
model_name: str,
|
|
worker_addr: str,
|
|
messages: Union[str, List[Dict[str, str]]],
|
|
*,
|
|
temperature: float,
|
|
top_p: float,
|
|
top_k: Optional[int],
|
|
presence_penalty: Optional[float],
|
|
frequency_penalty: Optional[float],
|
|
max_tokens: Optional[int],
|
|
echo: Optional[bool],
|
|
logprobs: Optional[int] = None,
|
|
stop: Optional[Union[str, List[str]]],
|
|
best_of: Optional[int] = None,
|
|
use_beam_search: Optional[bool] = None,
|
|
) -> Dict[str, Any]:
|
|
conv = await get_conv(model_name, worker_addr)
|
|
conv = Conversation(
|
|
name=conv["name"],
|
|
system_template=conv["system_template"],
|
|
system_message=conv["system_message"],
|
|
roles=conv["roles"],
|
|
messages=list(conv["messages"]), # prevent in-place modification
|
|
offset=conv["offset"],
|
|
sep_style=SeparatorStyle(conv["sep_style"]),
|
|
sep=conv["sep"],
|
|
sep2=conv["sep2"],
|
|
stop_str=conv["stop_str"],
|
|
stop_token_ids=conv["stop_token_ids"],
|
|
)
|
|
|
|
if isinstance(messages, str):
|
|
prompt = messages
|
|
images = []
|
|
else:
|
|
for message in messages:
|
|
msg_role = message["role"]
|
|
if msg_role == "system":
|
|
conv.set_system_message(message["content"])
|
|
elif msg_role == "user":
|
|
if type(message["content"]) == list:
|
|
image_list = [
|
|
item["image_url"]["url"]
|
|
for item in message["content"]
|
|
if item["type"] == "image_url"
|
|
]
|
|
text_list = [
|
|
item["text"]
|
|
for item in message["content"]
|
|
if item["type"] == "text"
|
|
]
|
|
|
|
# TODO(chris): This only applies to LLaVA model. Implement an image_token string in the conv template.
|
|
text = "<image>\n" * len(image_list)
|
|
text += "\n".join(text_list)
|
|
conv.append_message(conv.roles[0], (text, image_list))
|
|
else:
|
|
conv.append_message(conv.roles[0], message["content"])
|
|
elif msg_role == "assistant":
|
|
conv.append_message(conv.roles[1], message["content"])
|
|
else:
|
|
raise ValueError(f"Unknown role: {msg_role}")
|
|
|
|
# Add a blank message for the assistant.
|
|
conv.append_message(conv.roles[1], None)
|
|
prompt = conv.get_prompt()
|
|
images = conv.get_images()
|
|
|
|
gen_params = {
|
|
"model": model_name,
|
|
"prompt": prompt,
|
|
"temperature": temperature,
|
|
"logprobs": logprobs,
|
|
"top_p": top_p,
|
|
"top_k": top_k,
|
|
"presence_penalty": presence_penalty,
|
|
"frequency_penalty": frequency_penalty,
|
|
"max_new_tokens": max_tokens,
|
|
"echo": echo,
|
|
"stop_token_ids": conv.stop_token_ids,
|
|
}
|
|
|
|
if len(images) > 0:
|
|
gen_params["images"] = images
|
|
|
|
if best_of is not None:
|
|
gen_params.update({"best_of": best_of})
|
|
if use_beam_search is not None:
|
|
gen_params.update({"use_beam_search": use_beam_search})
|
|
|
|
new_stop = set()
|
|
_add_to_set(stop, new_stop)
|
|
_add_to_set(conv.stop_str, new_stop)
|
|
|
|
gen_params["stop"] = list(new_stop)
|
|
|
|
logger.debug(f"==== request ====\n{gen_params}")
|
|
return gen_params
|
|
|
|
|
|
async def get_worker_address(model_name: str) -> str:
|
|
"""
|
|
Get worker address based on the requested model
|
|
|
|
:param model_name: The worker's model name
|
|
:return: Worker address from the controller
|
|
:raises: :class:`ValueError`: No available worker for requested model
|
|
"""
|
|
controller_address = app_settings.controller_address
|
|
worker_addr = await fetch_remote(
|
|
controller_address + "/get_worker_address", {"model": model_name}, "address"
|
|
)
|
|
|
|
# No available worker
|
|
if worker_addr == "":
|
|
raise ValueError(f"No available worker for {model_name}")
|
|
logger.debug(f"model_name: {model_name}, worker_addr: {worker_addr}")
|
|
return worker_addr
|
|
|
|
|
|
async def get_conv(model_name: str, worker_addr: str):
|
|
conv_template = conv_template_map.get((worker_addr, model_name))
|
|
if conv_template is None:
|
|
conv_template = await fetch_remote(
|
|
worker_addr + "/worker_get_conv_template", {"model": model_name}, "conv"
|
|
)
|
|
conv_template_map[(worker_addr, model_name)] = conv_template
|
|
return conv_template
|
|
|
|
|
|
@app.get("/v1/models", dependencies=[Depends(check_api_key)])
|
|
async def show_available_models():
|
|
controller_address = app_settings.controller_address
|
|
ret = await fetch_remote(controller_address + "/refresh_all_workers")
|
|
models = await fetch_remote(controller_address + "/list_models", None, "models")
|
|
|
|
models.sort()
|
|
# TODO: return real model permission details
|
|
model_cards = []
|
|
for m in models:
|
|
model_cards.append(ModelCard(id=m, root=m, permission=[ModelPermission()]))
|
|
return ModelList(data=model_cards)
|
|
|
|
|
|
@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key)])
|
|
async def create_chat_completion(request: ChatCompletionRequest):
|
|
"""Creates a completion for the chat message"""
|
|
error_check_ret = await check_model(request)
|
|
if error_check_ret is not None:
|
|
return error_check_ret
|
|
error_check_ret = check_requests(request)
|
|
if error_check_ret is not None:
|
|
return error_check_ret
|
|
|
|
worker_addr = await get_worker_address(request.model)
|
|
|
|
gen_params = await get_gen_params(
|
|
request.model,
|
|
worker_addr,
|
|
request.messages,
|
|
temperature=request.temperature,
|
|
top_p=request.top_p,
|
|
top_k=request.top_k,
|
|
presence_penalty=request.presence_penalty,
|
|
frequency_penalty=request.frequency_penalty,
|
|
max_tokens=request.max_tokens,
|
|
echo=False,
|
|
stop=request.stop,
|
|
)
|
|
|
|
max_new_tokens, error_check_ret = await check_length(
|
|
request,
|
|
gen_params["prompt"],
|
|
gen_params["max_new_tokens"],
|
|
worker_addr,
|
|
)
|
|
|
|
if error_check_ret is not None:
|
|
return error_check_ret
|
|
|
|
gen_params["max_new_tokens"] = max_new_tokens
|
|
|
|
if request.stream:
|
|
generator = chat_completion_stream_generator(
|
|
request.model, gen_params, request.n, worker_addr
|
|
)
|
|
return StreamingResponse(generator, media_type="text/event-stream")
|
|
|
|
choices = []
|
|
chat_completions = []
|
|
for i in range(request.n):
|
|
content = asyncio.create_task(generate_completion(gen_params, worker_addr))
|
|
chat_completions.append(content)
|
|
try:
|
|
all_tasks = await asyncio.gather(*chat_completions)
|
|
except Exception as e:
|
|
return create_error_response(ErrorCode.INTERNAL_ERROR, str(e))
|
|
usage = UsageInfo()
|
|
for i, content in enumerate(all_tasks):
|
|
if isinstance(content, str):
|
|
content = json.loads(content)
|
|
|
|
if content["error_code"] != 0:
|
|
return create_error_response(content["error_code"], content["text"])
|
|
choices.append(
|
|
ChatCompletionResponseChoice(
|
|
index=i,
|
|
message=ChatMessage(role="assistant", content=content["text"]),
|
|
finish_reason=content.get("finish_reason", "stop"),
|
|
)
|
|
)
|
|
if "usage" in content:
|
|
task_usage = UsageInfo.model_validate(content["usage"])
|
|
for usage_key, usage_value in task_usage.model_dump().items():
|
|
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
|
|
|
|
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
|
|
|
|
|
async def chat_completion_stream_generator(
|
|
model_name: str, gen_params: Dict[str, Any], n: int, worker_addr: str
|
|
) -> Generator[str, Any, None]:
|
|
"""
|
|
Event stream format:
|
|
https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format
|
|
"""
|
|
id = f"chatcmpl-{shortuuid.random()}"
|
|
finish_stream_events = []
|
|
for i in range(n):
|
|
# First chunk with role
|
|
choice_data = ChatCompletionResponseStreamChoice(
|
|
index=i,
|
|
delta=DeltaMessage(role="assistant"),
|
|
finish_reason=None,
|
|
)
|
|
chunk = ChatCompletionStreamResponse(
|
|
id=id, choices=[choice_data], model=model_name
|
|
)
|
|
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
|
|
|
|
previous_text = ""
|
|
async for content in generate_completion_stream(gen_params, worker_addr):
|
|
if content["error_code"] != 0:
|
|
yield f"data: {json.dumps(content, ensure_ascii=False)}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
return
|
|
decoded_unicode = content["text"].replace("\ufffd", "")
|
|
delta_text = decoded_unicode[len(previous_text) :]
|
|
previous_text = (
|
|
decoded_unicode
|
|
if len(decoded_unicode) > len(previous_text)
|
|
else previous_text
|
|
)
|
|
|
|
if len(delta_text) == 0:
|
|
delta_text = None
|
|
choice_data = ChatCompletionResponseStreamChoice(
|
|
index=i,
|
|
delta=DeltaMessage(content=delta_text),
|
|
finish_reason=content.get("finish_reason", None),
|
|
)
|
|
chunk = ChatCompletionStreamResponse(
|
|
id=id, choices=[choice_data], model=model_name
|
|
)
|
|
if delta_text is None:
|
|
if content.get("finish_reason", None) is not None:
|
|
finish_stream_events.append(chunk)
|
|
continue
|
|
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
|
|
# There is not "content" field in the last delta message, so exclude_none to exclude field "content".
|
|
for finish_chunk in finish_stream_events:
|
|
yield f"data: {finish_chunk.model_dump_json(exclude_none=True)}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
|
|
@app.post("/v1/completions", dependencies=[Depends(check_api_key)])
|
|
async def create_completion(request: CompletionRequest):
|
|
error_check_ret = await check_model(request)
|
|
if error_check_ret is not None:
|
|
return error_check_ret
|
|
error_check_ret = check_requests(request)
|
|
if error_check_ret is not None:
|
|
return error_check_ret
|
|
|
|
request.prompt = process_input(request.model, request.prompt)
|
|
|
|
worker_addr = await get_worker_address(request.model)
|
|
for text in request.prompt:
|
|
max_tokens, error_check_ret = await check_length(
|
|
request, text, request.max_tokens, worker_addr
|
|
)
|
|
if error_check_ret is not None:
|
|
return error_check_ret
|
|
|
|
if isinstance(max_tokens, int) and max_tokens < request.max_tokens:
|
|
request.max_tokens = max_tokens
|
|
|
|
if request.stream:
|
|
generator = generate_completion_stream_generator(
|
|
request, request.n, worker_addr
|
|
)
|
|
return StreamingResponse(generator, media_type="text/event-stream")
|
|
else:
|
|
text_completions = []
|
|
for text in request.prompt:
|
|
gen_params = await get_gen_params(
|
|
request.model,
|
|
worker_addr,
|
|
text,
|
|
temperature=request.temperature,
|
|
top_p=request.top_p,
|
|
top_k=request.top_k,
|
|
frequency_penalty=request.frequency_penalty,
|
|
presence_penalty=request.presence_penalty,
|
|
max_tokens=request.max_tokens,
|
|
logprobs=request.logprobs,
|
|
echo=request.echo,
|
|
stop=request.stop,
|
|
best_of=request.best_of,
|
|
use_beam_search=request.use_beam_search,
|
|
)
|
|
for i in range(request.n):
|
|
content = asyncio.create_task(
|
|
generate_completion(gen_params, worker_addr)
|
|
)
|
|
text_completions.append(content)
|
|
|
|
try:
|
|
all_tasks = await asyncio.gather(*text_completions)
|
|
except Exception as e:
|
|
return create_error_response(ErrorCode.INTERNAL_ERROR, str(e))
|
|
|
|
choices = []
|
|
usage = UsageInfo()
|
|
for i, content in enumerate(all_tasks):
|
|
if content["error_code"] != 0:
|
|
return create_error_response(content["error_code"], content["text"])
|
|
choices.append(
|
|
CompletionResponseChoice(
|
|
index=i,
|
|
text=content["text"],
|
|
logprobs=create_openai_logprobs(content.get("logprobs", None)),
|
|
finish_reason=content.get("finish_reason", "stop"),
|
|
)
|
|
)
|
|
task_usage = UsageInfo.model_validate(content["usage"])
|
|
for usage_key, usage_value in task_usage.model_dump().items():
|
|
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
|
|
|
|
return CompletionResponse(
|
|
model=request.model, choices=choices, usage=UsageInfo.model_validate(usage)
|
|
)
|
|
|
|
|
|
async def generate_completion_stream_generator(
|
|
request: CompletionRequest, n: int, worker_addr: str
|
|
):
|
|
model_name = request.model
|
|
id = f"cmpl-{shortuuid.random()}"
|
|
finish_stream_events = []
|
|
for text in request.prompt:
|
|
for i in range(n):
|
|
previous_text = ""
|
|
gen_params = await get_gen_params(
|
|
request.model,
|
|
worker_addr,
|
|
text,
|
|
temperature=request.temperature,
|
|
top_p=request.top_p,
|
|
top_k=request.top_k,
|
|
presence_penalty=request.presence_penalty,
|
|
frequency_penalty=request.frequency_penalty,
|
|
max_tokens=request.max_tokens,
|
|
logprobs=request.logprobs,
|
|
echo=request.echo,
|
|
stop=request.stop,
|
|
)
|
|
async for content in generate_completion_stream(gen_params, worker_addr):
|
|
if content["error_code"] != 0:
|
|
yield f"data: {json.dumps(content, ensure_ascii=False)}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
return
|
|
decoded_unicode = content["text"].replace("\ufffd", "")
|
|
delta_text = decoded_unicode[len(previous_text) :]
|
|
previous_text = (
|
|
decoded_unicode
|
|
if len(decoded_unicode) > len(previous_text)
|
|
else previous_text
|
|
)
|
|
# todo: index is not apparent
|
|
choice_data = CompletionResponseStreamChoice(
|
|
index=i,
|
|
text=delta_text,
|
|
logprobs=create_openai_logprobs(content.get("logprobs", None)),
|
|
finish_reason=content.get("finish_reason", None),
|
|
)
|
|
chunk = CompletionStreamResponse(
|
|
id=id,
|
|
object="text_completion",
|
|
choices=[choice_data],
|
|
model=model_name,
|
|
)
|
|
if len(delta_text) == 0:
|
|
if content.get("finish_reason", None) is not None:
|
|
finish_stream_events.append(chunk)
|
|
continue
|
|
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
|
|
# There is not "content" field in the last delta message, so exclude_none to exclude field "content".
|
|
for finish_chunk in finish_stream_events:
|
|
yield f"data: {finish_chunk.model_dump_json(exclude_unset=True)}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
|
|
async def generate_completion_stream(payload: Dict[str, Any], worker_addr: str):
|
|
controller_address = app_settings.controller_address
|
|
async with httpx.AsyncClient() as client:
|
|
delimiter = b"\0"
|
|
async with client.stream(
|
|
"POST",
|
|
worker_addr + "/worker_generate_stream",
|
|
headers=headers,
|
|
json=payload,
|
|
timeout=WORKER_API_TIMEOUT,
|
|
) as response:
|
|
# content = await response.aread()
|
|
buffer = b""
|
|
async for raw_chunk in response.aiter_raw():
|
|
buffer += raw_chunk
|
|
while (chunk_end := buffer.find(delimiter)) >= 0:
|
|
chunk, buffer = buffer[:chunk_end], buffer[chunk_end + 1 :]
|
|
if not chunk:
|
|
continue
|
|
yield json.loads(chunk.decode())
|
|
|
|
|
|
async def generate_completion(payload: Dict[str, Any], worker_addr: str):
|
|
return await fetch_remote(worker_addr + "/worker_generate", payload, "")
|
|
|
|
|
|
@app.post("/v1/embeddings", dependencies=[Depends(check_api_key)])
|
|
@app.post("/v1/engines/{model_name}/embeddings", dependencies=[Depends(check_api_key)])
|
|
async def create_embeddings(request: EmbeddingsRequest, model_name: str = None):
|
|
"""Creates embeddings for the text"""
|
|
if request.model is None:
|
|
request.model = model_name
|
|
error_check_ret = await check_model(request)
|
|
if error_check_ret is not None:
|
|
return error_check_ret
|
|
|
|
request.input = process_input(request.model, request.input)
|
|
|
|
data = []
|
|
token_num = 0
|
|
batch_size = WORKER_API_EMBEDDING_BATCH_SIZE
|
|
batches = [
|
|
request.input[i : min(i + batch_size, len(request.input))]
|
|
for i in range(0, len(request.input), batch_size)
|
|
]
|
|
for num_batch, batch in enumerate(batches):
|
|
payload = {
|
|
"model": request.model,
|
|
"input": batch,
|
|
"encoding_format": request.encoding_format,
|
|
}
|
|
embedding = await get_embedding(payload)
|
|
if "error_code" in embedding and embedding["error_code"] != 0:
|
|
return create_error_response(embedding["error_code"], embedding["text"])
|
|
data += [
|
|
{
|
|
"object": "embedding",
|
|
"embedding": emb,
|
|
"index": num_batch * batch_size + i,
|
|
}
|
|
for i, emb in enumerate(embedding["embedding"])
|
|
]
|
|
token_num += embedding["token_num"]
|
|
return EmbeddingsResponse(
|
|
data=data,
|
|
model=request.model,
|
|
usage=UsageInfo(
|
|
prompt_tokens=token_num,
|
|
total_tokens=token_num,
|
|
completion_tokens=None,
|
|
),
|
|
).model_dump(exclude_none=True)
|
|
|
|
|
|
async def get_embedding(payload: Dict[str, Any]):
|
|
controller_address = app_settings.controller_address
|
|
model_name = payload["model"]
|
|
worker_addr = await get_worker_address(model_name)
|
|
|
|
embedding = await fetch_remote(worker_addr + "/worker_get_embeddings", payload)
|
|
return json.loads(embedding)
|
|
|
|
|
|
### GENERAL API - NOT OPENAI COMPATIBLE ###
|
|
|
|
|
|
@app.post("/api/v1/token_check")
|
|
async def count_tokens(request: APITokenCheckRequest):
|
|
"""
|
|
Checks the token count for each message in your list
|
|
This is not part of the OpenAI API spec.
|
|
"""
|
|
checkedList = []
|
|
for item in request.prompts:
|
|
worker_addr = await get_worker_address(item.model)
|
|
|
|
context_len = await fetch_remote(
|
|
worker_addr + "/model_details",
|
|
{"prompt": item.prompt, "model": item.model},
|
|
"context_length",
|
|
)
|
|
|
|
token_num = await fetch_remote(
|
|
worker_addr + "/count_token",
|
|
{"prompt": item.prompt, "model": item.model},
|
|
"count",
|
|
)
|
|
|
|
can_fit = True
|
|
if token_num + item.max_tokens > context_len:
|
|
can_fit = False
|
|
|
|
checkedList.append(
|
|
APITokenCheckResponseItem(
|
|
fits=can_fit, contextLength=context_len, tokenCount=token_num
|
|
)
|
|
)
|
|
|
|
return APITokenCheckResponse(prompts=checkedList)
|
|
|
|
|
|
@app.post("/api/v1/chat/completions")
|
|
async def create_chat_completion(request: APIChatCompletionRequest):
|
|
"""Creates a completion for the chat message"""
|
|
error_check_ret = await check_model(request)
|
|
if error_check_ret is not None:
|
|
return error_check_ret
|
|
error_check_ret = check_requests(request)
|
|
if error_check_ret is not None:
|
|
return error_check_ret
|
|
|
|
worker_addr = await get_worker_address(request.model)
|
|
|
|
gen_params = await get_gen_params(
|
|
request.model,
|
|
worker_addr,
|
|
request.messages,
|
|
temperature=request.temperature,
|
|
top_p=request.top_p,
|
|
top_k=request.top_k,
|
|
presence_penalty=request.presence_penalty,
|
|
frequency_penalty=request.frequency_penalty,
|
|
max_tokens=request.max_tokens,
|
|
echo=False,
|
|
stop=request.stop,
|
|
)
|
|
|
|
if request.repetition_penalty is not None:
|
|
gen_params["repetition_penalty"] = request.repetition_penalty
|
|
|
|
max_new_tokens, error_check_ret = await check_length(
|
|
request,
|
|
gen_params["prompt"],
|
|
gen_params["max_new_tokens"],
|
|
worker_addr,
|
|
)
|
|
|
|
if error_check_ret is not None:
|
|
return error_check_ret
|
|
|
|
gen_params["max_new_tokens"] = max_new_tokens
|
|
|
|
if request.stream:
|
|
generator = chat_completion_stream_generator(
|
|
request.model, gen_params, request.n, worker_addr
|
|
)
|
|
return StreamingResponse(generator, media_type="text/event-stream")
|
|
|
|
choices = []
|
|
chat_completions = []
|
|
for i in range(request.n):
|
|
content = asyncio.create_task(generate_completion(gen_params, worker_addr))
|
|
chat_completions.append(content)
|
|
try:
|
|
all_tasks = await asyncio.gather(*chat_completions)
|
|
except Exception as e:
|
|
return create_error_response(ErrorCode.INTERNAL_ERROR, str(e))
|
|
usage = UsageInfo()
|
|
for i, content in enumerate(all_tasks):
|
|
if content["error_code"] != 0:
|
|
return create_error_response(content["error_code"], content["text"])
|
|
choices.append(
|
|
ChatCompletionResponseChoice(
|
|
index=i,
|
|
message=ChatMessage(role="assistant", content=content["text"]),
|
|
finish_reason=content.get("finish_reason", "stop"),
|
|
)
|
|
)
|
|
task_usage = UsageInfo.model_validate(content["usage"])
|
|
for usage_key, usage_value in task_usage.model_dump().items():
|
|
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
|
|
|
|
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
|
|
|
|
|
### END GENERAL API - NOT OPENAI COMPATIBLE ###
|
|
|
|
|
|
def create_openai_api_server():
|
|
parser = argparse.ArgumentParser(
|
|
description="FastChat ChatGPT-Compatible RESTful API server."
|
|
)
|
|
parser.add_argument("--host", type=str, default="localhost", help="host name")
|
|
parser.add_argument("--port", type=int, default=8000, help="port number")
|
|
parser.add_argument(
|
|
"--controller-address", type=str, default="http://localhost:21001"
|
|
)
|
|
parser.add_argument(
|
|
"--allow-credentials", action="store_true", help="allow credentials"
|
|
)
|
|
parser.add_argument(
|
|
"--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
|
|
)
|
|
parser.add_argument(
|
|
"--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
|
|
)
|
|
parser.add_argument(
|
|
"--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
|
|
)
|
|
parser.add_argument(
|
|
"--api-keys",
|
|
type=lambda s: s.split(","),
|
|
help="Optional list of comma separated API keys",
|
|
)
|
|
parser.add_argument(
|
|
"--ssl",
|
|
action="store_true",
|
|
required=False,
|
|
default=False,
|
|
help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=args.allowed_origins,
|
|
allow_credentials=args.allow_credentials,
|
|
allow_methods=args.allowed_methods,
|
|
allow_headers=args.allowed_headers,
|
|
)
|
|
app_settings.controller_address = args.controller_address
|
|
app_settings.api_keys = args.api_keys
|
|
|
|
logger.info(f"args: {args}")
|
|
return args
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = create_openai_api_server()
|
|
if args.ssl:
|
|
uvicorn.run(
|
|
app,
|
|
host=args.host,
|
|
port=args.port,
|
|
log_level="info",
|
|
ssl_keyfile=os.environ["SSL_KEYFILE"],
|
|
ssl_certfile=os.environ["SSL_CERTFILE"],
|
|
)
|
|
else:
|
|
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|