add managed server to make grabbing logprobs easier w/ tokenized items

This commit is contained in:
dmahan93 2025-10-24 13:09:46 -07:00
parent 312f8859e3
commit 7bf4cfbf80
6 changed files with 1138 additions and 29 deletions

View file

@ -148,14 +148,24 @@ class SGLangServer(APIServer):
kwargs.get("model", None) is not None
), "Model is required for completion!"
assert (
kwargs.get("prompt", None) is not None
), "Prompt is required for completion!"
kwargs.get("prompt", None) is not None or kwargs.get("input_ids", None) is not None
), "Prompt or input_ids is required for completion!"
# Get n parameter for number of completions
n = kwargs.get("n", 1)
prompt_tokens = self.tokenizer.encode(kwargs.pop("prompt"))
# Use input_ids if provided (from ManagedServer), otherwise tokenize prompt
if "input_ids" in kwargs:
prompt_tokens = kwargs.pop("input_ids")
kwargs.pop("prompt", None) # Remove prompt if it exists
else:
prompt_tokens = self.tokenizer.encode(kwargs.pop("prompt"))
# Check for double BOS token, can happen if you use chat templates and forget that they insert a BOS token
if prompt_tokens[0] == self.tokenizer.bos_token_id == prompt_tokens[1]:
if (
len(prompt_tokens) >= 2
and prompt_tokens[0] == self.tokenizer.bos_token_id == prompt_tokens[1]
):
prompt_tokens = prompt_tokens[1:]
if "max_tokens" in kwargs:
kwargs["max_new_tokens"] = kwargs.pop("max_tokens")