atropos/atroposlib/envs/server_handling/managed_server_proxy.py

623 lines
21 KiB
Python

"""
OpenAI-compatible chat completions proxy over ManagedServer.
Exposes /{uuid}/v1/chat/completions and related endpoints so that external
environment microservices can interact with ManagedServer via standard
OpenAI API during multi-step rollouts.
Each UUID maps to a session containing a ManagedServer instance. Tool call
parsing uses vLLM's parsers directly. The ManagedServer always stores raw
text — tool call translation only affects the HTTP wire format.
Uses ServerManager for routing across multiple backend servers (load balancing,
health checks, etc.) — same infrastructure as the rest of atropos.
Usage:
# Standalone with JSON config
python -m atroposlib.envs.server_handling.managed_server_proxy \\
--config servers.json \\
--port 9100
# Or mount into existing FastAPI app
from atroposlib.envs.server_handling.managed_server_proxy import create_app
app = create_app(server_manager, tokenizer, model_name)
servers.json example:
{
"model_name": "Qwen/Qwen3-4B",
"servers": [
{"base_url": "http://gpu1:8000/v1", "server_type": "vllm", "api_key": ""},
{"base_url": "http://gpu2:8000/v1", "server_type": "vllm", "api_key": ""}
]
}
"""
import logging
import time
import uuid as uuid_lib
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from atroposlib.envs.server_handling.managed_server import (
DummyManagedServer,
ManagedServer,
)
from atroposlib.envs.server_handling.openai_server import OpenAIServer
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
from atroposlib.envs.server_handling.server_manager import ServerManager
from atroposlib.envs.server_handling.tool_call_translator import ToolCallTranslator
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Request / Response models
# ---------------------------------------------------------------------------
class ChatMessage(BaseModel):
role: str
content: Optional[str] = None
tool_calls: Optional[List[dict]] = None
tool_call_id: Optional[str] = None
name: Optional[str] = None
class ChatCompletionProxyRequest(BaseModel):
messages: List[ChatMessage]
max_tokens: int = 1024
temperature: float = 1.0
n: int = 1
stop: Optional[List[str]] = None
tools: Optional[List[dict]] = None
tool_choice: Optional[Any] = None # "auto", "none", "required", or dict
# TODO: top_p, frequency_penalty, presence_penalty, seed, response_format,
# logprobs, top_logprobs — pass through to backend when implemented
class SessionCreateRequest(BaseModel):
tool_parser: str = "hermes"
track_tree: bool = False
# Pin to a specific backend server by its base_url. In production,
# the caller gets their assigned server from the atropos API and
# passes it here. If omitted, falls back to picking the server
# with the most open semaphore slots (fine for dev/testing).
base_url: Optional[str] = None
class SessionCreateResponse(BaseModel):
uuid: str
model_name: str
tool_parser: str
base_url: Optional[str] = None # Which backend was selected
created_at: float
class RenderResponse(BaseModel):
prompt_text: str
token_ids: List[int]
num_tokens: int
# ---------------------------------------------------------------------------
# Session state
# ---------------------------------------------------------------------------
@dataclass
class SessionState:
uuid: str
managed_server: ManagedServer
translator: ToolCallTranslator
model_name: str
base_url: Optional[str] = None # Which backend server this session is pinned to
created_at: float = field(default_factory=time.time)
last_accessed: float = field(default_factory=time.time)
# ---------------------------------------------------------------------------
# OpenAI error format
# ---------------------------------------------------------------------------
def openai_error(
status_code: int, message: str, error_type: str = "invalid_request_error"
) -> JSONResponse:
return JSONResponse(
status_code=status_code,
content={
"error": {
"message": message,
"type": error_type,
"code": status_code,
}
},
)
# ---------------------------------------------------------------------------
# App factory
# ---------------------------------------------------------------------------
def create_app(
server_manager: ServerManager,
tokenizer: Any,
model_name: str = "unknown",
) -> FastAPI:
"""Create the proxy FastAPI app.
Args:
server_manager: ServerManager instance managing one or more backend
servers (VLLMServer, SGLangServer, etc.). Used to pick the most
available server when creating sessions.
tokenizer: HuggingFace tokenizer for the model.
model_name: Model name to report in responses.
Returns:
FastAPI app with all endpoints registered.
"""
app = FastAPI(title="ManagedServer OpenAI Proxy")
sessions: Dict[str, SessionState] = {}
# -- helpers --
def _get_session(session_uuid: str) -> SessionState:
session = sessions.get(session_uuid)
if session is None:
raise HTTPException(
status_code=404, detail=f"Session {session_uuid} not found"
)
session.last_accessed = time.time()
return session
def _get_server_base_url(server) -> Optional[str]:
"""Get the base_url from a server's config, if available."""
if hasattr(server, "config") and hasattr(server.config, "base_url"):
return server.config.base_url
return None
def _select_server(base_url: Optional[str] = None):
"""Pick a server from the manager.
Args:
base_url: If provided, pin to the server with this base_url.
Raises 404 if no server matches.
If None, picks the most available server (mirrors
ServerManager.managed_server() logic).
Returns:
Selected APIServer instance.
"""
if base_url is not None:
# Pin to specific server by base_url
for server in server_manager.servers:
server_url = _get_server_base_url(server)
if server_url and server_url.rstrip("/") == base_url.rstrip("/"):
return server
# No match — list available URLs in error
available = [
_get_server_base_url(s)
for s in server_manager.servers
if _get_server_base_url(s)
]
raise HTTPException(
status_code=404,
detail=f"No server with base_url '{base_url}'. Available: {available}",
)
# Auto-select most available
most_available_idx = 0
most_available_slots = -1
for i, server in enumerate(server_manager.servers):
if not server.server_healthy:
continue
if server.sem._value > most_available_slots:
most_available_idx = i
most_available_slots = server.sem._value
return server_manager.servers[most_available_idx]
def _render_prompt(
messages: List[Dict[str, Any]],
tools: Optional[List[dict]] = None,
translator: Optional[ToolCallTranslator] = None,
) -> str:
"""Render messages to prompt text via chat template.
If a translator is provided, converts OpenAI tool_call messages
back to raw text first.
"""
# Convert messages to dicts
msg_dicts = []
for msg in messages:
if isinstance(msg, BaseModel):
msg_dicts.append(msg.model_dump(exclude_none=True))
else:
msg_dicts.append(msg)
# Reconstruct raw text for assistant tool_call messages
if translator:
msg_dicts = translator.convert_messages_for_template(msg_dicts)
# Build kwargs for apply_chat_template
template_kwargs = {
"tokenize": False,
"add_generation_prompt": True,
}
if tools:
template_kwargs["tools"] = tools
return tokenizer.apply_chat_template(msg_dicts, **template_kwargs)
def _build_openai_response(
choices_data: List[dict],
model: str,
) -> dict:
"""Build an OpenAI ChatCompletion response dict."""
return {
"id": f"chatcmpl-{uuid_lib.uuid4().hex[:12]}",
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": choices_data,
}
# -- endpoints --
@app.get("/health")
async def health():
healthy_servers = sum(1 for s in server_manager.servers if s.server_healthy)
return {
"status": "ok",
"model": model_name,
"sessions": len(sessions),
"servers": len(server_manager.servers),
"healthy_servers": healthy_servers,
}
@app.get("/v1/models")
async def list_models():
return {
"object": "list",
"data": [
{
"id": model_name,
"object": "model",
"created": int(time.time()),
"owned_by": "atropos",
}
],
}
@app.post("/setup")
async def setup(request: Request):
"""Receive server configuration from ServerManager.
Accepts the same JSON format as the standalone --config file.
Replaces the current server_manager's servers with the new config.
Called by ServerManager at startup to push its config to the proxy.
Body: {"model_name": "...", "servers": [{"base_url": "...", "server_type": "vllm"}, ...]}
"""
config = await request.json()
new_configs = []
new_model = config.get("model_name", model_name)
for srv in config.get("servers", []):
new_configs.append(
APIServerConfig(
model_name=new_model,
base_url=srv["base_url"],
api_key=srv.get("api_key", ""),
server_type=srv.get("server_type", "vllm"),
num_max_requests_at_once=srv.get("num_max_requests_at_once", 512),
num_requests_for_eval=srv.get("num_requests_for_eval", 64),
timeout=srv.get("timeout", 1200),
tokenizer_name=config.get("tokenizer_name", "none"),
)
)
if new_configs:
new_manager = ServerManager(configs=new_configs)
server_manager.servers = new_manager.servers
logger.info("Setup: replaced servers with %d new configs", len(new_configs))
return {
"status": "ok",
"servers": len(server_manager.servers),
"model_name": new_model,
}
@app.get("/servers")
async def list_servers():
"""List available backend servers.
Useful for discovery/debugging. In production, server allocation
is managed by the atropos API — the environment gets told which
server to use and passes that base_url to POST /sessions/create.
"""
server_list = []
for i, server in enumerate(server_manager.servers):
url = _get_server_base_url(server)
healthy = getattr(server, "server_healthy", True)
server_list.append(
{
"index": i,
"base_url": url,
"healthy": healthy,
"model_name": (
getattr(server.config, "model_name", model_name)
if hasattr(server, "config")
else model_name
),
"server_type": (
getattr(server.config, "server_type", "unknown")
if hasattr(server, "config")
else "unknown"
),
}
)
return {"servers": server_list}
@app.post("/sessions/create", response_model=SessionCreateResponse)
async def create_session(request: SessionCreateRequest):
session_uuid = str(uuid_lib.uuid4())
# Pick server — pinned to base_url if specified, otherwise most available
selected_server = _select_server(base_url=request.base_url)
selected_url = _get_server_base_url(selected_server)
# Use DummyManagedServer for OpenAI endpoints (no logprobs support)
if isinstance(selected_server, OpenAIServer):
logger.info(
"Session %s using DummyManagedServer (OpenAI endpoint). "
"Token IDs and logprobs will be placeholders.",
session_uuid,
)
managed = DummyManagedServer(
server=selected_server,
tokenizer=tokenizer,
track_tree=request.track_tree,
)
else:
managed = ManagedServer(
server=selected_server,
tokenizer=tokenizer,
track_tree=request.track_tree,
tool_parser=request.tool_parser,
)
# Translator kept for the render endpoint (prompt preview)
translator = ToolCallTranslator(
tokenizer=tokenizer,
parser_name=request.tool_parser,
)
session = SessionState(
uuid=session_uuid,
managed_server=managed,
translator=translator,
model_name=model_name,
base_url=selected_url,
)
sessions[session_uuid] = session
return SessionCreateResponse(
uuid=session_uuid,
model_name=model_name,
tool_parser=request.tool_parser,
base_url=selected_url,
created_at=session.created_at,
)
@app.get("/sessions")
async def list_sessions():
return {
"sessions": [
{
"uuid": s.uuid,
"model_name": s.model_name,
"base_url": s.base_url,
"created_at": s.created_at,
"last_accessed": s.last_accessed,
"num_nodes": len(
s.managed_server.current_nodes
if hasattr(s.managed_server, "current_nodes")
else s.managed_server.sequences
),
}
for s in sessions.values()
]
}
@app.post("/{session_uuid}/v1/chat/completions")
async def chat_completions(session_uuid: str, request: ChatCompletionProxyRequest):
session = _get_session(session_uuid)
managed = session.managed_server
if not request.messages:
return openai_error(400, "messages must not be empty")
# Convert pydantic messages to dicts
messages = [msg.model_dump(exclude_none=True) for msg in request.messages]
# Build kwargs — ManagedServer.chat_completion() handles all tool
# call logic internally (template rendering, inbound reconstruction,
# outbound parsing, skip_special_tokens)
completion_kwargs = {
"messages": messages,
"n": request.n,
"max_tokens": request.max_tokens,
"temperature": request.temperature,
}
if request.stop:
completion_kwargs["stop"] = request.stop
if request.tools:
completion_kwargs["tools"] = request.tools
if request.tool_choice is not None:
completion_kwargs["tool_choice"] = request.tool_choice
try:
result = await managed.chat_completion(**completion_kwargs)
except Exception as e:
logger.exception("Completion failed")
return openai_error(
500, f"Completion failed: {e}", error_type="server_error"
)
# Convert ChatCompletion to JSON-serializable response
choices = []
for choice in result.choices:
choice_data = {
"index": choice.index,
"message": {
"role": choice.message.role,
"content": choice.message.content,
},
"finish_reason": choice.finish_reason,
}
if choice.message.tool_calls:
choice_data["message"]["tool_calls"] = choice.message.tool_calls
choices.append(choice_data)
return _build_openai_response(choices, model_name)
@app.post("/{session_uuid}/v1/chat/completions/render")
async def render_prompt(session_uuid: str, request: ChatCompletionProxyRequest):
session = _get_session(session_uuid)
try:
prompt_text = _render_prompt(
messages=request.messages,
tools=request.tools,
translator=session.translator,
)
except Exception as e:
return openai_error(400, f"Failed to render prompt: {e}")
token_ids = tokenizer.encode(prompt_text, add_special_tokens=False)
return RenderResponse(
prompt_text=prompt_text,
token_ids=token_ids,
num_tokens=len(token_ids),
)
@app.get("/{session_uuid}/nodes")
async def get_nodes(session_uuid: str):
session = _get_session(session_uuid)
state = session.managed_server.get_state()
if session.managed_server.track_tree:
nodes = list(state.get("sequences", {}).values())
else:
nodes = state.get("nodes", [])
return {
"nodes": [node.model_dump() for node in nodes],
}
@app.delete("/{session_uuid}")
async def delete_session(session_uuid: str):
session = sessions.pop(session_uuid, None)
if session is None:
raise HTTPException(
status_code=404, detail=f"Session {session_uuid} not found"
)
session.managed_server.reset()
return {"status": "deleted", "uuid": session_uuid}
# -- exception handler for OpenAI-style errors --
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
return openai_error(exc.status_code, exc.detail)
# NOTE on concurrent access: each UUID is meant to represent a single
# rollout session used by one caller at a time. If you send concurrent
# requests to the same UUID, the ManagedServer's node extension logic
# may get confused because it does prefix matching on current_nodes.
# Don't do that. UUIDs are cheap, make a new one.
return app
# ---------------------------------------------------------------------------
# Standalone entrypoint
# ---------------------------------------------------------------------------
def main():
import argparse
import json
import uvicorn
from transformers import AutoTokenizer
parser = argparse.ArgumentParser(description="ManagedServer OpenAI Proxy")
parser.add_argument(
"--config",
required=True,
help="Path to JSON config file with server definitions",
)
parser.add_argument("--port", type=int, default=9100, help="Proxy port")
parser.add_argument("--host", default="0.0.0.0", help="Proxy host")
args = parser.parse_args()
# Load config
with open(args.config) as f:
config = json.load(f)
model_name = config["model_name"]
tokenizer = AutoTokenizer.from_pretrained(config.get("tokenizer_name", model_name))
# Build APIServerConfigs from the JSON
server_configs = []
for srv in config["servers"]:
server_configs.append(
APIServerConfig(
model_name=model_name,
base_url=srv["base_url"],
api_key=srv.get("api_key", ""),
server_type=srv.get("server_type", "vllm"),
num_max_requests_at_once=srv.get("num_max_requests_at_once", 512),
num_requests_for_eval=srv.get("num_requests_for_eval", 64),
timeout=srv.get("timeout", 1200),
tokenizer_name=config.get("tokenizer_name", "none"),
)
)
server_manager = ServerManager(configs=server_configs)
app = create_app(
server_manager=server_manager,
tokenizer=tokenizer,
model_name=model_name,
)
print(f"Starting ManagedServer OpenAI Proxy on {args.host}:{args.port}")
print(f" Model: {model_name}")
print(f" Backends: {len(server_configs)} server(s)")
for i, cfg in enumerate(server_configs):
print(f" [{i}] {cfg.base_url} ({cfg.server_type})")
print()
print("Endpoints:")
print(" POST /sessions/create")
print(" POST /{uuid}/v1/chat/completions")
print(" POST /{uuid}/v1/chat/completions/render")
print(" GET /{uuid}/nodes")
print(" DELETE /{uuid}")
print(" GET /sessions")
print(" GET /v1/models")
print(" GET /health")
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
if __name__ == "__main__":
main()