mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
add managed vllm server
This commit is contained in:
parent
578175a709
commit
e6ac3abdcb
9 changed files with 597 additions and 15 deletions
|
|
@ -18,6 +18,7 @@ from atroposlib.envs.server_handling.server_baseline import (
|
|||
from atroposlib.envs.server_handling.server_harness import ServerHarness
|
||||
from atroposlib.envs.server_handling.sglang_server import SGLangServer
|
||||
from atroposlib.envs.server_handling.trl_vllm_server import TrlVllmServer
|
||||
from atroposlib.envs.server_handling.vllm_server import VLLMServer
|
||||
|
||||
|
||||
class ServerManagerConfig(BaseModel):
|
||||
|
|
@ -58,6 +59,8 @@ class ServerManager:
|
|||
server_class = TrlVllmServer
|
||||
elif configs.server_type == "sglang":
|
||||
server_class = SGLangServer
|
||||
elif configs.server_type == "vllm":
|
||||
server_class = VLLMServer
|
||||
else:
|
||||
raise ValueError(f"Invalid server type: {configs.server_type}")
|
||||
else:
|
||||
|
|
@ -67,6 +70,8 @@ class ServerManager:
|
|||
server_class = TrlVllmServer
|
||||
elif configs[0].server_type == "sglang":
|
||||
server_class = SGLangServer
|
||||
elif configs[0].server_type == "vllm":
|
||||
server_class = VLLMServer
|
||||
else:
|
||||
raise ValueError(f"Invalid server type: {configs[0].server_type}")
|
||||
if testing:
|
||||
|
|
@ -198,7 +203,7 @@ class ServerManager:
|
|||
for completion in completions[1:]:
|
||||
out.choices.extend(completion.choices)
|
||||
return out
|
||||
is_train = kwargs.get("split", "train") == "train"
|
||||
is_train = kwargs.pop("split", "train") == "train"
|
||||
most_available_server = 0
|
||||
most_available_server_num_slots = -1
|
||||
await self.wait_for_sem(is_train)
|
||||
|
|
@ -231,7 +236,7 @@ class ServerManager:
|
|||
for completion in completions[1:]:
|
||||
out.choices.extend(completion.choices)
|
||||
return out
|
||||
is_train = kwargs.get("split", "train") == "train"
|
||||
is_train = kwargs.pop("split", "train") == "train"
|
||||
most_available_server = 0
|
||||
most_available_server_num_slots = -1
|
||||
await self.wait_for_sem(is_train)
|
||||
|
|
@ -276,7 +281,7 @@ class ServerManager:
|
|||
finish_reasons.extend(out_finish_reasons)
|
||||
return (prompt_tokens, output_tokens, output_logprobs, finish_reasons)
|
||||
|
||||
is_train = kwargs.get("split", "train") == "train"
|
||||
is_train = kwargs.pop("split", "train") == "train"
|
||||
most_available_server = 0
|
||||
most_available_server_num_slots = -1
|
||||
await self.wait_for_sem(is_train)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue