add sglang specific token level logprob handling and server manager/baseline logprob/token fn

This commit is contained in:
Dakota 2025-10-16 12:38:03 -05:00
parent 4862e9972f
commit c36ec29656
4 changed files with 512 additions and 37 deletions

View file

@ -16,6 +16,7 @@ from atroposlib.envs.server_handling.server_baseline import (
)
from atroposlib.envs.server_handling.server_harness import ServerHarness
from atroposlib.envs.server_handling.trl_vllm_server import TrlVllmServer
from atroposlib.envs.server_handling.sglang_server import SGLangServer
class ServerManagerConfig(BaseModel):
@ -54,6 +55,8 @@ class ServerManager:
server_class = OpenAIServer
elif configs.server_type == "trl":
server_class = TrlVllmServer
elif configs.server_type == "sglang":
server_class = SGLangServer
else:
raise ValueError(f"Invalid server type: {configs.server_type}")
else:
@ -61,6 +64,8 @@ class ServerManager:
server_class = OpenAIServer
elif configs[0].server_type == "trl":
server_class = TrlVllmServer
elif configs[0].server_type == "sglang":
server_class = SGLangServer
else:
raise ValueError(f"Invalid server type: {configs[0].server_type}")
if testing:
@ -241,6 +246,51 @@ class ServerManager:
)
return await self.servers[most_available_server].completion(**kwargs)
async def tokens_and_logprobs_completion(
self, **kwargs
) -> tuple[list, list, list, list]:
"""
Get tokens and logprobs from completion.
Returns (prompt_tokens, output_tokens, output_logprobs, finish_reasons).
"""
n = kwargs.get("n", 1)
if n > self.max_n_completions:
# Split into multiple completions
results = []
total_n = n
while total_n > 0:
n_to_use = min(total_n, self.max_n_completions)
kwargs["n"] = n_to_use
results.append(self.tokens_and_logprobs_completion(**kwargs))
total_n -= n_to_use
results = await asyncio.gather(*results)
# Merge results - prompt_tokens should be same, extend output lists
prompt_tokens = results[0][0]
output_tokens = []
output_logprobs = []
finish_reasons = []
for _, out_tokens, out_logprobs, out_finish_reasons in results:
output_tokens.extend(out_tokens)
output_logprobs.extend(out_logprobs)
finish_reasons.extend(out_finish_reasons)
return (prompt_tokens, output_tokens, output_logprobs, finish_reasons)
is_train = kwargs.get("split", "train") == "train"
most_available_server = 0
most_available_server_num_slots = -1
await self.wait_for_sem(is_train)
for i, server in enumerate(self.servers):
if not server.server_healthy:
continue
if (
server.sem._value if is_train else server.eval_sem._value
) > most_available_server_num_slots:
most_available_server = i
most_available_server_num_slots = (
server.sem._value if is_train else server.eval_sem._value
)
return await self.servers[most_available_server].tokens_and_logprobs_completion(**kwargs)
@asynccontextmanager
async def dedicated_server(self) -> AsyncGenerator[OpenAIServer, None]:
most_available_server = 0