add tool call parsing based on vllm impl and an openai server endpoint

This commit is contained in:
dmahan93 2026-03-02 23:17:13 -06:00
parent 887a94374c
commit add42a2afb
11 changed files with 3370 additions and 34 deletions

View file

@ -0,0 +1,623 @@
"""
OpenAI-compatible chat completions proxy over ManagedServer.
Exposes /{uuid}/v1/chat/completions and related endpoints so that external
environment microservices can interact with ManagedServer via standard
OpenAI API during multi-step rollouts.
Each UUID maps to a session containing a ManagedServer instance. Tool call
parsing uses vLLM's parsers directly. The ManagedServer always stores raw
text tool call translation only affects the HTTP wire format.
Uses ServerManager for routing across multiple backend servers (load balancing,
health checks, etc.) same infrastructure as the rest of atropos.
Usage:
# Standalone with JSON config
python -m atroposlib.envs.server_handling.managed_server_proxy \\
--config servers.json \\
--port 9100
# Or mount into existing FastAPI app
from atroposlib.envs.server_handling.managed_server_proxy import create_app
app = create_app(server_manager, tokenizer, model_name)
servers.json example:
{
"model_name": "Qwen/Qwen3-4B",
"servers": [
{"base_url": "http://gpu1:8000/v1", "server_type": "vllm", "api_key": ""},
{"base_url": "http://gpu2:8000/v1", "server_type": "vllm", "api_key": ""}
]
}
"""
import logging
import time
import uuid as uuid_lib
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from atroposlib.envs.server_handling.managed_server import (
DummyManagedServer,
ManagedServer,
)
from atroposlib.envs.server_handling.openai_server import OpenAIServer
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
from atroposlib.envs.server_handling.server_manager import ServerManager
from atroposlib.envs.server_handling.tool_call_translator import ToolCallTranslator
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Request / Response models
# ---------------------------------------------------------------------------
class ChatMessage(BaseModel):
role: str
content: Optional[str] = None
tool_calls: Optional[List[dict]] = None
tool_call_id: Optional[str] = None
name: Optional[str] = None
class ChatCompletionProxyRequest(BaseModel):
messages: List[ChatMessage]
max_tokens: int = 1024
temperature: float = 1.0
n: int = 1
stop: Optional[List[str]] = None
tools: Optional[List[dict]] = None
tool_choice: Optional[Any] = None # "auto", "none", "required", or dict
# TODO: top_p, frequency_penalty, presence_penalty, seed, response_format,
# logprobs, top_logprobs — pass through to backend when implemented
class SessionCreateRequest(BaseModel):
tool_parser: str = "hermes"
track_tree: bool = False
# Pin to a specific backend server by its base_url. In production,
# the caller gets their assigned server from the atropos API and
# passes it here. If omitted, falls back to picking the server
# with the most open semaphore slots (fine for dev/testing).
base_url: Optional[str] = None
class SessionCreateResponse(BaseModel):
uuid: str
model_name: str
tool_parser: str
base_url: Optional[str] = None # Which backend was selected
created_at: float
class RenderResponse(BaseModel):
prompt_text: str
token_ids: List[int]
num_tokens: int
# ---------------------------------------------------------------------------
# Session state
# ---------------------------------------------------------------------------
@dataclass
class SessionState:
uuid: str
managed_server: ManagedServer
translator: ToolCallTranslator
model_name: str
base_url: Optional[str] = None # Which backend server this session is pinned to
created_at: float = field(default_factory=time.time)
last_accessed: float = field(default_factory=time.time)
# ---------------------------------------------------------------------------
# OpenAI error format
# ---------------------------------------------------------------------------
def openai_error(
status_code: int, message: str, error_type: str = "invalid_request_error"
) -> JSONResponse:
return JSONResponse(
status_code=status_code,
content={
"error": {
"message": message,
"type": error_type,
"code": status_code,
}
},
)
# ---------------------------------------------------------------------------
# App factory
# ---------------------------------------------------------------------------
def create_app(
server_manager: ServerManager,
tokenizer: Any,
model_name: str = "unknown",
) -> FastAPI:
"""Create the proxy FastAPI app.
Args:
server_manager: ServerManager instance managing one or more backend
servers (VLLMServer, SGLangServer, etc.). Used to pick the most
available server when creating sessions.
tokenizer: HuggingFace tokenizer for the model.
model_name: Model name to report in responses.
Returns:
FastAPI app with all endpoints registered.
"""
app = FastAPI(title="ManagedServer OpenAI Proxy")
sessions: Dict[str, SessionState] = {}
# -- helpers --
def _get_session(session_uuid: str) -> SessionState:
session = sessions.get(session_uuid)
if session is None:
raise HTTPException(
status_code=404, detail=f"Session {session_uuid} not found"
)
session.last_accessed = time.time()
return session
def _get_server_base_url(server) -> Optional[str]:
"""Get the base_url from a server's config, if available."""
if hasattr(server, "config") and hasattr(server.config, "base_url"):
return server.config.base_url
return None
def _select_server(base_url: Optional[str] = None):
"""Pick a server from the manager.
Args:
base_url: If provided, pin to the server with this base_url.
Raises 404 if no server matches.
If None, picks the most available server (mirrors
ServerManager.managed_server() logic).
Returns:
Selected APIServer instance.
"""
if base_url is not None:
# Pin to specific server by base_url
for server in server_manager.servers:
server_url = _get_server_base_url(server)
if server_url and server_url.rstrip("/") == base_url.rstrip("/"):
return server
# No match — list available URLs in error
available = [
_get_server_base_url(s)
for s in server_manager.servers
if _get_server_base_url(s)
]
raise HTTPException(
status_code=404,
detail=f"No server with base_url '{base_url}'. Available: {available}",
)
# Auto-select most available
most_available_idx = 0
most_available_slots = -1
for i, server in enumerate(server_manager.servers):
if not server.server_healthy:
continue
if server.sem._value > most_available_slots:
most_available_idx = i
most_available_slots = server.sem._value
return server_manager.servers[most_available_idx]
def _render_prompt(
messages: List[Dict[str, Any]],
tools: Optional[List[dict]] = None,
translator: Optional[ToolCallTranslator] = None,
) -> str:
"""Render messages to prompt text via chat template.
If a translator is provided, converts OpenAI tool_call messages
back to raw text first.
"""
# Convert messages to dicts
msg_dicts = []
for msg in messages:
if isinstance(msg, BaseModel):
msg_dicts.append(msg.model_dump(exclude_none=True))
else:
msg_dicts.append(msg)
# Reconstruct raw text for assistant tool_call messages
if translator:
msg_dicts = translator.convert_messages_for_template(msg_dicts)
# Build kwargs for apply_chat_template
template_kwargs = {
"tokenize": False,
"add_generation_prompt": True,
}
if tools:
template_kwargs["tools"] = tools
return tokenizer.apply_chat_template(msg_dicts, **template_kwargs)
def _build_openai_response(
choices_data: List[dict],
model: str,
) -> dict:
"""Build an OpenAI ChatCompletion response dict."""
return {
"id": f"chatcmpl-{uuid_lib.uuid4().hex[:12]}",
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": choices_data,
}
# -- endpoints --
@app.get("/health")
async def health():
healthy_servers = sum(1 for s in server_manager.servers if s.server_healthy)
return {
"status": "ok",
"model": model_name,
"sessions": len(sessions),
"servers": len(server_manager.servers),
"healthy_servers": healthy_servers,
}
@app.get("/v1/models")
async def list_models():
return {
"object": "list",
"data": [
{
"id": model_name,
"object": "model",
"created": int(time.time()),
"owned_by": "atropos",
}
],
}
@app.post("/setup")
async def setup(request: Request):
"""Receive server configuration from ServerManager.
Accepts the same JSON format as the standalone --config file.
Replaces the current server_manager's servers with the new config.
Called by ServerManager at startup to push its config to the proxy.
Body: {"model_name": "...", "servers": [{"base_url": "...", "server_type": "vllm"}, ...]}
"""
config = await request.json()
new_configs = []
new_model = config.get("model_name", model_name)
for srv in config.get("servers", []):
new_configs.append(
APIServerConfig(
model_name=new_model,
base_url=srv["base_url"],
api_key=srv.get("api_key", ""),
server_type=srv.get("server_type", "vllm"),
num_max_requests_at_once=srv.get("num_max_requests_at_once", 512),
num_requests_for_eval=srv.get("num_requests_for_eval", 64),
timeout=srv.get("timeout", 1200),
tokenizer_name=config.get("tokenizer_name", "none"),
)
)
if new_configs:
new_manager = ServerManager(configs=new_configs)
server_manager.servers = new_manager.servers
logger.info("Setup: replaced servers with %d new configs", len(new_configs))
return {
"status": "ok",
"servers": len(server_manager.servers),
"model_name": new_model,
}
@app.get("/servers")
async def list_servers():
"""List available backend servers.
Useful for discovery/debugging. In production, server allocation
is managed by the atropos API the environment gets told which
server to use and passes that base_url to POST /sessions/create.
"""
server_list = []
for i, server in enumerate(server_manager.servers):
url = _get_server_base_url(server)
healthy = getattr(server, "server_healthy", True)
server_list.append(
{
"index": i,
"base_url": url,
"healthy": healthy,
"model_name": (
getattr(server.config, "model_name", model_name)
if hasattr(server, "config")
else model_name
),
"server_type": (
getattr(server.config, "server_type", "unknown")
if hasattr(server, "config")
else "unknown"
),
}
)
return {"servers": server_list}
@app.post("/sessions/create", response_model=SessionCreateResponse)
async def create_session(request: SessionCreateRequest):
session_uuid = str(uuid_lib.uuid4())
# Pick server — pinned to base_url if specified, otherwise most available
selected_server = _select_server(base_url=request.base_url)
selected_url = _get_server_base_url(selected_server)
# Use DummyManagedServer for OpenAI endpoints (no logprobs support)
if isinstance(selected_server, OpenAIServer):
logger.info(
"Session %s using DummyManagedServer (OpenAI endpoint). "
"Token IDs and logprobs will be placeholders.",
session_uuid,
)
managed = DummyManagedServer(
server=selected_server,
tokenizer=tokenizer,
track_tree=request.track_tree,
)
else:
managed = ManagedServer(
server=selected_server,
tokenizer=tokenizer,
track_tree=request.track_tree,
tool_parser=request.tool_parser,
)
# Translator kept for the render endpoint (prompt preview)
translator = ToolCallTranslator(
tokenizer=tokenizer,
parser_name=request.tool_parser,
)
session = SessionState(
uuid=session_uuid,
managed_server=managed,
translator=translator,
model_name=model_name,
base_url=selected_url,
)
sessions[session_uuid] = session
return SessionCreateResponse(
uuid=session_uuid,
model_name=model_name,
tool_parser=request.tool_parser,
base_url=selected_url,
created_at=session.created_at,
)
@app.get("/sessions")
async def list_sessions():
return {
"sessions": [
{
"uuid": s.uuid,
"model_name": s.model_name,
"base_url": s.base_url,
"created_at": s.created_at,
"last_accessed": s.last_accessed,
"num_nodes": len(
s.managed_server.current_nodes
if hasattr(s.managed_server, "current_nodes")
else s.managed_server.sequences
),
}
for s in sessions.values()
]
}
@app.post("/{session_uuid}/v1/chat/completions")
async def chat_completions(session_uuid: str, request: ChatCompletionProxyRequest):
session = _get_session(session_uuid)
managed = session.managed_server
if not request.messages:
return openai_error(400, "messages must not be empty")
# Convert pydantic messages to dicts
messages = [msg.model_dump(exclude_none=True) for msg in request.messages]
# Build kwargs — ManagedServer.chat_completion() handles all tool
# call logic internally (template rendering, inbound reconstruction,
# outbound parsing, skip_special_tokens)
completion_kwargs = {
"messages": messages,
"n": request.n,
"max_tokens": request.max_tokens,
"temperature": request.temperature,
}
if request.stop:
completion_kwargs["stop"] = request.stop
if request.tools:
completion_kwargs["tools"] = request.tools
if request.tool_choice is not None:
completion_kwargs["tool_choice"] = request.tool_choice
try:
result = await managed.chat_completion(**completion_kwargs)
except Exception as e:
logger.exception("Completion failed")
return openai_error(
500, f"Completion failed: {e}", error_type="server_error"
)
# Convert ChatCompletion to JSON-serializable response
choices = []
for choice in result.choices:
choice_data = {
"index": choice.index,
"message": {
"role": choice.message.role,
"content": choice.message.content,
},
"finish_reason": choice.finish_reason,
}
if choice.message.tool_calls:
choice_data["message"]["tool_calls"] = choice.message.tool_calls
choices.append(choice_data)
return _build_openai_response(choices, model_name)
@app.post("/{session_uuid}/v1/chat/completions/render")
async def render_prompt(session_uuid: str, request: ChatCompletionProxyRequest):
session = _get_session(session_uuid)
try:
prompt_text = _render_prompt(
messages=request.messages,
tools=request.tools,
translator=session.translator,
)
except Exception as e:
return openai_error(400, f"Failed to render prompt: {e}")
token_ids = tokenizer.encode(prompt_text, add_special_tokens=False)
return RenderResponse(
prompt_text=prompt_text,
token_ids=token_ids,
num_tokens=len(token_ids),
)
@app.get("/{session_uuid}/nodes")
async def get_nodes(session_uuid: str):
session = _get_session(session_uuid)
state = session.managed_server.get_state()
if session.managed_server.track_tree:
nodes = list(state.get("sequences", {}).values())
else:
nodes = state.get("nodes", [])
return {
"nodes": [node.model_dump() for node in nodes],
}
@app.delete("/{session_uuid}")
async def delete_session(session_uuid: str):
session = sessions.pop(session_uuid, None)
if session is None:
raise HTTPException(
status_code=404, detail=f"Session {session_uuid} not found"
)
session.managed_server.reset()
return {"status": "deleted", "uuid": session_uuid}
# -- exception handler for OpenAI-style errors --
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
return openai_error(exc.status_code, exc.detail)
# NOTE on concurrent access: each UUID is meant to represent a single
# rollout session used by one caller at a time. If you send concurrent
# requests to the same UUID, the ManagedServer's node extension logic
# may get confused because it does prefix matching on current_nodes.
# Don't do that. UUIDs are cheap, make a new one.
return app
# ---------------------------------------------------------------------------
# Standalone entrypoint
# ---------------------------------------------------------------------------
def main():
import argparse
import json
import uvicorn
from transformers import AutoTokenizer
parser = argparse.ArgumentParser(description="ManagedServer OpenAI Proxy")
parser.add_argument(
"--config",
required=True,
help="Path to JSON config file with server definitions",
)
parser.add_argument("--port", type=int, default=9100, help="Proxy port")
parser.add_argument("--host", default="0.0.0.0", help="Proxy host")
args = parser.parse_args()
# Load config
with open(args.config) as f:
config = json.load(f)
model_name = config["model_name"]
tokenizer = AutoTokenizer.from_pretrained(config.get("tokenizer_name", model_name))
# Build APIServerConfigs from the JSON
server_configs = []
for srv in config["servers"]:
server_configs.append(
APIServerConfig(
model_name=model_name,
base_url=srv["base_url"],
api_key=srv.get("api_key", ""),
server_type=srv.get("server_type", "vllm"),
num_max_requests_at_once=srv.get("num_max_requests_at_once", 512),
num_requests_for_eval=srv.get("num_requests_for_eval", 64),
timeout=srv.get("timeout", 1200),
tokenizer_name=config.get("tokenizer_name", "none"),
)
)
server_manager = ServerManager(configs=server_configs)
app = create_app(
server_manager=server_manager,
tokenizer=tokenizer,
model_name=model_name,
)
print(f"Starting ManagedServer OpenAI Proxy on {args.host}:{args.port}")
print(f" Model: {model_name}")
print(f" Backends: {len(server_configs)} server(s)")
for i, cfg in enumerate(server_configs):
print(f" [{i}] {cfg.base_url} ({cfg.server_type})")
print()
print("Endpoints:")
print(" POST /sessions/create")
print(" POST /{uuid}/v1/chat/completions")
print(" POST /{uuid}/v1/chat/completions/render")
print(" GET /{uuid}/nodes")
print(" DELETE /{uuid}")
print(" GET /sessions")
print(" GET /v1/models")
print(" GET /health")
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
if __name__ == "__main__":
main()