add managed vllm server

This commit is contained in:
Dakota 2025-11-07 13:06:49 -06:00
parent 578175a709
commit e6ac3abdcb
9 changed files with 597 additions and 15 deletions

View file

@ -171,6 +171,8 @@ pre-commit install
You should edit the config_init section of the environment file you want ([For example, in GSM8K Environment](https://github.com/NousResearch/atropos/blob/main/environments/gsm8k_server.py#L53)) to point to a running VLLM or SGLang inference server as well as any other [configuration changes](CONFIG.md) you'd like to make, such as the group size, then:
> **Note:** By default, Atropos uses the OpenAI-compatible API endpoint which works with any provider. For enhanced features, use `VLLMServer` (atroposlib/envs/server_handling/vllm_server.py) or `SGLangServer` (atroposlib/envs/server_handling/sglang_server.py) for direct access to native APIs with full token and logprob tracking.
```bash
# Start the API server
run-api

View file

@ -181,6 +181,7 @@ These class-level variables in `BaseEnv` can be overridden in your subclass to c
* **`server_cls: Type[APIServer]`**:
* Default: `APIServer`
* Purpose: Specifies the class to be used for managing interactions with API servers (e.g., inference endpoints). Should mostly be used for developing additional API interfaces, but if you need a nonstandard way of connecting with an existing API you can use this to easily slot in any modifications you need.
* **Note:** In most cases, you should use the `server_type` field in your `APIServerConfig` instead of overriding this. Set `server_type` to `"openai"` (default), `"vllm"`, `"sglang"`, or `"trl"` to automatically use the appropriate server class with enhanced features like native API access and full token/logprob tracking.
## Provided Functionality

View file

@ -4,6 +4,8 @@
`ManagedServer` is a wrapper around `APIServer` that automatically tracks text sequences with aligned tokens and logprobs. It eliminates the need for manual token extraction, alignment, and masking in your environment code, making it **the recommended approach** for handling inference in Atropos environments.
**Server Compatibility:** ManagedServer works with all Atropos server types - `OpenAIServer`, `VLLMServer`, `SGLangServer`, and `TrlVllmServer`. Simply set the `server_type` field in your `APIServerConfig` to `"openai"` (default), `"vllm"`, `"sglang"`, or `"trl"` to use the appropriate backend with automatic server class selection.
### Why Use ManagedServer?
**Before ManagedServer** (manual approach):

View file

@ -108,8 +108,8 @@ class ServerBaseline(BaseModel):
rolling_buffer_length: int = Field(
default=1000, description="Length of the rolling buffer to store metrics."
)
server_type: Literal["openai", "trl", "sglang"] = Field(
default="openai", description="Type of server to use, openai or trl"
server_type: Literal["openai", "trl", "sglang", "vllm"] = Field(
default="openai", description="Type of server to use"
)

View file

@ -18,6 +18,7 @@ from atroposlib.envs.server_handling.server_baseline import (
from atroposlib.envs.server_handling.server_harness import ServerHarness
from atroposlib.envs.server_handling.sglang_server import SGLangServer
from atroposlib.envs.server_handling.trl_vllm_server import TrlVllmServer
from atroposlib.envs.server_handling.vllm_server import VLLMServer
class ServerManagerConfig(BaseModel):
@ -58,6 +59,8 @@ class ServerManager:
server_class = TrlVllmServer
elif configs.server_type == "sglang":
server_class = SGLangServer
elif configs.server_type == "vllm":
server_class = VLLMServer
else:
raise ValueError(f"Invalid server type: {configs.server_type}")
else:
@ -67,6 +70,8 @@ class ServerManager:
server_class = TrlVllmServer
elif configs[0].server_type == "sglang":
server_class = SGLangServer
elif configs[0].server_type == "vllm":
server_class = VLLMServer
else:
raise ValueError(f"Invalid server type: {configs[0].server_type}")
if testing:
@ -198,7 +203,7 @@ class ServerManager:
for completion in completions[1:]:
out.choices.extend(completion.choices)
return out
is_train = kwargs.get("split", "train") == "train"
is_train = kwargs.pop("split", "train") == "train"
most_available_server = 0
most_available_server_num_slots = -1
await self.wait_for_sem(is_train)
@ -231,7 +236,7 @@ class ServerManager:
for completion in completions[1:]:
out.choices.extend(completion.choices)
return out
is_train = kwargs.get("split", "train") == "train"
is_train = kwargs.pop("split", "train") == "train"
most_available_server = 0
most_available_server_num_slots = -1
await self.wait_for_sem(is_train)
@ -276,7 +281,7 @@ class ServerManager:
finish_reasons.extend(out_finish_reasons)
return (prompt_tokens, output_tokens, output_logprobs, finish_reasons)
is_train = kwargs.get("split", "train") == "train"
is_train = kwargs.pop("split", "train") == "train"
most_available_server = 0
most_available_server_num_slots = -1
await self.wait_for_sem(is_train)

View file

@ -0,0 +1,344 @@
# This requires a customized vLLM api server
# see example_trainer/vllm_api_server.py for an example
import asyncio
import warnings
import aiohttp
import openai
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.completion import Completion
from pydantic_cli import FailedExecutionException
from transformers import AutoTokenizer
from atroposlib.envs.constants import NAMESPACE_SEP, OPENAI_NAMESPACE
from atroposlib.envs.server_handling.server_baseline import APIServer, APIServerConfig
class VLLMServer(APIServer):
"""
VLLM server handling.
"""
def __init__(self, config: APIServerConfig):
self.openai = openai.AsyncClient(
api_key=config.api_key,
base_url=config.base_url,
timeout=config.timeout,
)
self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
super().__init__(config)
async def check_server_status_task(self, chat_completion: bool = True):
while True:
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.config.base_url.replace('/v1', '')}/health",
headers=(
{"Authorization": f"Bearer {self.config.api_key}"}
if self.config.api_key
else {}
),
timeout=aiohttp.ClientTimeout(total=self.config.timeout),
) as response:
response.raise_for_status()
self.server_healthy = True
except (
aiohttp.ClientError,
openai.OpenAIError,
openai.APITimeoutError,
Exception,
):
self.server_healthy = False
await asyncio.sleep(1)
async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion:
"""
Wrapper for the chat completion using the openai client.
"""
assert (
kwargs.get("model", None) is not None
), "Model is required for chat completion!"
assert (
kwargs.get("messages", None) is not None
), "Messages are required for chat completion!"
if self.config.n_kwarg_is_ignored:
n = kwargs.pop("n", 1)
completion_list = await asyncio.gather(
*[self.openai.chat.completions.create(**kwargs) for _ in range(n)]
)
completions = completion_list[0]
if n > 1:
for c in completion_list[1:]:
completions.choices.extend(c.choices)
else:
completions = await self.openai.chat.completions.create(**kwargs)
else:
if "n" in kwargs:
n = kwargs["n"]
else:
n = 1
completions = await self.openai.chat.completions.create(**kwargs)
if len(completions.choices) != n:
if len(completions.choices) != 1:
raise ValueError(
f"Expected 1 or {n} completions, got {len(completions.choices)}!"
)
else:
warnings.warn("n kwarg is ignored by the API, setting to True")
self.config.n_kwarg_is_ignored = True
completion_list = await asyncio.gather(
*[
self.openai.chat.completions.create(**kwargs)
for _ in range(1, n)
]
)
for c in completion_list:
completions.choices.extend(c.choices)
return completions
async def _completion_wrapper(self, **kwargs) -> Completion:
"""
Wrapper for the completion using the openai client.
"""
assert (
kwargs.get("model", None) is not None
), "Model is required for completion!"
assert (
kwargs.get("prompt", None) is not None
), "Prompt is required for completion!"
if self.config.n_kwarg_is_ignored:
n = kwargs.pop("n", 1)
completion_list = await asyncio.gather(
*[self.openai.completions.create(**kwargs) for _ in range(n)]
)
completions = completion_list[0]
if n > 1:
for c in completion_list[1:]:
completions.choices.extend(c.choices)
else:
if "n" in kwargs:
n = kwargs["n"]
else:
n = 1
completions = await self.openai.completions.create(**kwargs)
if len(completions.choices) != n:
if len(completions.choices) != 1:
raise ValueError(
f"Expected 1 or {n} completions, got {len(completions.choices)}!"
)
else:
warnings.warn("n kwarg is ignored by the API, setting to True")
self.config.n_kwarg_is_ignored = True
completion_list = await asyncio.gather(
*[self.openai.completions.create(**kwargs) for _ in range(1, n)]
)
for c in completion_list:
completions.choices.extend(c.choices)
return completions
async def _tokens_and_logprobs_completion_wrapper(
self, **kwargs
) -> tuple[list, list, list, list]:
"""
Wrapper for tokens and logprobs completion using VLLM's native API.
Returns a tuple of (prompt_tokens, output_tokens, output_logprobs, finish_reasons).
Each element is a list of lists (one per completion in the batch).
"""
assert (
kwargs.get("model", None) is not None
), "Model is required for completion!"
assert (
kwargs.get("prompt", None) is not None
or kwargs.get("input_ids", None) is not None
), "Prompt or input_ids is required for completion!"
# Use input_ids if provided (from ManagedServer), otherwise tokenize prompt
if "input_ids" in kwargs:
prompt_tokens = kwargs.pop("input_ids")
kwargs.pop("prompt", None) # Remove prompt if it exists
else:
prompt_tokens = self.tokenizer.encode(kwargs.pop("prompt"))
# Check for double BOS token, can happen if you use chat templates and forget that they insert a BOS token
if (
len(prompt_tokens) >= 2
and prompt_tokens[0] == self.tokenizer.bos_token_id == prompt_tokens[1]
):
prompt_tokens = prompt_tokens[1:]
if "max_new_tokens" in kwargs:
kwargs["max_tokens"] = kwargs.pop("max_new_tokens")
if "model" in kwargs:
kwargs.pop("model")
# Prepare request for VLLM native API
request_data = {"prompt": {"prompt_token_ids": prompt_tokens}, "logprobs": 0}
request_data.update(kwargs)
# Make async request to VLLM /generate endpoint
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.config.base_url.replace('/v1', '')}/generate",
json=request_data,
headers=(
{"Authorization": f"Bearer {self.config.api_key}"}
if self.config.api_key
else {}
),
timeout=aiohttp.ClientTimeout(total=self.config.timeout),
) as response:
response.raise_for_status()
results = await response.json()
output_tokens_list = []
output_logprobs_list = []
finish_reasons_list = []
for output_token_logprobs, finish_reason in zip(
results["logprobs"], results["finish_reasons"]
):
logprobs = [
list(item[0].values())[0] for item in output_token_logprobs
] # Extract logprob from [{id: logprob}]
output_ids = [
int(list(item[0].keys())[0]) for item in output_token_logprobs
] # Extract token ID from [{id: logprob}]
# Get finish reason
finish_reason = finish_reason
output_tokens_list.append(output_ids)
output_logprobs_list.append(logprobs)
finish_reasons_list.append(finish_reason)
return (
prompt_tokens,
output_tokens_list,
output_logprobs_list,
finish_reasons_list,
)
def resolve_openai_configs(
default_server_configs,
openai_config_dict,
yaml_config,
cli_passed_flags,
logger,
):
"""
Helper to resolve the final server_configs, handling single, multiple servers, and overrides.
"""
from atroposlib.envs.server_handling.server_manager import ServerBaseline
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
openai_yaml_config = yaml_config.get(OPENAI_NAMESPACE, None)
openai_cli_config = {
k: v for k, v in cli_passed_flags.items() if k.startswith(openai_full_prefix)
}
is_multi_server_yaml = (
isinstance(openai_yaml_config, list) and len(openai_yaml_config) >= 2
)
is_multi_server_default = (
(not is_multi_server_yaml)
and isinstance(default_server_configs, list)
and len(default_server_configs) >= 2
)
if (is_multi_server_yaml or is_multi_server_default) and openai_cli_config:
raise FailedExecutionException(
message=f"CLI overrides for OpenAI settings (--{openai_full_prefix}*) are not supported "
f"when multiple servers are defined (either via YAML list under '{OPENAI_NAMESPACE}' "
"or a default list with length >= 2).",
exit_code=2,
)
if is_multi_server_yaml:
logger.info(
f"Using multi-server configuration defined in YAML under '{OPENAI_NAMESPACE}'."
)
try:
server_configs = [APIServerConfig(**cfg) for cfg in openai_yaml_config]
except Exception as e:
raise FailedExecutionException(
f"Error parsing multi-server OpenAI configuration from YAML under '{OPENAI_NAMESPACE}': {e}"
) from e
elif isinstance(default_server_configs, ServerBaseline):
logger.info("Using ServerBaseline configuration.")
server_configs = default_server_configs
elif is_multi_server_default:
logger.info("Using default multi-server configuration (length >= 2).")
server_configs = default_server_configs
else:
logger.info(
"Using single OpenAI server configuration based on merged settings (default/YAML/CLI)."
)
try:
final_openai_config = APIServerConfig(**openai_config_dict)
except Exception as e:
raise FailedExecutionException(
f"Error creating final OpenAI configuration from merged settings: {e}\n"
f"Merged Dict: {openai_config_dict}"
) from e
if isinstance(default_server_configs, APIServerConfig):
server_configs = final_openai_config
elif isinstance(default_server_configs, list):
server_configs = [final_openai_config]
else:
logger.warning(
f"Unexpected type for default_server_configs: {type(default_server_configs)}. "
f"Proceeding with single OpenAI server configuration based on merged settings."
)
server_configs = [final_openai_config]
return server_configs
if __name__ == "__main__":
async def test_tokens_and_logprobs():
# Configure the server - update these values for your setup
config = APIServerConfig(
api_key="", # Add your API key if needed
base_url="http://localhost:8000", # Update to your VLLM server URL
model_name="Qwen/Qwen2.5-7B", # Update to your model name
timeout=120,
)
server = VLLMServer(config)
# Test the tokens_and_logprobs_completion method
print("Testing tokens_and_logprobs_completion...")
try:
prompt_tokens, output_tokens, output_logprobs, finish_reasons = (
await server.tokens_and_logprobs_completion(
prompt="The capital of France is",
n=4,
max_tokens=32,
temperature=1.0,
top_p=1.0,
stop=["User:", "Human:", "Assistant:", "</answer>"],
)
)
print("\nResults:")
print(f"Prompt tokens: {prompt_tokens}")
print(f"Output tokens: {output_tokens}")
print(f"Output logprobs (first 5): {[lp[:5] for lp in output_logprobs]}")
print(f"Finish reasons: {finish_reasons}")
print(f"\nNumber of completions: {len(output_tokens)}")
print(f"Output length: {[len(tokens) for tokens in output_tokens]}")
responses = "\n\n".join(
[server.tokenizer.decode(tokens) for tokens in output_tokens]
)
print(f"Responses:\n-{responses}")
print(f"Health: {server.server_healthy}")
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()
# Run the test
asyncio.run(test_tokens_and_logprobs())

View file

@ -155,7 +155,7 @@ class MathEnv(BaseEnv):
server_configs = ServerBaseline(
model_name="Qwen/Qwen2.5-7B",
num_requests_for_eval=256, # since evaling only on one...
server_type="sglang",
server_type="vllm",
)
return env_config, server_configs
@ -255,15 +255,15 @@ class MathEnv(BaseEnv):
return
async def rollout_and_score_eval(self, question, answer, subset):
completion = await self.server.completion(
prompt=question,
n=1,
max_tokens=32765,
temperature=0.0,
split="eval",
stop=stop_list,
)
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
completion = await managed.completion(
prompt=question,
n=1,
max_tokens=32765,
temperature=0.0,
split="eval",
stop=stop_list,
)
loop = asyncio.get_event_loop()
gold = "\\boxed{" + answer + "}" if "\\boxed" not in answer else answer
resp = completion.choices[0].text

View file

@ -8,6 +8,10 @@ This example uses `vLLM` for efficient inference during the (simulated) data gen
**Note:** This script is intended as a *reference example* for API integration and basic training setup. It is not optimized for large-scale, efficient training.
### Custom vLLM Server
The `vllm_api_server.py` file in this directory provides a customized vLLM API server implementation based on vLLM's native API. This server exposes enhanced endpoints for token and logprob tracking. The `VLLMServer` class in `atroposlib/envs/server_handling/vllm_server.py` can connect to this server for direct access to vLLM's `/generate` endpoint with full token-level logprobs.
## Prerequisites
1. **Python:** Python 3.8 or higher is recommended.

View file

@ -0,0 +1,224 @@
# Based on https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/api_server.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
NOTE: This API server is used only for demonstrating usage of AsyncEngine
and simple performance benchmarks. It is not intended for production use.
For production use, we recommend using our OpenAI compatible server.
We are also not going to accept PRs modifying this file, please
change `vllm/entrypoints/openai/api_server.py` instead.
"""
import asyncio
import json
import ssl
from argparse import Namespace
from collections.abc import AsyncGenerator
from typing import Any
import vllm.envs as envs
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.utils import with_cancellation
from vllm.logger import init_logger
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.usage.usage_lib import UsageContext
from vllm.utils import random_uuid
from vllm.v1.engine.async_llm import AsyncLLMEngine
try:
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.system_utils import set_ulimit
except ImportError:
from vllm.utils import FlexibleArgumentParser, set_ulimit
from vllm.outputs import RequestOutput # noqa: F401
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger("vllm.entrypoints.api_server")
app = FastAPI()
engine = None
@app.get("/health")
async def health() -> Response:
"""Health check."""
return Response(status_code=200)
@app.get("/health_generate")
async def health_generate() -> Response:
"""
Check the health of the inference server by sending a special request to generate one token.
"""
assert engine is not None
sampling_params = SamplingParams()
request_id = random_uuid()
results_generator = engine.generate(
{"prompt_token_ids": [0]}, sampling_params, request_id
)
try:
async for request_output in results_generator:
final_output = request_output # type: RequestOutput # noqa: F841
except asyncio.CancelledError:
return Response(status_code=499)
return Response(status_code=200)
@app.post("/generate")
async def generate(request: Request) -> Response:
"""Generate completion for the request.
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- stream: whether to stream the results or not.
- other fields: the sampling parameters (See `SamplingParams` for details).
"""
request_dict = await request.json()
return await _generate(request_dict, raw_request=request)
@with_cancellation
async def _generate(request_dict: dict, raw_request: Request) -> Response:
prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", False)
request_dict["output_kind"] = RequestOutputKind.FINAL_ONLY
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()
assert engine is not None
results_generator = engine.generate(prompt, sampling_params, request_id)
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator:
prompt = request_output.prompt
assert prompt is not None
text_outputs = [prompt + output.text for output in request_output.outputs]
ret = {"text": text_outputs}
yield (json.dumps(ret) + "\n").encode("utf-8")
if stream:
return StreamingResponse(stream_results())
# Non-streaming case
final_output = None
try:
async for request_output in results_generator:
final_output = request_output # type: RequestOutput
except asyncio.CancelledError:
return Response(status_code=499)
assert final_output is not None
prompt = final_output.prompt or engine.tokenizer.decode(
final_output.prompt_token_ids
)
assert prompt is not None
text_outputs = [output.text for output in final_output.outputs]
finish_reasons = [output.finish_reason for output in final_output.outputs]
ret = {"text": text_outputs, "prompt": prompt, "finish_reasons": finish_reasons}
if sampling_params.logprobs is not None:
output_logprobs = [
[
[{key: value.logprob for key, value in logprob.items()}]
for logprob in x.logprobs
]
for x in final_output.outputs
]
prompt_token_ids = final_output.prompt_token_ids
output_token_ids = [x.token_ids for x in final_output.outputs]
ret["logprobs"] = output_logprobs
ret["prompt_token_ids"] = prompt_token_ids
ret["token_ids"] = output_token_ids
return JSONResponse(ret)
def build_app(args: Namespace) -> FastAPI:
global app # noqa: F824
app.root_path = args.root_path
return app
async def init_app(
args: Namespace,
llm_engine: AsyncLLMEngine | None = None,
) -> FastAPI:
app = build_app(args)
global engine
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = (
llm_engine
if llm_engine is not None
else AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.API_SERVER
)
)
app.state.engine_client = engine
return app
async def run_server(
args: Namespace, llm_engine: AsyncLLMEngine | None = None, **uvicorn_kwargs: Any
) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)
set_ulimit()
app = await init_app(args, llm_engine)
assert engine is not None
shutdown_task = await serve_http(
app,
sock=None,
enable_ssl_refresh=args.enable_ssl_refresh,
host=args.host,
port=args.port,
log_level=args.log_level,
timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs,
**uvicorn_kwargs,
)
await shutdown_task
if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=parser.check_port, default=8000)
parser.add_argument("--ssl-keyfile", type=str, default=None)
parser.add_argument("--ssl-certfile", type=str, default=None)
parser.add_argument(
"--ssl-ca-certs", type=str, default=None, help="The CA certificates file"
)
parser.add_argument(
"--enable-ssl-refresh",
action="store_true",
default=False,
help="Refresh SSL Context when SSL certificate files change",
)
parser.add_argument(
"--ssl-cert-reqs",
type=int,
default=int(ssl.CERT_NONE),
help="Whether client certificate is required (see stdlib ssl module's)",
)
parser.add_argument(
"--root-path",
type=str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy",
)
parser.add_argument("--log-level", type=str, default="debug")
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
asyncio.run(run_server(args))