mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-24 17:04:55 +00:00
prompt logprobs
This commit is contained in:
parent
e98100e5f6
commit
439b9b129b
7 changed files with 73 additions and 138 deletions
|
|
@ -348,19 +348,16 @@ class ServerManager:
|
|||
|
||||
async def get_logprobs(self, **kwargs) -> dict:
|
||||
"""
|
||||
Route normalized get_logprobs requests to the most available server.
|
||||
Route normalized prompt-logprob requests to the most available server.
|
||||
|
||||
Returns a normalized dict with:
|
||||
- prompt_tokens
|
||||
- sequence_token_ids
|
||||
- sequence_logprobs
|
||||
- sequence_topk_token_ids
|
||||
- sequence_topk_logprobs
|
||||
- finish_reasons
|
||||
- prompt_topk_token_ids
|
||||
- prompt_topk_logprobs
|
||||
"""
|
||||
n = kwargs.get("n", 1)
|
||||
if n > self.max_n_completions:
|
||||
# Split into multiple requests and merge sequence-level outputs.
|
||||
# Prompt logprobs are prompt-level; n-splitting does not change prompt arrays.
|
||||
results = []
|
||||
total_n = n
|
||||
while total_n > 0:
|
||||
|
|
@ -369,25 +366,7 @@ class ServerManager:
|
|||
results.append(self.get_logprobs(**kwargs))
|
||||
total_n -= n_to_use
|
||||
results = await asyncio.gather(*results)
|
||||
merged = {
|
||||
"prompt_tokens": results[0]["prompt_tokens"],
|
||||
"sequence_token_ids": [],
|
||||
"sequence_logprobs": [],
|
||||
"sequence_topk_token_ids": [],
|
||||
"sequence_topk_logprobs": [],
|
||||
"finish_reasons": [],
|
||||
}
|
||||
for result in results:
|
||||
merged["sequence_token_ids"].extend(result["sequence_token_ids"])
|
||||
merged["sequence_logprobs"].extend(result["sequence_logprobs"])
|
||||
merged["sequence_topk_token_ids"].extend(
|
||||
result["sequence_topk_token_ids"]
|
||||
)
|
||||
merged["sequence_topk_logprobs"].extend(
|
||||
result["sequence_topk_logprobs"]
|
||||
)
|
||||
merged["finish_reasons"].extend(result["finish_reasons"])
|
||||
return merged
|
||||
return results[0]
|
||||
|
||||
is_train = kwargs.pop("split", "train") == "train"
|
||||
most_available_server = 0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue