prompt logprobs

This commit is contained in:
Jai Suphavadeeprasit 2026-03-03 21:56:11 -05:00
parent e98100e5f6
commit 439b9b129b
7 changed files with 73 additions and 138 deletions

View file

@ -641,40 +641,31 @@ class APIServer(ABC):
async def get_logprobs(self, **kwargs) -> Dict[str, Any]:
"""
Trainer-agnostic logprob API with normalized output schema.
Trainer-agnostic prompt-logprob API with normalized output schema.
This default implementation is built from `tokens_and_logprobs_completion`
and returns sampled-token logprobs (top-k singleton per position).
and returns prompt-side singleton top-k values.
Returns:
Dict with:
- prompt_tokens: List[int]
- sequence_token_ids: List[List[int]]
- sequence_logprobs: List[List[float]]
- sequence_topk_token_ids: List[List[List[int]]]
- sequence_topk_logprobs: List[List[List[float]]]
- finish_reasons: List[Any]
- prompt_topk_token_ids: List[List[int]]
- prompt_topk_logprobs: List[List[float]]
"""
(
prompt_tokens,
output_tokens_list,
output_logprobs_list,
finish_reasons,
_output_tokens_list,
_output_logprobs_list,
_finish_reasons,
) = await self.tokens_and_logprobs_completion(**kwargs)
topk_token_ids = [
[[token_id] for token_id in seq_tokens] for seq_tokens in output_tokens_list
]
topk_logprobs = [
[[logprob] for logprob in seq_logprobs]
for seq_logprobs in output_logprobs_list
]
# Fallback path does not have true prompt-logprobs, so we provide
# interface-compatible singleton values for each prompt token.
prompt_topk_token_ids = [[token_id] for token_id in prompt_tokens]
prompt_topk_logprobs = [[1.0] for _ in prompt_tokens]
return {
"prompt_tokens": prompt_tokens,
"sequence_token_ids": output_tokens_list,
"sequence_logprobs": output_logprobs_list,
"sequence_topk_token_ids": topk_token_ids,
"sequence_topk_logprobs": topk_logprobs,
"finish_reasons": finish_reasons,
"prompt_topk_token_ids": prompt_topk_token_ids,
"prompt_topk_logprobs": prompt_topk_logprobs,
}