mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
next
This commit is contained in:
parent
4f33ab8bf4
commit
bb2736db4e
2 changed files with 44 additions and 0 deletions
|
|
@ -2,6 +2,7 @@
|
|||
# see example_trainer/vllm_api_server.py for an example
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
|
|
@ -19,6 +20,8 @@ from atroposlib.envs.server_handling.server_baseline import (
|
|||
ReasoningConfig,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VLLMServer(APIServer):
|
||||
"""
|
||||
|
|
@ -190,6 +193,14 @@ class VLLMServer(APIServer):
|
|||
# Prepare request for VLLM native API
|
||||
request_data = {"prompt": {"prompt_token_ids": prompt_tokens}, "logprobs": 0}
|
||||
request_data.update(kwargs)
|
||||
logger.info(
|
||||
"vllm_server completion POST start base_url=%s prompt_tokens=%s n=%s max_tokens=%s temperature=%s",
|
||||
self.config.base_url,
|
||||
len(prompt_tokens),
|
||||
request_data.get("n"),
|
||||
request_data.get("max_tokens"),
|
||||
request_data.get("temperature"),
|
||||
)
|
||||
|
||||
# Make async request to VLLM /generate endpoint
|
||||
async with aiohttp.ClientSession() as session:
|
||||
|
|
@ -205,6 +216,11 @@ class VLLMServer(APIServer):
|
|||
) as response:
|
||||
response.raise_for_status()
|
||||
results = await response.json()
|
||||
logger.info(
|
||||
"vllm_server completion POST done outputs=%s finish_reasons=%s",
|
||||
len(results.get("logprobs", [])),
|
||||
len(results.get("finish_reasons", [])),
|
||||
)
|
||||
output_tokens_list = []
|
||||
output_logprobs_list = []
|
||||
finish_reasons_list = []
|
||||
|
|
@ -314,6 +330,13 @@ class VLLMServer(APIServer):
|
|||
request_data["temperature"] = 0.0
|
||||
request_data["top_p"] = 1.0
|
||||
request_data.setdefault("max_tokens", 1)
|
||||
logger.info(
|
||||
"vllm_server get_logprobs POST start base_url=%s prompt_tokens=%s top_k=%s max_tokens=%s",
|
||||
self.config.base_url,
|
||||
len(prompt_tokens),
|
||||
top_k,
|
||||
request_data.get("max_tokens"),
|
||||
)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
|
|
@ -328,6 +351,10 @@ class VLLMServer(APIServer):
|
|||
) as response:
|
||||
response.raise_for_status()
|
||||
results = await response.json()
|
||||
logger.info(
|
||||
"vllm_server get_logprobs POST done prompt_logprobs_present=%s",
|
||||
results.get("prompt_logprobs") is not None,
|
||||
)
|
||||
|
||||
raw_prompt_logprobs = results.get("prompt_logprobs")
|
||||
if raw_prompt_logprobs is None:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue