init commit

This commit is contained in:
Jai Suphavadeeprasit 2026-03-03 11:32:09 -05:00
parent 887a94374c
commit b9291aa29f
5 changed files with 357 additions and 0 deletions

View file

@ -346,6 +346,66 @@ class ServerManager:
**kwargs
)
async def get_logprobs(self, **kwargs) -> dict:
"""
Route normalized get_logprobs requests to the most available server.
Returns a normalized dict with:
- prompt_tokens
- sequence_token_ids
- sequence_logprobs
- sequence_topk_token_ids
- sequence_topk_logprobs
- finish_reasons
"""
n = kwargs.get("n", 1)
if n > self.max_n_completions:
# Split into multiple requests and merge sequence-level outputs.
results = []
total_n = n
while total_n > 0:
n_to_use = min(total_n, self.max_n_completions)
kwargs["n"] = n_to_use
results.append(self.get_logprobs(**kwargs))
total_n -= n_to_use
results = await asyncio.gather(*results)
merged = {
"prompt_tokens": results[0]["prompt_tokens"],
"sequence_token_ids": [],
"sequence_logprobs": [],
"sequence_topk_token_ids": [],
"sequence_topk_logprobs": [],
"finish_reasons": [],
}
for result in results:
merged["sequence_token_ids"].extend(result["sequence_token_ids"])
merged["sequence_logprobs"].extend(result["sequence_logprobs"])
merged["sequence_topk_token_ids"].extend(
result["sequence_topk_token_ids"]
)
merged["sequence_topk_logprobs"].extend(
result["sequence_topk_logprobs"]
)
merged["finish_reasons"].extend(result["finish_reasons"])
return merged
is_train = kwargs.pop("split", "train") == "train"
most_available_server = 0
most_available_server_num_slots = -1
await self.wait_for_sem(is_train)
for i, server in enumerate(self.servers):
if not server.server_healthy:
continue
if (
server.sem._value if is_train else server.eval_sem._value
) > most_available_server_num_slots:
most_available_server = i
most_available_server_num_slots = (
server.sem._value if is_train else server.eval_sem._value
)
return await self.servers[most_available_server].get_logprobs(**kwargs)
@asynccontextmanager
async def dedicated_server(self) -> AsyncGenerator[OpenAIServer, None]:
most_available_server = 0