BLEUBERI/eval/arena-hard-v2.0/utils/sglang_server.py
2025-06-04 20:36:43 +00:00

318 lines
10 KiB
Python

# Author: Peter Jin
from typing import Any, Optional, Union
from dataclasses import dataclass
import concurrent.futures
import json
import multiprocessing as mp
import os
import subprocess
import time
import traceback
import urllib.request
try:
import sglang
import sglang.srt.entrypoints.http_server
import sglang.srt.server_args
import sglang.srt.utils
except ImportError:
sglang = None
def _sglang_server_init(tx, server_args):
def _post_init():
tx.send(None)
sglang.srt.entrypoints.http_server.launch_server(
server_args,
launch_callback=_post_init,
)
def _sglang_server_heartbeat(
_host: str = None,
_port: int = None,
):
req_headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
req = urllib.request.Request(
f"http://{_host}:{_port}/health_generate",
headers=req_headers,
)
with urllib.request.urlopen(req, timeout=2) as out:
out_data = out.read()
#output = json.loads(out_data.decode("utf-8"))
output = out_data.decode("utf-8")
return {
"output": output,
}
def _sglang_server_submit(
input_ids: list = None,
sampling_params: dict = None,
_ctr: int = None,
_host: int = None,
_port: int = None,
):
req_headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
req_body = {
"input_ids": input_ids,
"sampling_params": sampling_params,
}
req_data = json.dumps(req_body).encode("utf-8")
req = urllib.request.Request(
f"http://{_host}:{_port}/generate",
headers=req_headers,
data=req_data,
)
with urllib.request.urlopen(req) as out:
out_data = out.read()
output = json.loads(out_data.decode("utf-8"))
return {
"output": output,
"_ctr": _ctr,
}
@dataclass
class SGLangRequest:
_ctr: int
_key: Any
_output: Any
def _counter(self) -> int:
return self._ctr
def key(self) -> Any:
return self._key
def result(self) -> Any:
return self._output
class SGLangServerExecutor:
def __init__(
self,
max_workers: int = 256,
server_host: int = "127.0.0.1",
server_port: int = 30000,
backend: str = "spawn",
# NB(peter): upgrade sglang here.
subprocess_venv_path: str = None,
**kwargs
):
if server_host is not None:
assert kwargs.get("host", None) is None
kwargs["host"] = server_host
if server_port is not None:
assert kwargs.get("port", None) is None
kwargs["port"] = server_port
kwargs.setdefault("skip_tokenizer_init", True)
kwargs.setdefault("log_level", "warning")
#kwargs.setdefault("decode_log_interval", 1000000000)
# NB(peter): fa3 backend requires sglang >= 0.4.5.
kwargs.setdefault("attention_backend", "fa3")
if backend == "spawn":
server_args = sglang.srt.server_args.ServerArgs(
**kwargs
)
mpctx = mp.get_context("spawn")
rx, tx = mpctx.Pipe(False)
proc = mpctx.Process(
target=_sglang_server_init,
args=(tx, server_args,),
)
proc.start()
self._server_host = server_host
self._server_port = server_port
self._server_proc = proc
self._server_pid = proc.pid
elif backend == "subprocess":
# TODO(peter): alternative subprocess-based implementation
# to avoid sglang pip requirement (but need existing venv).
if subprocess_venv_path is not None:
python_path = os.path.join(subprocess_venv_path, "bin/python")
else:
python_path = "python"
cmd = [
python_path,
"-m",
"sglang.launch_server",
]
for key, arg in kwargs.items():
if isinstance(arg, bool):
if arg:
cmd.append(
f"--{key.replace('_', '-')}"
)
else:
cmd.append(
f"--no-{key.replace('_', '-')}"
)
else:
cmd.append(
f"--{key.replace('_', '-')}"
)
if isinstance(arg, float):
cmd.append(str(arg))
elif isinstance(arg, int):
cmd.append(str(arg))
elif isinstance(arg, str):
cmd.append(arg)
else:
raise NotImplementedError
print(f"DEBUG: SGLangServerExecutor: subprocess command = {cmd}")
proc = subprocess.Popen(cmd, shell=False, text=True)
self._server_host = server_host
self._server_port = server_port
self._server_proc = proc
self._server_pid = proc.pid
elif backend == "docker":
raise NotImplementedError
else:
raise NotImplementedError
assert self._server_pid is not None
self._server_backend = backend
self._pool_ctr = 0
self._pool_dict = dict()
self._pool_work = set()
self._pool_exec = concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers
)
print(f"DEBUG: SGLangServerExecutor: post init...")
if backend == "spawn":
_ = rx.recv()
elif backend == "subprocess":
heartbeat_args = {
"_host": self._server_host,
"_port": self._server_port,
}
while True:
try:
w = self._pool_exec.submit(
_sglang_server_heartbeat,
**heartbeat_args
)
work = [w]
for w in concurrent.futures.as_completed(work):
output = w.result()
#print(f"DEBUG: SGLangServerExecutor: post init: heartbeat output = {output}")
print(f"DEBUG: SGLangServerExecutor: post init: heartbeat: ok")
except Exception as e:
print(f"DEBUG: SGLangServerExecutor: post init: retry heartbeat: exception = {e}")
#print(f"DEBUG: SGLangServerExecutor: post init: retry heartbeat: traceback:")
#print(traceback.format_exc())
print(f"DEBUG: SGLangServerExecutor: post init: retry heartbeat: sleep...")
time.sleep(10.0)
continue
break
print(f"DEBUG: SGLangServerExecutor: post init: done")
def join(self):
print(f"DEBUG: SGLangServerExecutor: join...")
if self._server_backend == "spawn":
sglang.srt.utils.kill_process_tree(self._server_pid)
self._server_proc = None
elif self._server_backend == "subprocess":
if False:
from tenet.utils.process import terminate_process_group
terminate_process_group(self._server_pid)
self._server_proc.kill()
self._server_proc = None
self._server_pid = None
else:
raise NotImplementedError
def submit(
self,
input_ids: Union[None, list[int], list[list[int]]] = None,
prompt_token_ids: Union[None, list[int], list[list[int]]] = None,
sampling_params: Union[None, dict[str, Any], list[dict[str, Any]]] = None,
keys: Optional[list[Any]] = None,
):
if (
input_ids is not None and
prompt_token_ids is not None
):
assert False, (
"SGLangServerExecutor.submit supports either `input_ids` or `prompt_token_ids` but not both"
)
elif (
input_ids is None and
prompt_token_ids is not None
):
input_ids = prompt_token_ids
assert input_ids is not None
if (
isinstance(input_ids, list) and
len(input_ids) > 0
):
if isinstance(input_ids[0], int):
input_ids = [input_ids]
batch_size = len(input_ids)
work = []
for batch_idx in range(batch_size):
params = None
if isinstance(sampling_params, dict):
params = sampling_params
elif isinstance(sampling_params, list):
if len(sampling_params) <= 0:
raise NotImplementedError
if isinstance(sampling_params[batch_idx], dict):
params = sampling_params[batch_idx]
else:
raise NotImplementedError
else:
raise NotImplementedError
assert params is not None
submit_args = {
"input_ids": input_ids[batch_idx],
"sampling_params": params,
"_ctr": self._pool_ctr,
"_host": self._server_host,
"_port": self._server_port,
}
w = self._pool_exec.submit(
_sglang_server_submit,
**submit_args
)
self._pool_dict[w] = (self._pool_ctr, keys[batch_idx])
self._pool_work.add(w)
self._pool_ctr += 1
work.append(w)
return work
def as_completed(self):
# NB(peter): the `wait` version _should be_ re-entrant-safe.
# for w in concurrent.futures.as_completed(self._pool_work):
while self._pool_work:
done, work = concurrent.futures.wait(
self._pool_work,
return_when=concurrent.futures.FIRST_COMPLETED
)
self._pool_work = work
if False:
assert len(done) <= 1
if not done:
continue
for w in done:
ctr, key = self._pool_dict.pop(w)
result = w.result()
assert ctr == result["_ctr"]
output = result["output"]
yield SGLangRequest(
_ctr=ctr,
_key=key,
_output=output,
)