mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
init commit
This commit is contained in:
parent
887a94374c
commit
b9291aa29f
5 changed files with 357 additions and 0 deletions
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
import asyncio
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import aiohttp
|
||||
import openai
|
||||
|
|
@ -231,6 +232,143 @@ class VLLMServer(APIServer):
|
|||
finish_reasons_list,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_topk_entry(
|
||||
token_logprobs_entry: Any,
|
||||
) -> Tuple[List[int], List[float]]:
|
||||
"""
|
||||
Normalize a single token-position logprob payload into parallel top-k arrays.
|
||||
|
||||
Supports common structures from vLLM responses:
|
||||
- dict: {token_id: logprob, ...}
|
||||
- list[dict]: [{token_id: logprob}, ...]
|
||||
"""
|
||||
if isinstance(token_logprobs_entry, dict):
|
||||
items = list(token_logprobs_entry.items())
|
||||
return [int(k) for k, _ in items], [float(v) for _, v in items]
|
||||
|
||||
if isinstance(token_logprobs_entry, list):
|
||||
token_ids: List[int] = []
|
||||
logprobs: List[float] = []
|
||||
for item in token_logprobs_entry:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
for key, value in item.items():
|
||||
token_ids.append(int(key))
|
||||
logprobs.append(float(value))
|
||||
return token_ids, logprobs
|
||||
|
||||
return [], []
|
||||
|
||||
async def get_logprobs(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch normalized logprobs from vLLM /generate with optional top-k.
|
||||
|
||||
Args:
|
||||
top_k / top_logprobs: Optional number of logprobs per position.
|
||||
Defaults to 1.
|
||||
prompt or input_ids: Input text or token IDs.
|
||||
|
||||
Returns:
|
||||
Normalized dict:
|
||||
- prompt_tokens
|
||||
- sequence_token_ids
|
||||
- sequence_logprobs
|
||||
- sequence_topk_token_ids
|
||||
- sequence_topk_logprobs
|
||||
- finish_reasons
|
||||
"""
|
||||
assert (
|
||||
kwargs.get("prompt", None) is not None
|
||||
or kwargs.get("input_ids", None) is not None
|
||||
), "Prompt or input_ids is required for get_logprobs!"
|
||||
|
||||
top_k = int(kwargs.pop("top_k", kwargs.pop("top_logprobs", 1)))
|
||||
top_k = max(1, top_k)
|
||||
|
||||
# Use input_ids if provided (from ManagedServer), otherwise tokenize prompt
|
||||
if "input_ids" in kwargs:
|
||||
prompt_tokens = kwargs.pop("input_ids")
|
||||
kwargs.pop("prompt", None)
|
||||
else:
|
||||
prompt_tokens = self.tokenizer.encode(kwargs.pop("prompt"))
|
||||
|
||||
# Check for double BOS token.
|
||||
if (
|
||||
len(prompt_tokens) >= 2
|
||||
and prompt_tokens[0] == self.tokenizer.bos_token_id == prompt_tokens[1]
|
||||
):
|
||||
prompt_tokens = prompt_tokens[1:]
|
||||
|
||||
if "max_new_tokens" in kwargs:
|
||||
kwargs["max_tokens"] = kwargs.pop("max_new_tokens")
|
||||
if "max_completion_tokens" in kwargs:
|
||||
kwargs["max_tokens"] = kwargs.pop("max_completion_tokens")
|
||||
kwargs.pop("model", None)
|
||||
|
||||
request_data = {
|
||||
"prompt": {"prompt_token_ids": prompt_tokens},
|
||||
"logprobs": top_k,
|
||||
}
|
||||
request_data.update(kwargs)
|
||||
|
||||
# Keep semaphore behavior consistent with other server calls.
|
||||
split = request_data.pop("split", "train")
|
||||
sem = self.sem if split == "train" else self.eval_sem
|
||||
while not self.server_healthy:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async with sem:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.config.base_url.replace('/v1', '')}/generate",
|
||||
json=request_data,
|
||||
headers=(
|
||||
{"Authorization": f"Bearer {self.config.api_key}"}
|
||||
if self.config.api_key
|
||||
else {}
|
||||
),
|
||||
timeout=aiohttp.ClientTimeout(total=self.config.timeout),
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
results = await response.json()
|
||||
|
||||
sequence_topk_token_ids: List[List[List[int]]] = []
|
||||
sequence_topk_logprobs: List[List[List[float]]] = []
|
||||
sequence_token_ids: List[List[int]] = []
|
||||
sequence_logprobs: List[List[float]] = []
|
||||
finish_reasons: List[Any] = []
|
||||
|
||||
for token_logprobs_seq, finish_reason in zip(
|
||||
results["logprobs"], results["finish_reasons"]
|
||||
):
|
||||
seq_topk_token_ids: List[List[int]] = []
|
||||
seq_topk_logprobs: List[List[float]] = []
|
||||
seq_token_ids: List[int] = []
|
||||
seq_logprobs: List[float] = []
|
||||
|
||||
for token_logprobs_entry in token_logprobs_seq:
|
||||
topk_ids, topk_lps = self._normalize_topk_entry(token_logprobs_entry)
|
||||
seq_topk_token_ids.append(topk_ids)
|
||||
seq_topk_logprobs.append(topk_lps)
|
||||
seq_token_ids.append(topk_ids[0] if topk_ids else -1)
|
||||
seq_logprobs.append(topk_lps[0] if topk_lps else 0.0)
|
||||
|
||||
sequence_topk_token_ids.append(seq_topk_token_ids)
|
||||
sequence_topk_logprobs.append(seq_topk_logprobs)
|
||||
sequence_token_ids.append(seq_token_ids)
|
||||
sequence_logprobs.append(seq_logprobs)
|
||||
finish_reasons.append(finish_reason)
|
||||
|
||||
return {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"sequence_token_ids": sequence_token_ids,
|
||||
"sequence_logprobs": sequence_logprobs,
|
||||
"sequence_topk_token_ids": sequence_topk_token_ids,
|
||||
"sequence_topk_logprobs": sequence_topk_logprobs,
|
||||
"finish_reasons": finish_reasons,
|
||||
}
|
||||
|
||||
|
||||
def resolve_openai_configs(
|
||||
default_server_configs,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue