mirror of
https://github.com/lilakk/BLEUBERI.git
synced 2026-04-19 12:58:12 +00:00
415 lines
12 KiB
Python
415 lines
12 KiB
Python
"""
|
|
A model worker that calls huggingface inference endpoint.
|
|
|
|
Register models in a JSON file with the following format:
|
|
{
|
|
"falcon-180b-chat": {
|
|
"model_name": "falcon-180B-chat",
|
|
"api_base": "https://api-inference.huggingface.co/models",
|
|
"model_path": "tiiuae/falcon-180B-chat",
|
|
"token": "hf_XXX",
|
|
"context_length": 2048
|
|
},
|
|
"zephyr-7b-beta": {
|
|
"model_name": "zephyr-7b-beta",
|
|
"model_path": "",
|
|
"api_base": "xxx",
|
|
"token": "hf_XXX",
|
|
"context_length": 4096
|
|
}
|
|
}
|
|
|
|
"model_path", "api_base", "token", and "context_length" are necessary, while others are optional.
|
|
"""
|
|
import argparse
|
|
import asyncio
|
|
import json
|
|
import uuid
|
|
import os
|
|
from typing import List, Optional
|
|
|
|
import requests
|
|
import uvicorn
|
|
from fastapi import BackgroundTasks, FastAPI, Request
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
from huggingface_hub import InferenceClient
|
|
|
|
from fastchat.constants import SERVER_ERROR_MSG, ErrorCode
|
|
from fastchat.serve.base_model_worker import BaseModelWorker
|
|
from fastchat.utils import build_logger
|
|
|
|
worker_id = str(uuid.uuid4())[:8]
|
|
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
|
|
|
|
workers = []
|
|
worker_map = {}
|
|
app = FastAPI()
|
|
|
|
|
|
# reference to
|
|
# https://github.com/philschmid/easyllm/blob/cbd908b3b3f44a97a22cb0fc2c93df3660bacdad/easyllm/clients/huggingface.py#L374-L392
|
|
def get_gen_kwargs(
|
|
params,
|
|
seed: Optional[int] = None,
|
|
):
|
|
stop = params.get("stop", None)
|
|
if isinstance(stop, list):
|
|
stop_sequences = stop
|
|
elif isinstance(stop, str):
|
|
stop_sequences = [stop]
|
|
else:
|
|
stop_sequences = []
|
|
gen_kwargs = {
|
|
"do_sample": True,
|
|
"return_full_text": bool(params.get("echo", False)),
|
|
"max_new_tokens": int(params.get("max_new_tokens", 256)),
|
|
"top_p": float(params.get("top_p", 1.0)),
|
|
"temperature": float(params.get("temperature", 1.0)),
|
|
"stop_sequences": stop_sequences,
|
|
"repetition_penalty": float(params.get("repetition_penalty", 1.0)),
|
|
"top_k": params.get("top_k", None),
|
|
"seed": seed,
|
|
}
|
|
if gen_kwargs["top_p"] == 1:
|
|
gen_kwargs["top_p"] = 0.9999999
|
|
if gen_kwargs["top_p"] == 0:
|
|
gen_kwargs.pop("top_p")
|
|
if gen_kwargs["temperature"] == 0:
|
|
gen_kwargs.pop("temperature")
|
|
gen_kwargs["do_sample"] = False
|
|
return gen_kwargs
|
|
|
|
|
|
def could_be_stop(text, stop):
|
|
for s in stop:
|
|
if any(text.endswith(s[:i]) for i in range(1, len(s) + 1)):
|
|
return True
|
|
return False
|
|
|
|
|
|
class HuggingfaceApiWorker(BaseModelWorker):
|
|
def __init__(
|
|
self,
|
|
controller_addr: str,
|
|
worker_addr: str,
|
|
worker_id: str,
|
|
model_path: str,
|
|
api_base: str,
|
|
token: str,
|
|
context_length: int,
|
|
model_names: List[str],
|
|
limit_worker_concurrency: int,
|
|
no_register: bool,
|
|
conv_template: Optional[str] = None,
|
|
seed: Optional[int] = None,
|
|
**kwargs,
|
|
):
|
|
super().__init__(
|
|
controller_addr,
|
|
worker_addr,
|
|
worker_id,
|
|
model_path,
|
|
model_names,
|
|
limit_worker_concurrency,
|
|
conv_template=conv_template,
|
|
)
|
|
|
|
self.model_path = model_path
|
|
self.api_base = api_base
|
|
self.token = token
|
|
self.context_len = context_length
|
|
self.seed = seed
|
|
|
|
logger.info(
|
|
f"Connecting with huggingface api {self.model_path} as {self.model_names} on worker {worker_id} ..."
|
|
)
|
|
|
|
if not no_register:
|
|
self.init_heart_beat()
|
|
|
|
def count_token(self, params):
|
|
# No tokenizer here
|
|
ret = {
|
|
"count": 0,
|
|
"error_code": 0,
|
|
}
|
|
return ret
|
|
|
|
def generate_stream_gate(self, params):
|
|
self.call_ct += 1
|
|
|
|
prompt = params["prompt"]
|
|
gen_kwargs = get_gen_kwargs(params, seed=self.seed)
|
|
stop = gen_kwargs["stop_sequences"]
|
|
if "falcon" in self.model_path and "chat" in self.model_path:
|
|
stop.extend(["\nUser:", "<|endoftext|>", " User:", "###"])
|
|
stop = list(set(stop))
|
|
gen_kwargs["stop_sequences"] = stop
|
|
|
|
logger.info(f"prompt: {prompt}")
|
|
logger.info(f"gen_kwargs: {gen_kwargs}")
|
|
|
|
try:
|
|
if self.model_path == "":
|
|
url = f"{self.api_base}"
|
|
else:
|
|
url = f"{self.api_base}/{self.model_path}"
|
|
client = InferenceClient(url, token=self.token)
|
|
res = client.text_generation(
|
|
prompt, stream=True, details=True, **gen_kwargs
|
|
)
|
|
|
|
reason = None
|
|
text = ""
|
|
for chunk in res:
|
|
if chunk.token.special:
|
|
continue
|
|
text += chunk.token.text
|
|
|
|
s = next((x for x in stop if text.endswith(x)), None)
|
|
if s is not None:
|
|
text = text[: -len(s)]
|
|
reason = "stop"
|
|
break
|
|
if could_be_stop(text, stop):
|
|
continue
|
|
if (
|
|
chunk.details is not None
|
|
and chunk.details.finish_reason is not None
|
|
):
|
|
reason = chunk.details.finish_reason
|
|
if reason not in ["stop", "length"]:
|
|
reason = None
|
|
ret = {
|
|
"text": text,
|
|
"error_code": 0,
|
|
"finish_reason": reason,
|
|
}
|
|
yield json.dumps(ret).encode() + b"\0"
|
|
except Exception as e:
|
|
ret = {
|
|
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
|
|
"error_code": ErrorCode.INTERNAL_ERROR,
|
|
}
|
|
yield json.dumps(ret).encode() + b"\0"
|
|
|
|
def generate_gate(self, params):
|
|
for x in self.generate_stream_gate(params):
|
|
pass
|
|
return json.loads(x[:-1].decode())
|
|
|
|
def get_embeddings(self, params):
|
|
raise NotImplementedError()
|
|
|
|
|
|
def release_worker_semaphore(worker):
|
|
worker.semaphore.release()
|
|
|
|
|
|
def acquire_worker_semaphore(worker):
|
|
if worker.semaphore is None:
|
|
worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency)
|
|
return worker.semaphore.acquire()
|
|
|
|
|
|
def create_background_tasks(worker):
|
|
background_tasks = BackgroundTasks()
|
|
background_tasks.add_task(lambda: release_worker_semaphore(worker))
|
|
return background_tasks
|
|
|
|
|
|
@app.post("/worker_generate_stream")
|
|
async def api_generate_stream(request: Request):
|
|
params = await request.json()
|
|
worker = worker_map[params["model"]]
|
|
await acquire_worker_semaphore(worker)
|
|
generator = worker.generate_stream_gate(params)
|
|
background_tasks = create_background_tasks(worker)
|
|
return StreamingResponse(generator, background=background_tasks)
|
|
|
|
|
|
@app.post("/worker_generate")
|
|
async def api_generate(request: Request):
|
|
params = await request.json()
|
|
worker = worker_map[params["model"]]
|
|
await acquire_worker_semaphore(worker)
|
|
output = worker.generate_gate(params)
|
|
release_worker_semaphore(worker)
|
|
return JSONResponse(output)
|
|
|
|
|
|
@app.post("/worker_get_embeddings")
|
|
async def api_get_embeddings(request: Request):
|
|
params = await request.json()
|
|
worker = worker_map[params["model"]]
|
|
await acquire_worker_semaphore(worker)
|
|
embedding = worker.get_embeddings(params)
|
|
release_worker_semaphore(worker)
|
|
return JSONResponse(content=embedding)
|
|
|
|
|
|
@app.post("/worker_get_status")
|
|
async def api_get_status(request: Request):
|
|
return {
|
|
"model_names": [m for w in workers for m in w.model_names],
|
|
"speed": 1,
|
|
"queue_length": sum([w.get_queue_length() for w in workers]),
|
|
}
|
|
|
|
|
|
@app.post("/count_token")
|
|
async def api_count_token(request: Request):
|
|
params = await request.json()
|
|
worker = worker_map[params["model"]]
|
|
return worker.count_token(params)
|
|
|
|
|
|
@app.post("/worker_get_conv_template")
|
|
async def api_get_conv(request: Request):
|
|
params = await request.json()
|
|
worker = worker_map[params["model"]]
|
|
return worker.get_conv_template()
|
|
|
|
|
|
@app.post("/model_details")
|
|
async def api_model_details(request: Request):
|
|
params = await request.json()
|
|
worker = worker_map[params["model"]]
|
|
return {"context_length": worker.context_len}
|
|
|
|
|
|
def create_huggingface_api_worker():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--host", type=str, default="localhost")
|
|
parser.add_argument("--port", type=int, default=21002)
|
|
parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
|
|
parser.add_argument(
|
|
"--controller-address", type=str, default="http://localhost:21001"
|
|
)
|
|
# all model-related parameters are listed in --model-info-file
|
|
parser.add_argument(
|
|
"--model-info-file",
|
|
type=str,
|
|
required=True,
|
|
help="Huggingface API model's info file path",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--limit-worker-concurrency",
|
|
type=int,
|
|
default=5,
|
|
help="Limit the model concurrency to prevent OOM.",
|
|
)
|
|
parser.add_argument("--no-register", action="store_true")
|
|
parser.add_argument(
|
|
"--seed",
|
|
type=int,
|
|
default=None,
|
|
help="Overwrite the random seed for each generation.",
|
|
)
|
|
parser.add_argument(
|
|
"--ssl",
|
|
action="store_true",
|
|
required=False,
|
|
default=False,
|
|
help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
with open(args.model_info_file, "r", encoding="UTF-8") as f:
|
|
model_info = json.load(f)
|
|
|
|
logger.info(f"args: {args}")
|
|
|
|
model_path_list = []
|
|
api_base_list = []
|
|
token_list = []
|
|
context_length_list = []
|
|
model_names_list = []
|
|
conv_template_list = []
|
|
|
|
for m in model_info:
|
|
model_path_list.append(model_info[m]["model_path"])
|
|
api_base_list.append(model_info[m]["api_base"])
|
|
token_list.append(model_info[m]["token"])
|
|
|
|
context_length = model_info[m]["context_length"]
|
|
model_names = model_info[m].get("model_names", [m.split("/")[-1]])
|
|
if isinstance(model_names, str):
|
|
model_names = [model_names]
|
|
conv_template = model_info[m].get("conv_template", None)
|
|
|
|
context_length_list.append(context_length)
|
|
model_names_list.append(model_names)
|
|
conv_template_list.append(conv_template)
|
|
|
|
logger.info(f"Model paths: {model_path_list}")
|
|
logger.info(f"API bases: {api_base_list}")
|
|
logger.info(f"Tokens: {token_list}")
|
|
logger.info(f"Context lengths: {context_length_list}")
|
|
logger.info(f"Model names: {model_names_list}")
|
|
logger.info(f"Conv templates: {conv_template_list}")
|
|
|
|
for (
|
|
model_names,
|
|
conv_template,
|
|
model_path,
|
|
api_base,
|
|
token,
|
|
context_length,
|
|
) in zip(
|
|
model_names_list,
|
|
conv_template_list,
|
|
model_path_list,
|
|
api_base_list,
|
|
token_list,
|
|
context_length_list,
|
|
):
|
|
m = HuggingfaceApiWorker(
|
|
args.controller_address,
|
|
args.worker_address,
|
|
worker_id,
|
|
model_path,
|
|
api_base,
|
|
token,
|
|
context_length,
|
|
model_names,
|
|
args.limit_worker_concurrency,
|
|
no_register=args.no_register,
|
|
conv_template=conv_template,
|
|
seed=args.seed,
|
|
)
|
|
workers.append(m)
|
|
for name in model_names:
|
|
worker_map[name] = m
|
|
|
|
# register all the models
|
|
url = args.controller_address + "/register_worker"
|
|
data = {
|
|
"worker_name": workers[0].worker_addr,
|
|
"check_heart_beat": not args.no_register,
|
|
"worker_status": {
|
|
"model_names": [m for w in workers for m in w.model_names],
|
|
"speed": 1,
|
|
"queue_length": sum([w.get_queue_length() for w in workers]),
|
|
},
|
|
}
|
|
r = requests.post(url, json=data)
|
|
assert r.status_code == 200
|
|
|
|
return args, workers
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args, workers = create_huggingface_api_worker()
|
|
if args.ssl:
|
|
uvicorn.run(
|
|
app,
|
|
host=args.host,
|
|
port=args.port,
|
|
log_level="info",
|
|
ssl_keyfile=os.environ["SSL_KEYFILE"],
|
|
ssl_certfile=os.environ["SSL_CERTFILE"],
|
|
)
|
|
else:
|
|
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|