mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
prompt logprobs simplicity
This commit is contained in:
parent
f1c20591b6
commit
5aaf7a346c
5 changed files with 69 additions and 69 deletions
|
|
@ -565,34 +565,52 @@ class ManagedServer:
|
|||
prompt, extending_node
|
||||
)
|
||||
|
||||
if hasattr(self.server, "get_logprobs"):
|
||||
payload = await self.server.get_logprobs(**request_kwargs)
|
||||
else:
|
||||
# Backwards-compatible fallback for harness/test doubles.
|
||||
(
|
||||
prompt_tokens,
|
||||
_output_tokens_list,
|
||||
_output_logprobs_list,
|
||||
_finish_reasons,
|
||||
) = await self.server.tokens_and_logprobs_completion(**request_kwargs)
|
||||
payload = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"prompt_topk_token_ids": [[tok] for tok in prompt_tokens],
|
||||
"prompt_topk_logprobs": [[1.0] for _ in prompt_tokens],
|
||||
}
|
||||
if not hasattr(self.server, "get_logprobs"):
|
||||
raise NotImplementedError(
|
||||
f"{self.server.__class__.__name__} does not implement get_logprobs. "
|
||||
"Strict mode requires backend prompt logprobs."
|
||||
)
|
||||
|
||||
# Normalize required keys if provider omitted prompt top-k arrays.
|
||||
if "prompt_topk_token_ids" not in payload:
|
||||
payload["prompt_topk_token_ids"] = [
|
||||
[tok] for tok in payload.get("prompt_tokens", [])
|
||||
]
|
||||
if "prompt_topk_logprobs" not in payload:
|
||||
payload["prompt_topk_logprobs"] = [
|
||||
[1.0] for _ in payload.get("prompt_tokens", [])
|
||||
]
|
||||
payload = await self.server.get_logprobs(**request_kwargs)
|
||||
self._validate_prompt_logprob_payload(payload)
|
||||
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def _validate_prompt_logprob_payload(payload: Dict[str, Any]) -> None:
|
||||
required = ("prompt_tokens", "prompt_topk_token_ids", "prompt_topk_logprobs")
|
||||
missing = [k for k in required if k not in payload]
|
||||
if missing:
|
||||
raise ValueError(
|
||||
f"get_logprobs response missing required keys: {missing}"
|
||||
)
|
||||
|
||||
prompt_tokens = payload["prompt_tokens"]
|
||||
token_ids = payload["prompt_topk_token_ids"]
|
||||
logprobs = payload["prompt_topk_logprobs"]
|
||||
|
||||
if not isinstance(prompt_tokens, list):
|
||||
raise ValueError("prompt_tokens must be a list[int].")
|
||||
if not isinstance(token_ids, list) or not isinstance(logprobs, list):
|
||||
raise ValueError(
|
||||
"prompt_topk_token_ids and prompt_topk_logprobs must be list-of-list."
|
||||
)
|
||||
if len(token_ids) != len(prompt_tokens) or len(logprobs) != len(prompt_tokens):
|
||||
raise ValueError(
|
||||
"prompt_topk arrays must align with prompt_tokens length."
|
||||
)
|
||||
|
||||
for idx, (tok_row, lp_row) in enumerate(zip(token_ids, logprobs)):
|
||||
if not isinstance(tok_row, list) or not isinstance(lp_row, list):
|
||||
raise ValueError(
|
||||
"prompt_topk_token_ids and prompt_topk_logprobs must be list-of-list."
|
||||
)
|
||||
if len(tok_row) != len(lp_row):
|
||||
raise ValueError(
|
||||
f"prompt_topk row mismatch at position {idx}: "
|
||||
f"{len(tok_row)} token ids vs {len(lp_row)} logprobs."
|
||||
)
|
||||
|
||||
|
||||
class DummyManagedServer:
|
||||
"""
|
||||
|
|
@ -706,17 +724,12 @@ class DummyManagedServer:
|
|||
|
||||
async def get_logprobs(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Return interface-compatible dummy logprob payload.
|
||||
|
||||
This keeps interface parity with ManagedServer while making it explicit
|
||||
that results are placeholders and not suitable for training.
|
||||
Dummy managed server does not provide real prompt logprobs.
|
||||
"""
|
||||
prompt_tokens = self.DUMMY_TOKENS[:]
|
||||
return {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"prompt_topk_token_ids": [[tok] for tok in prompt_tokens],
|
||||
"prompt_topk_logprobs": [[self.DUMMY_LOGPROBS[0]] for _ in prompt_tokens],
|
||||
}
|
||||
raise NotImplementedError(
|
||||
"DummyManagedServer does not support get_logprobs in strict mode. "
|
||||
"Use a backend with real prompt logprob support."
|
||||
)
|
||||
|
||||
|
||||
class ManagedServerAdapter:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue