prompt logprobs simplicity

This commit is contained in:
Jai Suphavadeeprasit 2026-03-03 22:06:49 -05:00
parent f1c20591b6
commit 5aaf7a346c
5 changed files with 69 additions and 69 deletions

View file

@ -641,10 +641,7 @@ class APIServer(ABC):
async def get_logprobs(self, **kwargs) -> Dict[str, Any]:
"""
Trainer-agnostic prompt-logprob API with normalized output schema.
This default implementation is built from `tokens_and_logprobs_completion`
and returns prompt-side singleton top-k values.
Prompt-logprob API with strict normalized output schema.
Returns:
Dict with:
@ -652,20 +649,8 @@ class APIServer(ABC):
- prompt_topk_token_ids: List[List[int]]
- prompt_topk_logprobs: List[List[float]]
"""
(
prompt_tokens,
_output_tokens_list,
_output_logprobs_list,
_finish_reasons,
) = await self.tokens_and_logprobs_completion(**kwargs)
# Fallback path does not have true prompt-logprobs, so we provide
# interface-compatible singleton values for each prompt token.
prompt_topk_token_ids = [[token_id] for token_id in prompt_tokens]
prompt_topk_logprobs = [[1.0] for _ in prompt_tokens]
return {
"prompt_tokens": prompt_tokens,
"prompt_topk_token_ids": prompt_topk_token_ids,
"prompt_topk_logprobs": prompt_topk_logprobs,
}
raise NotImplementedError(
f"{self.__class__.__name__}.get_logprobs must be implemented by the "
"server backend and must return prompt_tokens, "
"prompt_topk_token_ids, and prompt_topk_logprobs."
)