prompt logprobs

This commit is contained in:
Jai Suphavadeeprasit 2026-03-03 21:56:11 -05:00
parent e98100e5f6
commit 439b9b129b
7 changed files with 73 additions and 138 deletions

View file

@ -531,7 +531,7 @@ class ManagedServer:
async def get_logprobs(self, **kwargs) -> Dict[str, Any]:
"""
Fetch logprobs via wrapped server with a normalized trainer-agnostic schema.
Fetch prompt logprobs via wrapped server with a normalized schema.
Supported inputs:
- prompt
@ -541,11 +541,8 @@ class ManagedServer:
Returns:
Dict with:
- prompt_tokens
- sequence_token_ids
- sequence_logprobs
- sequence_topk_token_ids
- sequence_topk_logprobs
- finish_reasons
- prompt_topk_token_ids
- prompt_topk_logprobs
"""
request_kwargs = kwargs.copy()
messages = request_kwargs.pop("messages", None)
@ -574,31 +571,24 @@ class ManagedServer:
# Backwards-compatible fallback for harness/test doubles.
(
prompt_tokens,
output_tokens_list,
output_logprobs_list,
finish_reasons,
_output_tokens_list,
_output_logprobs_list,
_finish_reasons,
) = await self.server.tokens_and_logprobs_completion(**request_kwargs)
payload = {
"prompt_tokens": prompt_tokens,
"sequence_token_ids": output_tokens_list,
"sequence_logprobs": output_logprobs_list,
"sequence_topk_token_ids": [
[[tok] for tok in seq] for seq in output_tokens_list
],
"sequence_topk_logprobs": [
[[lp] for lp in seq] for seq in output_logprobs_list
],
"finish_reasons": finish_reasons,
"prompt_topk_token_ids": [[tok] for tok in prompt_tokens],
"prompt_topk_logprobs": [[1.0] for _ in prompt_tokens],
}
# Normalize required keys if provider omitted top-k arrays.
if "sequence_topk_token_ids" not in payload:
payload["sequence_topk_token_ids"] = [
[[tok] for tok in seq] for seq in payload["sequence_token_ids"]
# Normalize required keys if provider omitted prompt top-k arrays.
if "prompt_topk_token_ids" not in payload:
payload["prompt_topk_token_ids"] = [
[tok] for tok in payload.get("prompt_tokens", [])
]
if "sequence_topk_logprobs" not in payload:
payload["sequence_topk_logprobs"] = [
[[lp] for lp in seq] for seq in payload["sequence_logprobs"]
if "prompt_topk_logprobs" not in payload:
payload["prompt_topk_logprobs"] = [
[1.0] for _ in payload.get("prompt_tokens", [])
]
return payload
@ -722,15 +712,12 @@ class DummyManagedServer:
that results are placeholders and not suitable for training.
"""
n = int(kwargs.get("n", 1))
seq_ids = [self.DUMMY_TOKENS[:] for _ in range(n)]
seq_lps = [self.DUMMY_LOGPROBS[:] for _ in range(n)]
prompt_tokens = self.DUMMY_TOKENS[:]
return {
"prompt_tokens": [],
"sequence_token_ids": seq_ids,
"sequence_logprobs": seq_lps,
"sequence_topk_token_ids": [[[tok] for tok in seq] for seq in seq_ids],
"sequence_topk_logprobs": [[[lp] for lp in seq] for seq in seq_lps],
"finish_reasons": ["stop"] * n,
"prompt_tokens": prompt_tokens,
"prompt_topk_token_ids": [[tok] for tok in prompt_tokens],
"prompt_topk_logprobs": [[self.DUMMY_LOGPROBS[0]] for _ in prompt_tokens],
"finish_reasons": ["stop"] * n, # Retained for compatibility in callers.
}