add managed vllm server

This commit is contained in:
Dakota 2025-11-07 13:06:49 -06:00
parent 578175a709
commit e6ac3abdcb
9 changed files with 597 additions and 15 deletions

View file

@ -18,6 +18,7 @@ from atroposlib.envs.server_handling.server_baseline import (
from atroposlib.envs.server_handling.server_harness import ServerHarness
from atroposlib.envs.server_handling.sglang_server import SGLangServer
from atroposlib.envs.server_handling.trl_vllm_server import TrlVllmServer
from atroposlib.envs.server_handling.vllm_server import VLLMServer
class ServerManagerConfig(BaseModel):
@ -58,6 +59,8 @@ class ServerManager:
server_class = TrlVllmServer
elif configs.server_type == "sglang":
server_class = SGLangServer
elif configs.server_type == "vllm":
server_class = VLLMServer
else:
raise ValueError(f"Invalid server type: {configs.server_type}")
else:
@ -67,6 +70,8 @@ class ServerManager:
server_class = TrlVllmServer
elif configs[0].server_type == "sglang":
server_class = SGLangServer
elif configs[0].server_type == "vllm":
server_class = VLLMServer
else:
raise ValueError(f"Invalid server type: {configs[0].server_type}")
if testing:
@ -198,7 +203,7 @@ class ServerManager:
for completion in completions[1:]:
out.choices.extend(completion.choices)
return out
is_train = kwargs.get("split", "train") == "train"
is_train = kwargs.pop("split", "train") == "train"
most_available_server = 0
most_available_server_num_slots = -1
await self.wait_for_sem(is_train)
@ -231,7 +236,7 @@ class ServerManager:
for completion in completions[1:]:
out.choices.extend(completion.choices)
return out
is_train = kwargs.get("split", "train") == "train"
is_train = kwargs.pop("split", "train") == "train"
most_available_server = 0
most_available_server_num_slots = -1
await self.wait_for_sem(is_train)
@ -276,7 +281,7 @@ class ServerManager:
finish_reasons.extend(out_finish_reasons)
return (prompt_tokens, output_tokens, output_logprobs, finish_reasons)
is_train = kwargs.get("split", "train") == "train"
is_train = kwargs.pop("split", "train") == "train"
most_available_server = 0
most_available_server_num_slots = -1
await self.wait_for_sem(is_train)