mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
fixing comments
This commit is contained in:
parent
51088ac24d
commit
1eeb31065f
6 changed files with 21 additions and 81 deletions
|
|
@ -553,18 +553,6 @@ class ManagedServer:
|
|||
else:
|
||||
prompt = request_kwargs.get("prompt")
|
||||
|
||||
# Reuse tracked context in non-tree mode when possible.
|
||||
if (
|
||||
not self.track_tree
|
||||
and self.tokenizer is not None
|
||||
and "input_ids" not in request_kwargs
|
||||
and prompt is not None
|
||||
):
|
||||
extending_node = self._find_extending_node(prompt)
|
||||
request_kwargs["input_ids"] = self._compute_input_ids(
|
||||
prompt, extending_node
|
||||
)
|
||||
|
||||
if not hasattr(self.server, "get_logprobs"):
|
||||
raise NotImplementedError(
|
||||
f"{self.server.__class__.__name__} does not implement get_logprobs. "
|
||||
|
|
@ -572,41 +560,8 @@ class ManagedServer:
|
|||
)
|
||||
|
||||
payload = await self.server.get_logprobs(**request_kwargs)
|
||||
self._validate_prompt_logprob_payload(payload)
|
||||
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def _validate_prompt_logprob_payload(payload: Dict[str, Any]) -> None:
|
||||
required = ("prompt_tokens", "prompt_topk_token_ids", "prompt_topk_logprobs")
|
||||
missing = [k for k in required if k not in payload]
|
||||
if missing:
|
||||
raise ValueError(f"get_logprobs response missing required keys: {missing}")
|
||||
|
||||
prompt_tokens = payload["prompt_tokens"]
|
||||
token_ids = payload["prompt_topk_token_ids"]
|
||||
logprobs = payload["prompt_topk_logprobs"]
|
||||
|
||||
if not isinstance(prompt_tokens, list):
|
||||
raise ValueError("prompt_tokens must be a list[int].")
|
||||
if not isinstance(token_ids, list) or not isinstance(logprobs, list):
|
||||
raise ValueError(
|
||||
"prompt_topk_token_ids and prompt_topk_logprobs must be list-of-list."
|
||||
)
|
||||
if len(token_ids) != len(prompt_tokens) or len(logprobs) != len(prompt_tokens):
|
||||
raise ValueError("prompt_topk arrays must align with prompt_tokens length.")
|
||||
|
||||
for idx, (tok_row, lp_row) in enumerate(zip(token_ids, logprobs)):
|
||||
if not isinstance(tok_row, list) or not isinstance(lp_row, list):
|
||||
raise ValueError(
|
||||
"prompt_topk_token_ids and prompt_topk_logprobs must be list-of-list."
|
||||
)
|
||||
if len(tok_row) != len(lp_row):
|
||||
raise ValueError(
|
||||
f"prompt_topk row mismatch at position {idx}: "
|
||||
f"{len(tok_row)} token ids vs {len(lp_row)} logprobs."
|
||||
)
|
||||
|
||||
|
||||
class DummyManagedServer:
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue