fixing comments

This commit is contained in:
Jai Suphavadeeprasit 2026-03-03 23:08:28 -05:00
parent 51088ac24d
commit 1eeb31065f
6 changed files with 21 additions and 81 deletions

View file

@ -553,18 +553,6 @@ class ManagedServer:
else:
prompt = request_kwargs.get("prompt")
# Reuse tracked context in non-tree mode when possible.
if (
not self.track_tree
and self.tokenizer is not None
and "input_ids" not in request_kwargs
and prompt is not None
):
extending_node = self._find_extending_node(prompt)
request_kwargs["input_ids"] = self._compute_input_ids(
prompt, extending_node
)
if not hasattr(self.server, "get_logprobs"):
raise NotImplementedError(
f"{self.server.__class__.__name__} does not implement get_logprobs. "
@ -572,41 +560,8 @@ class ManagedServer:
)
payload = await self.server.get_logprobs(**request_kwargs)
self._validate_prompt_logprob_payload(payload)
return payload
@staticmethod
def _validate_prompt_logprob_payload(payload: Dict[str, Any]) -> None:
required = ("prompt_tokens", "prompt_topk_token_ids", "prompt_topk_logprobs")
missing = [k for k in required if k not in payload]
if missing:
raise ValueError(f"get_logprobs response missing required keys: {missing}")
prompt_tokens = payload["prompt_tokens"]
token_ids = payload["prompt_topk_token_ids"]
logprobs = payload["prompt_topk_logprobs"]
if not isinstance(prompt_tokens, list):
raise ValueError("prompt_tokens must be a list[int].")
if not isinstance(token_ids, list) or not isinstance(logprobs, list):
raise ValueError(
"prompt_topk_token_ids and prompt_topk_logprobs must be list-of-list."
)
if len(token_ids) != len(prompt_tokens) or len(logprobs) != len(prompt_tokens):
raise ValueError("prompt_topk arrays must align with prompt_tokens length.")
for idx, (tok_row, lp_row) in enumerate(zip(token_ids, logprobs)):
if not isinstance(tok_row, list) or not isinstance(lp_row, list):
raise ValueError(
"prompt_topk_token_ids and prompt_topk_logprobs must be list-of-list."
)
if len(tok_row) != len(lp_row):
raise ValueError(
f"prompt_topk row mismatch at position {idx}: "
f"{len(tok_row)} token ids vs {len(lp_row)} logprobs."
)
class DummyManagedServer:
"""