prompt logprobs

This commit is contained in:
Jai Suphavadeeprasit 2026-03-03 21:56:11 -05:00
parent e98100e5f6
commit 439b9b129b
7 changed files with 73 additions and 138 deletions

View file

@ -262,7 +262,7 @@ class VLLMServer(APIServer):
async def get_logprobs(self, **kwargs) -> Dict[str, Any]:
"""
Fetch normalized logprobs from vLLM /generate with optional top-k.
Fetch normalized prompt logprobs from vLLM /generate with optional top-k.
Args:
top_k / top_logprobs: Optional number of logprobs per position.
@ -272,11 +272,8 @@ class VLLMServer(APIServer):
Returns:
Normalized dict:
- prompt_tokens
- sequence_token_ids
- sequence_logprobs
- sequence_topk_token_ids
- sequence_topk_logprobs
- finish_reasons
- prompt_topk_token_ids
- prompt_topk_logprobs
"""
assert (
kwargs.get("prompt", None) is not None
@ -306,10 +303,8 @@ class VLLMServer(APIServer):
kwargs["max_tokens"] = kwargs.pop("max_completion_tokens")
kwargs.pop("model", None)
request_data = {
"prompt": {"prompt_token_ids": prompt_tokens},
"logprobs": top_k,
}
request_data = {"prompt": {"prompt_token_ids": prompt_tokens}}
request_data["prompt_logprobs"] = top_k
request_data.update(kwargs)
# Keep semaphore behavior consistent with other server calls.
@ -333,40 +328,30 @@ class VLLMServer(APIServer):
response.raise_for_status()
results = await response.json()
sequence_topk_token_ids: List[List[List[int]]] = []
sequence_topk_logprobs: List[List[List[float]]] = []
sequence_token_ids: List[List[int]] = []
sequence_logprobs: List[List[float]] = []
finish_reasons: List[Any] = []
raw_prompt_logprobs = results.get("prompt_logprobs")
if raw_prompt_logprobs is None:
raise ValueError(
"vLLM /generate response missing 'prompt_logprobs'. "
"Ensure backend supports prompt logprobs."
)
for token_logprobs_seq, finish_reason in zip(
results["logprobs"], results["finish_reasons"]
):
seq_topk_token_ids: List[List[int]] = []
seq_topk_logprobs: List[List[float]] = []
seq_token_ids: List[int] = []
seq_logprobs: List[float] = []
# Handle either direct [position] payloads or [sequence][position] payloads.
if raw_prompt_logprobs and isinstance(raw_prompt_logprobs[0], list):
prompt_entries = raw_prompt_logprobs[0]
else:
prompt_entries = raw_prompt_logprobs
for token_logprobs_entry in token_logprobs_seq:
topk_ids, topk_lps = self._normalize_topk_entry(token_logprobs_entry)
seq_topk_token_ids.append(topk_ids)
seq_topk_logprobs.append(topk_lps)
seq_token_ids.append(topk_ids[0] if topk_ids else -1)
seq_logprobs.append(topk_lps[0] if topk_lps else 0.0)
sequence_topk_token_ids.append(seq_topk_token_ids)
sequence_topk_logprobs.append(seq_topk_logprobs)
sequence_token_ids.append(seq_token_ids)
sequence_logprobs.append(seq_logprobs)
finish_reasons.append(finish_reason)
prompt_topk_token_ids: List[List[int]] = []
prompt_topk_logprobs: List[List[float]] = []
for entry in prompt_entries:
topk_ids, topk_lps = self._normalize_topk_entry(entry)
prompt_topk_token_ids.append(topk_ids)
prompt_topk_logprobs.append(topk_lps)
return {
"prompt_tokens": prompt_tokens,
"sequence_token_ids": sequence_token_ids,
"sequence_logprobs": sequence_logprobs,
"sequence_topk_token_ids": sequence_topk_token_ids,
"sequence_topk_logprobs": sequence_topk_logprobs,
"finish_reasons": finish_reasons,
"prompt_topk_token_ids": prompt_topk_token_ids,
"prompt_topk_logprobs": prompt_topk_logprobs,
}