This commit is contained in:
Jai Suphavadeeprasit 2026-03-09 21:25:58 -04:00
parent 4f33ab8bf4
commit bb2736db4e
2 changed files with 44 additions and 0 deletions

View file

@ -2,6 +2,7 @@
# see example_trainer/vllm_api_server.py for an example
import asyncio
import logging
import warnings
from typing import Any, Dict, List, Tuple
@ -19,6 +20,8 @@ from atroposlib.envs.server_handling.server_baseline import (
ReasoningConfig,
)
logger = logging.getLogger(__name__)
class VLLMServer(APIServer):
"""
@ -190,6 +193,14 @@ class VLLMServer(APIServer):
# Prepare request for VLLM native API
request_data = {"prompt": {"prompt_token_ids": prompt_tokens}, "logprobs": 0}
request_data.update(kwargs)
logger.info(
"vllm_server completion POST start base_url=%s prompt_tokens=%s n=%s max_tokens=%s temperature=%s",
self.config.base_url,
len(prompt_tokens),
request_data.get("n"),
request_data.get("max_tokens"),
request_data.get("temperature"),
)
# Make async request to VLLM /generate endpoint
async with aiohttp.ClientSession() as session:
@ -205,6 +216,11 @@ class VLLMServer(APIServer):
) as response:
response.raise_for_status()
results = await response.json()
logger.info(
"vllm_server completion POST done outputs=%s finish_reasons=%s",
len(results.get("logprobs", [])),
len(results.get("finish_reasons", [])),
)
output_tokens_list = []
output_logprobs_list = []
finish_reasons_list = []
@ -314,6 +330,13 @@ class VLLMServer(APIServer):
request_data["temperature"] = 0.0
request_data["top_p"] = 1.0
request_data.setdefault("max_tokens", 1)
logger.info(
"vllm_server get_logprobs POST start base_url=%s prompt_tokens=%s top_k=%s max_tokens=%s",
self.config.base_url,
len(prompt_tokens),
top_k,
request_data.get("max_tokens"),
)
async with aiohttp.ClientSession() as session:
async with session.post(
@ -328,6 +351,10 @@ class VLLMServer(APIServer):
) as response:
response.raise_for_status()
results = await response.json()
logger.info(
"vllm_server get_logprobs POST done prompt_logprobs_present=%s",
results.get("prompt_logprobs") is not None,
)
raw_prompt_logprobs = results.get("prompt_logprobs")
if raw_prompt_logprobs is None: