init commit

This commit is contained in:
Jai Suphavadeeprasit 2026-03-03 11:32:09 -05:00
parent 887a94374c
commit b9291aa29f
5 changed files with 357 additions and 0 deletions

View file

@ -3,6 +3,7 @@
import asyncio
import warnings
from typing import Any, Dict, List, Tuple
import aiohttp
import openai
@ -231,6 +232,143 @@ class VLLMServer(APIServer):
finish_reasons_list,
)
@staticmethod
def _normalize_topk_entry(
token_logprobs_entry: Any,
) -> Tuple[List[int], List[float]]:
"""
Normalize a single token-position logprob payload into parallel top-k arrays.
Supports common structures from vLLM responses:
- dict: {token_id: logprob, ...}
- list[dict]: [{token_id: logprob}, ...]
"""
if isinstance(token_logprobs_entry, dict):
items = list(token_logprobs_entry.items())
return [int(k) for k, _ in items], [float(v) for _, v in items]
if isinstance(token_logprobs_entry, list):
token_ids: List[int] = []
logprobs: List[float] = []
for item in token_logprobs_entry:
if not isinstance(item, dict):
continue
for key, value in item.items():
token_ids.append(int(key))
logprobs.append(float(value))
return token_ids, logprobs
return [], []
async def get_logprobs(self, **kwargs) -> Dict[str, Any]:
"""
Fetch normalized logprobs from vLLM /generate with optional top-k.
Args:
top_k / top_logprobs: Optional number of logprobs per position.
Defaults to 1.
prompt or input_ids: Input text or token IDs.
Returns:
Normalized dict:
- prompt_tokens
- sequence_token_ids
- sequence_logprobs
- sequence_topk_token_ids
- sequence_topk_logprobs
- finish_reasons
"""
assert (
kwargs.get("prompt", None) is not None
or kwargs.get("input_ids", None) is not None
), "Prompt or input_ids is required for get_logprobs!"
top_k = int(kwargs.pop("top_k", kwargs.pop("top_logprobs", 1)))
top_k = max(1, top_k)
# Use input_ids if provided (from ManagedServer), otherwise tokenize prompt
if "input_ids" in kwargs:
prompt_tokens = kwargs.pop("input_ids")
kwargs.pop("prompt", None)
else:
prompt_tokens = self.tokenizer.encode(kwargs.pop("prompt"))
# Check for double BOS token.
if (
len(prompt_tokens) >= 2
and prompt_tokens[0] == self.tokenizer.bos_token_id == prompt_tokens[1]
):
prompt_tokens = prompt_tokens[1:]
if "max_new_tokens" in kwargs:
kwargs["max_tokens"] = kwargs.pop("max_new_tokens")
if "max_completion_tokens" in kwargs:
kwargs["max_tokens"] = kwargs.pop("max_completion_tokens")
kwargs.pop("model", None)
request_data = {
"prompt": {"prompt_token_ids": prompt_tokens},
"logprobs": top_k,
}
request_data.update(kwargs)
# Keep semaphore behavior consistent with other server calls.
split = request_data.pop("split", "train")
sem = self.sem if split == "train" else self.eval_sem
while not self.server_healthy:
await asyncio.sleep(1)
async with sem:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.config.base_url.replace('/v1', '')}/generate",
json=request_data,
headers=(
{"Authorization": f"Bearer {self.config.api_key}"}
if self.config.api_key
else {}
),
timeout=aiohttp.ClientTimeout(total=self.config.timeout),
) as response:
response.raise_for_status()
results = await response.json()
sequence_topk_token_ids: List[List[List[int]]] = []
sequence_topk_logprobs: List[List[List[float]]] = []
sequence_token_ids: List[List[int]] = []
sequence_logprobs: List[List[float]] = []
finish_reasons: List[Any] = []
for token_logprobs_seq, finish_reason in zip(
results["logprobs"], results["finish_reasons"]
):
seq_topk_token_ids: List[List[int]] = []
seq_topk_logprobs: List[List[float]] = []
seq_token_ids: List[int] = []
seq_logprobs: List[float] = []
for token_logprobs_entry in token_logprobs_seq:
topk_ids, topk_lps = self._normalize_topk_entry(token_logprobs_entry)
seq_topk_token_ids.append(topk_ids)
seq_topk_logprobs.append(topk_lps)
seq_token_ids.append(topk_ids[0] if topk_ids else -1)
seq_logprobs.append(topk_lps[0] if topk_lps else 0.0)
sequence_topk_token_ids.append(seq_topk_token_ids)
sequence_topk_logprobs.append(seq_topk_logprobs)
sequence_token_ids.append(seq_token_ids)
sequence_logprobs.append(seq_logprobs)
finish_reasons.append(finish_reason)
return {
"prompt_tokens": prompt_tokens,
"sequence_token_ids": sequence_token_ids,
"sequence_logprobs": sequence_logprobs,
"sequence_topk_token_ids": sequence_topk_token_ids,
"sequence_topk_logprobs": sequence_topk_logprobs,
"finish_reasons": finish_reasons,
}
def resolve_openai_configs(
default_server_configs,