mirror of
https://github.com/lilakk/BLEUBERI.git
synced 2026-04-19 12:58:12 +00:00
115 lines
3.9 KiB
Python
115 lines
3.9 KiB
Python
"""Benchmarking script to test the throughput of serving workers."""
|
|
import argparse
|
|
import json
|
|
|
|
import requests
|
|
import threading
|
|
import time
|
|
|
|
from fastchat.conversation import get_conv_template
|
|
|
|
|
|
def main():
|
|
if args.worker_address:
|
|
worker_addr = args.worker_address
|
|
else:
|
|
controller_addr = args.controller_address
|
|
ret = requests.post(controller_addr + "/refresh_all_workers")
|
|
ret = requests.post(controller_addr + "/list_models")
|
|
models = ret.json()["models"]
|
|
models.sort()
|
|
print(f"Models: {models}")
|
|
|
|
ret = requests.post(
|
|
controller_addr + "/get_worker_address", json={"model": args.model_name}
|
|
)
|
|
worker_addr = ret.json()["address"]
|
|
print(f"worker_addr: {worker_addr}")
|
|
|
|
if worker_addr == "":
|
|
return
|
|
|
|
conv = get_conv_template("vicuna_v1.1")
|
|
conv.append_message(conv.roles[0], "Tell me a story with more than 1000 words")
|
|
prompt_template = conv.get_prompt()
|
|
prompts = [prompt_template for _ in range(args.n_thread)]
|
|
|
|
headers = {"User-Agent": "fastchat Client"}
|
|
ploads = [
|
|
{
|
|
"model": args.model_name,
|
|
"prompt": prompts[i],
|
|
"max_new_tokens": args.max_new_tokens,
|
|
"temperature": 0.0,
|
|
# "stop": conv.sep,
|
|
}
|
|
for i in range(len(prompts))
|
|
]
|
|
|
|
def send_request(results, i):
|
|
if args.test_dispatch:
|
|
ret = requests.post(
|
|
controller_addr + "/get_worker_address", json={"model": args.model_name}
|
|
)
|
|
thread_worker_addr = ret.json()["address"]
|
|
else:
|
|
thread_worker_addr = worker_addr
|
|
print(f"thread {i} goes to {thread_worker_addr}")
|
|
response = requests.post(
|
|
thread_worker_addr + "/worker_generate_stream",
|
|
headers=headers,
|
|
json=ploads[i],
|
|
stream=False,
|
|
)
|
|
k = list(
|
|
response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0")
|
|
)
|
|
# print(k)
|
|
response_new_words = json.loads(k[-2].decode("utf-8"))["text"]
|
|
error_code = json.loads(k[-2].decode("utf-8"))["error_code"]
|
|
# print(f"=== Thread {i} ===, words: {1}, error code: {error_code}")
|
|
results[i] = len(response_new_words.split(" ")) - len(prompts[i].split(" "))
|
|
|
|
# use N threads to prompt the backend
|
|
tik = time.time()
|
|
threads = []
|
|
results = [None] * args.n_thread
|
|
for i in range(args.n_thread):
|
|
t = threading.Thread(target=send_request, args=(results, i))
|
|
t.start()
|
|
# time.sleep(0.5)
|
|
threads.append(t)
|
|
|
|
for t in threads:
|
|
t.join()
|
|
|
|
print(f"Time (POST): {time.time() - tik} s")
|
|
# n_words = 0
|
|
# for i, response in enumerate(results):
|
|
# # print(prompt[i].replace(conv.sep, "\n"), end="")
|
|
# # make sure the streaming finishes at EOS or stopping criteria
|
|
# k = list(response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"))
|
|
# response_new_words = json.loads(k[-2].decode("utf-8"))["text"]
|
|
# # print(response_new_words)
|
|
# n_words += len(response_new_words.split(" ")) - len(prompts[i].split(" "))
|
|
n_words = sum(results)
|
|
time_seconds = time.time() - tik
|
|
print(
|
|
f"Time (Completion): {time_seconds}, n threads: {args.n_thread}, "
|
|
f"throughput: {n_words / time_seconds} words/s."
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--controller-address", type=str, default="http://localhost:21001"
|
|
)
|
|
parser.add_argument("--worker-address", type=str)
|
|
parser.add_argument("--model-name", type=str, default="vicuna")
|
|
parser.add_argument("--max-new-tokens", type=int, default=2048)
|
|
parser.add_argument("--n-thread", type=int, default=8)
|
|
parser.add_argument("--test-dispatch", action="store_true")
|
|
args = parser.parse_args()
|
|
|
|
main()
|