mirror of
https://github.com/lilakk/BLEUBERI.git
synced 2026-04-19 12:58:12 +00:00
389 lines
12 KiB
Python
389 lines
12 KiB
Python
"""
|
|
A controller manages distributed workers.
|
|
It sends worker addresses to clients.
|
|
"""
|
|
import argparse
|
|
import asyncio
|
|
import dataclasses
|
|
from enum import Enum, auto
|
|
import json
|
|
import logging
|
|
import os
|
|
import time
|
|
from typing import List, Union
|
|
import threading
|
|
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import StreamingResponse
|
|
import numpy as np
|
|
import requests
|
|
import uvicorn
|
|
|
|
from fastchat.constants import (
|
|
CONTROLLER_HEART_BEAT_EXPIRATION,
|
|
WORKER_API_TIMEOUT,
|
|
ErrorCode,
|
|
SERVER_ERROR_MSG,
|
|
)
|
|
from fastchat.utils import build_logger
|
|
|
|
|
|
logger = build_logger("controller", "controller.log")
|
|
|
|
|
|
class DispatchMethod(Enum):
|
|
LOTTERY = auto()
|
|
SHORTEST_QUEUE = auto()
|
|
|
|
@classmethod
|
|
def from_str(cls, name):
|
|
if name == "lottery":
|
|
return cls.LOTTERY
|
|
elif name == "shortest_queue":
|
|
return cls.SHORTEST_QUEUE
|
|
else:
|
|
raise ValueError(f"Invalid dispatch method")
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class WorkerInfo:
|
|
model_names: List[str]
|
|
speed: int
|
|
queue_length: int
|
|
check_heart_beat: bool
|
|
last_heart_beat: str
|
|
multimodal: bool
|
|
|
|
|
|
def heart_beat_controller(controller):
|
|
while True:
|
|
time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
|
|
controller.remove_stale_workers_by_expiration()
|
|
|
|
|
|
class Controller:
|
|
def __init__(self, dispatch_method: str):
|
|
# Dict[str -> WorkerInfo]
|
|
self.worker_info = {}
|
|
self.dispatch_method = DispatchMethod.from_str(dispatch_method)
|
|
|
|
self.heart_beat_thread = threading.Thread(
|
|
target=heart_beat_controller, args=(self,)
|
|
)
|
|
self.heart_beat_thread.start()
|
|
|
|
def register_worker(
|
|
self,
|
|
worker_name: str,
|
|
check_heart_beat: bool,
|
|
worker_status: dict,
|
|
multimodal: bool,
|
|
):
|
|
if worker_name not in self.worker_info:
|
|
logger.info(f"Register a new worker: {worker_name}")
|
|
else:
|
|
logger.info(f"Register an existing worker: {worker_name}")
|
|
|
|
if not worker_status:
|
|
worker_status = self.get_worker_status(worker_name)
|
|
if not worker_status:
|
|
return False
|
|
|
|
self.worker_info[worker_name] = WorkerInfo(
|
|
worker_status["model_names"],
|
|
worker_status["speed"],
|
|
worker_status["queue_length"],
|
|
check_heart_beat,
|
|
time.time(),
|
|
multimodal,
|
|
)
|
|
|
|
logger.info(f"Register done: {worker_name}, {worker_status}")
|
|
return True
|
|
|
|
def get_worker_status(self, worker_name: str):
|
|
try:
|
|
r = requests.post(worker_name + "/worker_get_status", timeout=5)
|
|
except requests.exceptions.RequestException as e:
|
|
logger.error(f"Get status fails: {worker_name}, {e}")
|
|
return None
|
|
|
|
if r.status_code != 200:
|
|
logger.error(f"Get status fails: {worker_name}, {r}")
|
|
return None
|
|
|
|
return r.json()
|
|
|
|
def remove_worker(self, worker_name: str):
|
|
del self.worker_info[worker_name]
|
|
|
|
def refresh_all_workers(self):
|
|
old_info = dict(self.worker_info)
|
|
self.worker_info = {}
|
|
|
|
for w_name, w_info in old_info.items():
|
|
if not self.register_worker(
|
|
w_name, w_info.check_heart_beat, None, w_info.multimodal
|
|
):
|
|
logger.info(f"Remove stale worker: {w_name}")
|
|
|
|
def list_models(self):
|
|
model_names = set()
|
|
|
|
for w_name, w_info in self.worker_info.items():
|
|
model_names.update(w_info.model_names)
|
|
|
|
return list(model_names)
|
|
|
|
def list_multimodal_models(self):
|
|
model_names = set()
|
|
|
|
for w_name, w_info in self.worker_info.items():
|
|
if w_info.multimodal:
|
|
model_names.update(w_info.model_names)
|
|
|
|
return list(model_names)
|
|
|
|
def list_language_models(self):
|
|
model_names = set()
|
|
|
|
for w_name, w_info in self.worker_info.items():
|
|
if not w_info.multimodal:
|
|
model_names.update(w_info.model_names)
|
|
|
|
return list(model_names)
|
|
|
|
def get_worker_address(self, model_name: str):
|
|
if self.dispatch_method == DispatchMethod.LOTTERY:
|
|
worker_names = []
|
|
worker_speeds = []
|
|
for w_name, w_info in self.worker_info.items():
|
|
if model_name in w_info.model_names:
|
|
worker_names.append(w_name)
|
|
worker_speeds.append(w_info.speed)
|
|
worker_speeds = np.array(worker_speeds, dtype=np.float32)
|
|
norm = np.sum(worker_speeds)
|
|
if norm < 1e-4:
|
|
return ""
|
|
worker_speeds = worker_speeds / norm
|
|
if True: # Directly return address
|
|
pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
|
|
worker_name = worker_names[pt]
|
|
return worker_name
|
|
|
|
# Check status before returning
|
|
while True:
|
|
pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
|
|
worker_name = worker_names[pt]
|
|
|
|
if self.get_worker_status(worker_name):
|
|
break
|
|
else:
|
|
self.remove_worker(worker_name)
|
|
worker_speeds[pt] = 0
|
|
norm = np.sum(worker_speeds)
|
|
if norm < 1e-4:
|
|
return ""
|
|
worker_speeds = worker_speeds / norm
|
|
continue
|
|
return worker_name
|
|
elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
|
|
worker_names = []
|
|
worker_qlen = []
|
|
for w_name, w_info in self.worker_info.items():
|
|
if model_name in w_info.model_names:
|
|
worker_names.append(w_name)
|
|
worker_qlen.append(w_info.queue_length / w_info.speed)
|
|
if len(worker_names) == 0:
|
|
return ""
|
|
min_index = np.argmin(worker_qlen)
|
|
w_name = worker_names[min_index]
|
|
self.worker_info[w_name].queue_length += 1
|
|
logger.info(
|
|
f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}"
|
|
)
|
|
return w_name
|
|
else:
|
|
raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
|
|
|
|
def receive_heart_beat(self, worker_name: str, queue_length: int):
|
|
if worker_name not in self.worker_info:
|
|
logger.info(f"Receive unknown heart beat. {worker_name}")
|
|
return False
|
|
|
|
self.worker_info[worker_name].queue_length = queue_length
|
|
self.worker_info[worker_name].last_heart_beat = time.time()
|
|
logger.info(f"Receive heart beat. {worker_name}")
|
|
return True
|
|
|
|
def remove_stale_workers_by_expiration(self):
|
|
expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
|
|
to_delete = []
|
|
for worker_name, w_info in self.worker_info.items():
|
|
if w_info.check_heart_beat and w_info.last_heart_beat < expire:
|
|
to_delete.append(worker_name)
|
|
|
|
for worker_name in to_delete:
|
|
self.remove_worker(worker_name)
|
|
|
|
def handle_no_worker(self, params):
|
|
logger.info(f"no worker: {params['model']}")
|
|
ret = {
|
|
"text": SERVER_ERROR_MSG,
|
|
"error_code": ErrorCode.CONTROLLER_NO_WORKER,
|
|
}
|
|
return json.dumps(ret).encode() + b"\0"
|
|
|
|
def handle_worker_timeout(self, worker_address):
|
|
logger.info(f"worker timeout: {worker_address}")
|
|
ret = {
|
|
"text": SERVER_ERROR_MSG,
|
|
"error_code": ErrorCode.CONTROLLER_WORKER_TIMEOUT,
|
|
}
|
|
return json.dumps(ret).encode() + b"\0"
|
|
|
|
# Let the controller act as a worker to achieve hierarchical
|
|
# management. This can be used to connect isolated sub networks.
|
|
def worker_api_get_status(self):
|
|
model_names = set()
|
|
speed = 0
|
|
queue_length = 0
|
|
|
|
for w_name in self.worker_info:
|
|
worker_status = self.get_worker_status(w_name)
|
|
if worker_status is not None:
|
|
model_names.update(worker_status["model_names"])
|
|
speed += worker_status["speed"]
|
|
queue_length += worker_status["queue_length"]
|
|
|
|
model_names = sorted(list(model_names))
|
|
return {
|
|
"model_names": model_names,
|
|
"speed": speed,
|
|
"queue_length": queue_length,
|
|
}
|
|
|
|
def worker_api_generate_stream(self, params):
|
|
worker_addr = self.get_worker_address(params["model"])
|
|
if not worker_addr:
|
|
yield self.handle_no_worker(params)
|
|
|
|
try:
|
|
response = requests.post(
|
|
worker_addr + "/worker_generate_stream",
|
|
json=params,
|
|
stream=True,
|
|
timeout=WORKER_API_TIMEOUT,
|
|
)
|
|
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
|
if chunk:
|
|
yield chunk + b"\0"
|
|
except requests.exceptions.RequestException as e:
|
|
yield self.handle_worker_timeout(worker_addr)
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
@app.post("/register_worker")
|
|
async def register_worker(request: Request):
|
|
data = await request.json()
|
|
controller.register_worker(
|
|
data["worker_name"],
|
|
data["check_heart_beat"],
|
|
data.get("worker_status", None),
|
|
data.get("multimodal", False),
|
|
)
|
|
|
|
|
|
@app.post("/refresh_all_workers")
|
|
async def refresh_all_workers():
|
|
models = controller.refresh_all_workers()
|
|
|
|
|
|
@app.post("/list_models")
|
|
async def list_models():
|
|
models = controller.list_models()
|
|
return {"models": models}
|
|
|
|
|
|
@app.post("/list_multimodal_models")
|
|
async def list_multimodal_models():
|
|
models = controller.list_multimodal_models()
|
|
return {"models": models}
|
|
|
|
|
|
@app.post("/list_language_models")
|
|
async def list_language_models():
|
|
models = controller.list_language_models()
|
|
return {"models": models}
|
|
|
|
|
|
@app.post("/get_worker_address")
|
|
async def get_worker_address(request: Request):
|
|
data = await request.json()
|
|
addr = controller.get_worker_address(data["model"])
|
|
return {"address": addr}
|
|
|
|
|
|
@app.post("/receive_heart_beat")
|
|
async def receive_heart_beat(request: Request):
|
|
data = await request.json()
|
|
exist = controller.receive_heart_beat(data["worker_name"], data["queue_length"])
|
|
return {"exist": exist}
|
|
|
|
|
|
@app.post("/worker_generate_stream")
|
|
async def worker_api_generate_stream(request: Request):
|
|
params = await request.json()
|
|
generator = controller.worker_api_generate_stream(params)
|
|
return StreamingResponse(generator)
|
|
|
|
|
|
@app.post("/worker_get_status")
|
|
async def worker_api_get_status(request: Request):
|
|
return controller.worker_api_get_status()
|
|
|
|
|
|
@app.get("/test_connection")
|
|
async def worker_api_get_status(request: Request):
|
|
return "success"
|
|
|
|
|
|
def create_controller():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--host", type=str, default="localhost")
|
|
parser.add_argument("--port", type=int, default=21001)
|
|
parser.add_argument(
|
|
"--dispatch-method",
|
|
type=str,
|
|
choices=["lottery", "shortest_queue"],
|
|
default="shortest_queue",
|
|
)
|
|
parser.add_argument(
|
|
"--ssl",
|
|
action="store_true",
|
|
required=False,
|
|
default=False,
|
|
help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.",
|
|
)
|
|
args = parser.parse_args()
|
|
logger.info(f"args: {args}")
|
|
|
|
controller = Controller(args.dispatch_method)
|
|
return args, controller
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args, controller = create_controller()
|
|
if args.ssl:
|
|
uvicorn.run(
|
|
app,
|
|
host=args.host,
|
|
port=args.port,
|
|
log_level="info",
|
|
ssl_keyfile=os.environ["SSL_KEYFILE"],
|
|
ssl_certfile=os.environ["SSL_CERTFILE"],
|
|
)
|
|
else:
|
|
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|