mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
add managed server to make grabbing logprobs easier w/ tokenized items
This commit is contained in:
parent
312f8859e3
commit
7bf4cfbf80
6 changed files with 1138 additions and 29 deletions
|
|
@ -148,14 +148,24 @@ class SGLangServer(APIServer):
|
|||
kwargs.get("model", None) is not None
|
||||
), "Model is required for completion!"
|
||||
assert (
|
||||
kwargs.get("prompt", None) is not None
|
||||
), "Prompt is required for completion!"
|
||||
kwargs.get("prompt", None) is not None or kwargs.get("input_ids", None) is not None
|
||||
), "Prompt or input_ids is required for completion!"
|
||||
|
||||
# Get n parameter for number of completions
|
||||
n = kwargs.get("n", 1)
|
||||
prompt_tokens = self.tokenizer.encode(kwargs.pop("prompt"))
|
||||
|
||||
# Use input_ids if provided (from ManagedServer), otherwise tokenize prompt
|
||||
if "input_ids" in kwargs:
|
||||
prompt_tokens = kwargs.pop("input_ids")
|
||||
kwargs.pop("prompt", None) # Remove prompt if it exists
|
||||
else:
|
||||
prompt_tokens = self.tokenizer.encode(kwargs.pop("prompt"))
|
||||
|
||||
# Check for double BOS token, can happen if you use chat templates and forget that they insert a BOS token
|
||||
if prompt_tokens[0] == self.tokenizer.bos_token_id == prompt_tokens[1]:
|
||||
if (
|
||||
len(prompt_tokens) >= 2
|
||||
and prompt_tokens[0] == self.tokenizer.bos_token_id == prompt_tokens[1]
|
||||
):
|
||||
prompt_tokens = prompt_tokens[1:]
|
||||
if "max_tokens" in kwargs:
|
||||
kwargs["max_new_tokens"] = kwargs.pop("max_tokens")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue