mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-28 17:29:30 +00:00
next
This commit is contained in:
parent
4f33ab8bf4
commit
bb2736db4e
2 changed files with 44 additions and 0 deletions
|
|
@ -447,14 +447,31 @@ class ManagedServer:
|
|||
if not self.track_tree and self.tokenizer is not None:
|
||||
input_ids = self._compute_input_ids(prompt, extending_node)
|
||||
completion_kwargs["input_ids"] = input_ids
|
||||
logger.info(
|
||||
"managed_server chat_completion prepared input_ids=%s extending=%s",
|
||||
len(input_ids),
|
||||
extending_node is not None,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"managed_server chat_completion using prompt passthrough track_tree=%s tokenizer=%s",
|
||||
self.track_tree,
|
||||
self.tokenizer is not None,
|
||||
)
|
||||
|
||||
# Call the tokens and logprobs wrapper directly
|
||||
logger.info("managed_server chat_completion calling backend completion wrapper")
|
||||
(
|
||||
prompt_tokens,
|
||||
output_tokens_list,
|
||||
output_logprobs_list,
|
||||
finish_reasons,
|
||||
) = await self.server.tokens_and_logprobs_completion(**completion_kwargs)
|
||||
logger.info(
|
||||
"managed_server chat_completion backend returned prompt_tokens=%s outputs=%s",
|
||||
len(prompt_tokens),
|
||||
len(output_tokens_list),
|
||||
)
|
||||
|
||||
# Track each completion and build choices
|
||||
n = len(output_tokens_list)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
# see example_trainer/vllm_api_server.py for an example
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
|
|
@ -19,6 +20,8 @@ from atroposlib.envs.server_handling.server_baseline import (
|
|||
ReasoningConfig,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VLLMServer(APIServer):
|
||||
"""
|
||||
|
|
@ -190,6 +193,14 @@ class VLLMServer(APIServer):
|
|||
# Prepare request for VLLM native API
|
||||
request_data = {"prompt": {"prompt_token_ids": prompt_tokens}, "logprobs": 0}
|
||||
request_data.update(kwargs)
|
||||
logger.info(
|
||||
"vllm_server completion POST start base_url=%s prompt_tokens=%s n=%s max_tokens=%s temperature=%s",
|
||||
self.config.base_url,
|
||||
len(prompt_tokens),
|
||||
request_data.get("n"),
|
||||
request_data.get("max_tokens"),
|
||||
request_data.get("temperature"),
|
||||
)
|
||||
|
||||
# Make async request to VLLM /generate endpoint
|
||||
async with aiohttp.ClientSession() as session:
|
||||
|
|
@ -205,6 +216,11 @@ class VLLMServer(APIServer):
|
|||
) as response:
|
||||
response.raise_for_status()
|
||||
results = await response.json()
|
||||
logger.info(
|
||||
"vllm_server completion POST done outputs=%s finish_reasons=%s",
|
||||
len(results.get("logprobs", [])),
|
||||
len(results.get("finish_reasons", [])),
|
||||
)
|
||||
output_tokens_list = []
|
||||
output_logprobs_list = []
|
||||
finish_reasons_list = []
|
||||
|
|
@ -314,6 +330,13 @@ class VLLMServer(APIServer):
|
|||
request_data["temperature"] = 0.0
|
||||
request_data["top_p"] = 1.0
|
||||
request_data.setdefault("max_tokens", 1)
|
||||
logger.info(
|
||||
"vllm_server get_logprobs POST start base_url=%s prompt_tokens=%s top_k=%s max_tokens=%s",
|
||||
self.config.base_url,
|
||||
len(prompt_tokens),
|
||||
top_k,
|
||||
request_data.get("max_tokens"),
|
||||
)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
|
|
@ -328,6 +351,10 @@ class VLLMServer(APIServer):
|
|||
) as response:
|
||||
response.raise_for_status()
|
||||
results = await response.json()
|
||||
logger.info(
|
||||
"vllm_server get_logprobs POST done prompt_logprobs_present=%s",
|
||||
results.get("prompt_logprobs") is not None,
|
||||
)
|
||||
|
||||
raw_prompt_logprobs = results.get("prompt_logprobs")
|
||||
if raw_prompt_logprobs is None:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue