prompt logprobs simplicity

This commit is contained in:
Jai Suphavadeeprasit 2026-03-03 22:06:49 -05:00
parent f1c20591b6
commit 5aaf7a346c
5 changed files with 69 additions and 69 deletions

View file

@ -565,34 +565,52 @@ class ManagedServer:
prompt, extending_node
)
if hasattr(self.server, "get_logprobs"):
payload = await self.server.get_logprobs(**request_kwargs)
else:
# Backwards-compatible fallback for harness/test doubles.
(
prompt_tokens,
_output_tokens_list,
_output_logprobs_list,
_finish_reasons,
) = await self.server.tokens_and_logprobs_completion(**request_kwargs)
payload = {
"prompt_tokens": prompt_tokens,
"prompt_topk_token_ids": [[tok] for tok in prompt_tokens],
"prompt_topk_logprobs": [[1.0] for _ in prompt_tokens],
}
if not hasattr(self.server, "get_logprobs"):
raise NotImplementedError(
f"{self.server.__class__.__name__} does not implement get_logprobs. "
"Strict mode requires backend prompt logprobs."
)
# Normalize required keys if provider omitted prompt top-k arrays.
if "prompt_topk_token_ids" not in payload:
payload["prompt_topk_token_ids"] = [
[tok] for tok in payload.get("prompt_tokens", [])
]
if "prompt_topk_logprobs" not in payload:
payload["prompt_topk_logprobs"] = [
[1.0] for _ in payload.get("prompt_tokens", [])
]
payload = await self.server.get_logprobs(**request_kwargs)
self._validate_prompt_logprob_payload(payload)
return payload
@staticmethod
def _validate_prompt_logprob_payload(payload: Dict[str, Any]) -> None:
required = ("prompt_tokens", "prompt_topk_token_ids", "prompt_topk_logprobs")
missing = [k for k in required if k not in payload]
if missing:
raise ValueError(
f"get_logprobs response missing required keys: {missing}"
)
prompt_tokens = payload["prompt_tokens"]
token_ids = payload["prompt_topk_token_ids"]
logprobs = payload["prompt_topk_logprobs"]
if not isinstance(prompt_tokens, list):
raise ValueError("prompt_tokens must be a list[int].")
if not isinstance(token_ids, list) or not isinstance(logprobs, list):
raise ValueError(
"prompt_topk_token_ids and prompt_topk_logprobs must be list-of-list."
)
if len(token_ids) != len(prompt_tokens) or len(logprobs) != len(prompt_tokens):
raise ValueError(
"prompt_topk arrays must align with prompt_tokens length."
)
for idx, (tok_row, lp_row) in enumerate(zip(token_ids, logprobs)):
if not isinstance(tok_row, list) or not isinstance(lp_row, list):
raise ValueError(
"prompt_topk_token_ids and prompt_topk_logprobs must be list-of-list."
)
if len(tok_row) != len(lp_row):
raise ValueError(
f"prompt_topk row mismatch at position {idx}: "
f"{len(tok_row)} token ids vs {len(lp_row)} logprobs."
)
class DummyManagedServer:
"""
@ -706,17 +724,12 @@ class DummyManagedServer:
async def get_logprobs(self, **kwargs) -> Dict[str, Any]:
"""
Return interface-compatible dummy logprob payload.
This keeps interface parity with ManagedServer while making it explicit
that results are placeholders and not suitable for training.
Dummy managed server does not provide real prompt logprobs.
"""
prompt_tokens = self.DUMMY_TOKENS[:]
return {
"prompt_tokens": prompt_tokens,
"prompt_topk_token_ids": [[tok] for tok in prompt_tokens],
"prompt_topk_logprobs": [[self.DUMMY_LOGPROBS[0]] for _ in prompt_tokens],
}
raise NotImplementedError(
"DummyManagedServer does not support get_logprobs in strict mode. "
"Use a backend with real prompt logprob support."
)
class ManagedServerAdapter: