mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
623 lines
21 KiB
Python
623 lines
21 KiB
Python
"""
|
|
OpenAI-compatible chat completions proxy over ManagedServer.
|
|
|
|
Exposes /{uuid}/v1/chat/completions and related endpoints so that external
|
|
environment microservices can interact with ManagedServer via standard
|
|
OpenAI API during multi-step rollouts.
|
|
|
|
Each UUID maps to a session containing a ManagedServer instance. Tool call
|
|
parsing uses vLLM's parsers directly. The ManagedServer always stores raw
|
|
text — tool call translation only affects the HTTP wire format.
|
|
|
|
Uses ServerManager for routing across multiple backend servers (load balancing,
|
|
health checks, etc.) — same infrastructure as the rest of atropos.
|
|
|
|
Usage:
|
|
# Standalone with JSON config
|
|
python -m atroposlib.envs.server_handling.managed_server_proxy \\
|
|
--config servers.json \\
|
|
--port 9100
|
|
|
|
# Or mount into existing FastAPI app
|
|
from atroposlib.envs.server_handling.managed_server_proxy import create_app
|
|
app = create_app(server_manager, tokenizer, model_name)
|
|
|
|
servers.json example:
|
|
{
|
|
"model_name": "Qwen/Qwen3-4B",
|
|
"servers": [
|
|
{"base_url": "http://gpu1:8000/v1", "server_type": "vllm", "api_key": ""},
|
|
{"base_url": "http://gpu2:8000/v1", "server_type": "vllm", "api_key": ""}
|
|
]
|
|
}
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
import uuid as uuid_lib
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from fastapi.responses import JSONResponse
|
|
from pydantic import BaseModel
|
|
|
|
from atroposlib.envs.server_handling.managed_server import (
|
|
DummyManagedServer,
|
|
ManagedServer,
|
|
)
|
|
from atroposlib.envs.server_handling.openai_server import OpenAIServer
|
|
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
|
|
from atroposlib.envs.server_handling.server_manager import ServerManager
|
|
from atroposlib.envs.server_handling.tool_call_translator import ToolCallTranslator
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Request / Response models
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class ChatMessage(BaseModel):
|
|
role: str
|
|
content: Optional[str] = None
|
|
tool_calls: Optional[List[dict]] = None
|
|
tool_call_id: Optional[str] = None
|
|
name: Optional[str] = None
|
|
|
|
|
|
class ChatCompletionProxyRequest(BaseModel):
|
|
messages: List[ChatMessage]
|
|
max_tokens: int = 1024
|
|
temperature: float = 1.0
|
|
n: int = 1
|
|
stop: Optional[List[str]] = None
|
|
tools: Optional[List[dict]] = None
|
|
tool_choice: Optional[Any] = None # "auto", "none", "required", or dict
|
|
# TODO: top_p, frequency_penalty, presence_penalty, seed, response_format,
|
|
# logprobs, top_logprobs — pass through to backend when implemented
|
|
|
|
|
|
class SessionCreateRequest(BaseModel):
|
|
tool_parser: str = "hermes"
|
|
track_tree: bool = False
|
|
# Pin to a specific backend server by its base_url. In production,
|
|
# the caller gets their assigned server from the atropos API and
|
|
# passes it here. If omitted, falls back to picking the server
|
|
# with the most open semaphore slots (fine for dev/testing).
|
|
base_url: Optional[str] = None
|
|
|
|
|
|
class SessionCreateResponse(BaseModel):
|
|
uuid: str
|
|
model_name: str
|
|
tool_parser: str
|
|
base_url: Optional[str] = None # Which backend was selected
|
|
created_at: float
|
|
|
|
|
|
class RenderResponse(BaseModel):
|
|
prompt_text: str
|
|
token_ids: List[int]
|
|
num_tokens: int
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Session state
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@dataclass
|
|
class SessionState:
|
|
uuid: str
|
|
managed_server: ManagedServer
|
|
translator: ToolCallTranslator
|
|
model_name: str
|
|
base_url: Optional[str] = None # Which backend server this session is pinned to
|
|
created_at: float = field(default_factory=time.time)
|
|
last_accessed: float = field(default_factory=time.time)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# OpenAI error format
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def openai_error(
|
|
status_code: int, message: str, error_type: str = "invalid_request_error"
|
|
) -> JSONResponse:
|
|
return JSONResponse(
|
|
status_code=status_code,
|
|
content={
|
|
"error": {
|
|
"message": message,
|
|
"type": error_type,
|
|
"code": status_code,
|
|
}
|
|
},
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# App factory
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def create_app(
|
|
server_manager: ServerManager,
|
|
tokenizer: Any,
|
|
model_name: str = "unknown",
|
|
) -> FastAPI:
|
|
"""Create the proxy FastAPI app.
|
|
|
|
Args:
|
|
server_manager: ServerManager instance managing one or more backend
|
|
servers (VLLMServer, SGLangServer, etc.). Used to pick the most
|
|
available server when creating sessions.
|
|
tokenizer: HuggingFace tokenizer for the model.
|
|
model_name: Model name to report in responses.
|
|
|
|
Returns:
|
|
FastAPI app with all endpoints registered.
|
|
"""
|
|
|
|
app = FastAPI(title="ManagedServer OpenAI Proxy")
|
|
sessions: Dict[str, SessionState] = {}
|
|
|
|
# -- helpers --
|
|
|
|
def _get_session(session_uuid: str) -> SessionState:
|
|
session = sessions.get(session_uuid)
|
|
if session is None:
|
|
raise HTTPException(
|
|
status_code=404, detail=f"Session {session_uuid} not found"
|
|
)
|
|
session.last_accessed = time.time()
|
|
return session
|
|
|
|
def _get_server_base_url(server) -> Optional[str]:
|
|
"""Get the base_url from a server's config, if available."""
|
|
if hasattr(server, "config") and hasattr(server.config, "base_url"):
|
|
return server.config.base_url
|
|
return None
|
|
|
|
def _select_server(base_url: Optional[str] = None):
|
|
"""Pick a server from the manager.
|
|
|
|
Args:
|
|
base_url: If provided, pin to the server with this base_url.
|
|
Raises 404 if no server matches.
|
|
If None, picks the most available server (mirrors
|
|
ServerManager.managed_server() logic).
|
|
|
|
Returns:
|
|
Selected APIServer instance.
|
|
"""
|
|
if base_url is not None:
|
|
# Pin to specific server by base_url
|
|
for server in server_manager.servers:
|
|
server_url = _get_server_base_url(server)
|
|
if server_url and server_url.rstrip("/") == base_url.rstrip("/"):
|
|
return server
|
|
# No match — list available URLs in error
|
|
available = [
|
|
_get_server_base_url(s)
|
|
for s in server_manager.servers
|
|
if _get_server_base_url(s)
|
|
]
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail=f"No server with base_url '{base_url}'. Available: {available}",
|
|
)
|
|
|
|
# Auto-select most available
|
|
most_available_idx = 0
|
|
most_available_slots = -1
|
|
for i, server in enumerate(server_manager.servers):
|
|
if not server.server_healthy:
|
|
continue
|
|
if server.sem._value > most_available_slots:
|
|
most_available_idx = i
|
|
most_available_slots = server.sem._value
|
|
return server_manager.servers[most_available_idx]
|
|
|
|
def _render_prompt(
|
|
messages: List[Dict[str, Any]],
|
|
tools: Optional[List[dict]] = None,
|
|
translator: Optional[ToolCallTranslator] = None,
|
|
) -> str:
|
|
"""Render messages to prompt text via chat template.
|
|
|
|
If a translator is provided, converts OpenAI tool_call messages
|
|
back to raw text first.
|
|
"""
|
|
# Convert messages to dicts
|
|
msg_dicts = []
|
|
for msg in messages:
|
|
if isinstance(msg, BaseModel):
|
|
msg_dicts.append(msg.model_dump(exclude_none=True))
|
|
else:
|
|
msg_dicts.append(msg)
|
|
|
|
# Reconstruct raw text for assistant tool_call messages
|
|
if translator:
|
|
msg_dicts = translator.convert_messages_for_template(msg_dicts)
|
|
|
|
# Build kwargs for apply_chat_template
|
|
template_kwargs = {
|
|
"tokenize": False,
|
|
"add_generation_prompt": True,
|
|
}
|
|
if tools:
|
|
template_kwargs["tools"] = tools
|
|
|
|
return tokenizer.apply_chat_template(msg_dicts, **template_kwargs)
|
|
|
|
def _build_openai_response(
|
|
choices_data: List[dict],
|
|
model: str,
|
|
) -> dict:
|
|
"""Build an OpenAI ChatCompletion response dict."""
|
|
return {
|
|
"id": f"chatcmpl-{uuid_lib.uuid4().hex[:12]}",
|
|
"object": "chat.completion",
|
|
"created": int(time.time()),
|
|
"model": model,
|
|
"choices": choices_data,
|
|
}
|
|
|
|
# -- endpoints --
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
healthy_servers = sum(1 for s in server_manager.servers if s.server_healthy)
|
|
return {
|
|
"status": "ok",
|
|
"model": model_name,
|
|
"sessions": len(sessions),
|
|
"servers": len(server_manager.servers),
|
|
"healthy_servers": healthy_servers,
|
|
}
|
|
|
|
@app.get("/v1/models")
|
|
async def list_models():
|
|
return {
|
|
"object": "list",
|
|
"data": [
|
|
{
|
|
"id": model_name,
|
|
"object": "model",
|
|
"created": int(time.time()),
|
|
"owned_by": "atropos",
|
|
}
|
|
],
|
|
}
|
|
|
|
@app.post("/setup")
|
|
async def setup(request: Request):
|
|
"""Receive server configuration from ServerManager.
|
|
|
|
Accepts the same JSON format as the standalone --config file.
|
|
Replaces the current server_manager's servers with the new config.
|
|
Called by ServerManager at startup to push its config to the proxy.
|
|
|
|
Body: {"model_name": "...", "servers": [{"base_url": "...", "server_type": "vllm"}, ...]}
|
|
"""
|
|
config = await request.json()
|
|
|
|
new_configs = []
|
|
new_model = config.get("model_name", model_name)
|
|
for srv in config.get("servers", []):
|
|
new_configs.append(
|
|
APIServerConfig(
|
|
model_name=new_model,
|
|
base_url=srv["base_url"],
|
|
api_key=srv.get("api_key", ""),
|
|
server_type=srv.get("server_type", "vllm"),
|
|
num_max_requests_at_once=srv.get("num_max_requests_at_once", 512),
|
|
num_requests_for_eval=srv.get("num_requests_for_eval", 64),
|
|
timeout=srv.get("timeout", 1200),
|
|
tokenizer_name=config.get("tokenizer_name", "none"),
|
|
)
|
|
)
|
|
|
|
if new_configs:
|
|
new_manager = ServerManager(configs=new_configs)
|
|
server_manager.servers = new_manager.servers
|
|
logger.info("Setup: replaced servers with %d new configs", len(new_configs))
|
|
|
|
return {
|
|
"status": "ok",
|
|
"servers": len(server_manager.servers),
|
|
"model_name": new_model,
|
|
}
|
|
|
|
@app.get("/servers")
|
|
async def list_servers():
|
|
"""List available backend servers.
|
|
|
|
Useful for discovery/debugging. In production, server allocation
|
|
is managed by the atropos API — the environment gets told which
|
|
server to use and passes that base_url to POST /sessions/create.
|
|
"""
|
|
server_list = []
|
|
for i, server in enumerate(server_manager.servers):
|
|
url = _get_server_base_url(server)
|
|
healthy = getattr(server, "server_healthy", True)
|
|
server_list.append(
|
|
{
|
|
"index": i,
|
|
"base_url": url,
|
|
"healthy": healthy,
|
|
"model_name": (
|
|
getattr(server.config, "model_name", model_name)
|
|
if hasattr(server, "config")
|
|
else model_name
|
|
),
|
|
"server_type": (
|
|
getattr(server.config, "server_type", "unknown")
|
|
if hasattr(server, "config")
|
|
else "unknown"
|
|
),
|
|
}
|
|
)
|
|
return {"servers": server_list}
|
|
|
|
@app.post("/sessions/create", response_model=SessionCreateResponse)
|
|
async def create_session(request: SessionCreateRequest):
|
|
session_uuid = str(uuid_lib.uuid4())
|
|
|
|
# Pick server — pinned to base_url if specified, otherwise most available
|
|
selected_server = _select_server(base_url=request.base_url)
|
|
selected_url = _get_server_base_url(selected_server)
|
|
|
|
# Use DummyManagedServer for OpenAI endpoints (no logprobs support)
|
|
if isinstance(selected_server, OpenAIServer):
|
|
logger.info(
|
|
"Session %s using DummyManagedServer (OpenAI endpoint). "
|
|
"Token IDs and logprobs will be placeholders.",
|
|
session_uuid,
|
|
)
|
|
managed = DummyManagedServer(
|
|
server=selected_server,
|
|
tokenizer=tokenizer,
|
|
track_tree=request.track_tree,
|
|
)
|
|
else:
|
|
managed = ManagedServer(
|
|
server=selected_server,
|
|
tokenizer=tokenizer,
|
|
track_tree=request.track_tree,
|
|
tool_parser=request.tool_parser,
|
|
)
|
|
|
|
# Translator kept for the render endpoint (prompt preview)
|
|
translator = ToolCallTranslator(
|
|
tokenizer=tokenizer,
|
|
parser_name=request.tool_parser,
|
|
)
|
|
session = SessionState(
|
|
uuid=session_uuid,
|
|
managed_server=managed,
|
|
translator=translator,
|
|
model_name=model_name,
|
|
base_url=selected_url,
|
|
)
|
|
sessions[session_uuid] = session
|
|
|
|
return SessionCreateResponse(
|
|
uuid=session_uuid,
|
|
model_name=model_name,
|
|
tool_parser=request.tool_parser,
|
|
base_url=selected_url,
|
|
created_at=session.created_at,
|
|
)
|
|
|
|
@app.get("/sessions")
|
|
async def list_sessions():
|
|
return {
|
|
"sessions": [
|
|
{
|
|
"uuid": s.uuid,
|
|
"model_name": s.model_name,
|
|
"base_url": s.base_url,
|
|
"created_at": s.created_at,
|
|
"last_accessed": s.last_accessed,
|
|
"num_nodes": len(
|
|
s.managed_server.current_nodes
|
|
if hasattr(s.managed_server, "current_nodes")
|
|
else s.managed_server.sequences
|
|
),
|
|
}
|
|
for s in sessions.values()
|
|
]
|
|
}
|
|
|
|
@app.post("/{session_uuid}/v1/chat/completions")
|
|
async def chat_completions(session_uuid: str, request: ChatCompletionProxyRequest):
|
|
session = _get_session(session_uuid)
|
|
managed = session.managed_server
|
|
|
|
if not request.messages:
|
|
return openai_error(400, "messages must not be empty")
|
|
|
|
# Convert pydantic messages to dicts
|
|
messages = [msg.model_dump(exclude_none=True) for msg in request.messages]
|
|
|
|
# Build kwargs — ManagedServer.chat_completion() handles all tool
|
|
# call logic internally (template rendering, inbound reconstruction,
|
|
# outbound parsing, skip_special_tokens)
|
|
completion_kwargs = {
|
|
"messages": messages,
|
|
"n": request.n,
|
|
"max_tokens": request.max_tokens,
|
|
"temperature": request.temperature,
|
|
}
|
|
if request.stop:
|
|
completion_kwargs["stop"] = request.stop
|
|
if request.tools:
|
|
completion_kwargs["tools"] = request.tools
|
|
if request.tool_choice is not None:
|
|
completion_kwargs["tool_choice"] = request.tool_choice
|
|
|
|
try:
|
|
result = await managed.chat_completion(**completion_kwargs)
|
|
except Exception as e:
|
|
logger.exception("Completion failed")
|
|
return openai_error(
|
|
500, f"Completion failed: {e}", error_type="server_error"
|
|
)
|
|
|
|
# Convert ChatCompletion to JSON-serializable response
|
|
choices = []
|
|
for choice in result.choices:
|
|
choice_data = {
|
|
"index": choice.index,
|
|
"message": {
|
|
"role": choice.message.role,
|
|
"content": choice.message.content,
|
|
},
|
|
"finish_reason": choice.finish_reason,
|
|
}
|
|
if choice.message.tool_calls:
|
|
choice_data["message"]["tool_calls"] = choice.message.tool_calls
|
|
|
|
choices.append(choice_data)
|
|
|
|
return _build_openai_response(choices, model_name)
|
|
|
|
@app.post("/{session_uuid}/v1/chat/completions/render")
|
|
async def render_prompt(session_uuid: str, request: ChatCompletionProxyRequest):
|
|
session = _get_session(session_uuid)
|
|
|
|
try:
|
|
prompt_text = _render_prompt(
|
|
messages=request.messages,
|
|
tools=request.tools,
|
|
translator=session.translator,
|
|
)
|
|
except Exception as e:
|
|
return openai_error(400, f"Failed to render prompt: {e}")
|
|
|
|
token_ids = tokenizer.encode(prompt_text, add_special_tokens=False)
|
|
|
|
return RenderResponse(
|
|
prompt_text=prompt_text,
|
|
token_ids=token_ids,
|
|
num_tokens=len(token_ids),
|
|
)
|
|
|
|
@app.get("/{session_uuid}/nodes")
|
|
async def get_nodes(session_uuid: str):
|
|
session = _get_session(session_uuid)
|
|
state = session.managed_server.get_state()
|
|
|
|
if session.managed_server.track_tree:
|
|
nodes = list(state.get("sequences", {}).values())
|
|
else:
|
|
nodes = state.get("nodes", [])
|
|
|
|
return {
|
|
"nodes": [node.model_dump() for node in nodes],
|
|
}
|
|
|
|
@app.delete("/{session_uuid}")
|
|
async def delete_session(session_uuid: str):
|
|
session = sessions.pop(session_uuid, None)
|
|
if session is None:
|
|
raise HTTPException(
|
|
status_code=404, detail=f"Session {session_uuid} not found"
|
|
)
|
|
session.managed_server.reset()
|
|
return {"status": "deleted", "uuid": session_uuid}
|
|
|
|
# -- exception handler for OpenAI-style errors --
|
|
|
|
@app.exception_handler(HTTPException)
|
|
async def http_exception_handler(request: Request, exc: HTTPException):
|
|
return openai_error(exc.status_code, exc.detail)
|
|
|
|
# NOTE on concurrent access: each UUID is meant to represent a single
|
|
# rollout session used by one caller at a time. If you send concurrent
|
|
# requests to the same UUID, the ManagedServer's node extension logic
|
|
# may get confused because it does prefix matching on current_nodes.
|
|
# Don't do that. UUIDs are cheap, make a new one.
|
|
|
|
return app
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Standalone entrypoint
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def main():
|
|
import argparse
|
|
import json
|
|
|
|
import uvicorn
|
|
from transformers import AutoTokenizer
|
|
|
|
parser = argparse.ArgumentParser(description="ManagedServer OpenAI Proxy")
|
|
parser.add_argument(
|
|
"--config",
|
|
required=True,
|
|
help="Path to JSON config file with server definitions",
|
|
)
|
|
parser.add_argument("--port", type=int, default=9100, help="Proxy port")
|
|
parser.add_argument("--host", default="0.0.0.0", help="Proxy host")
|
|
args = parser.parse_args()
|
|
|
|
# Load config
|
|
with open(args.config) as f:
|
|
config = json.load(f)
|
|
|
|
model_name = config["model_name"]
|
|
tokenizer = AutoTokenizer.from_pretrained(config.get("tokenizer_name", model_name))
|
|
|
|
# Build APIServerConfigs from the JSON
|
|
server_configs = []
|
|
for srv in config["servers"]:
|
|
server_configs.append(
|
|
APIServerConfig(
|
|
model_name=model_name,
|
|
base_url=srv["base_url"],
|
|
api_key=srv.get("api_key", ""),
|
|
server_type=srv.get("server_type", "vllm"),
|
|
num_max_requests_at_once=srv.get("num_max_requests_at_once", 512),
|
|
num_requests_for_eval=srv.get("num_requests_for_eval", 64),
|
|
timeout=srv.get("timeout", 1200),
|
|
tokenizer_name=config.get("tokenizer_name", "none"),
|
|
)
|
|
)
|
|
|
|
server_manager = ServerManager(configs=server_configs)
|
|
|
|
app = create_app(
|
|
server_manager=server_manager,
|
|
tokenizer=tokenizer,
|
|
model_name=model_name,
|
|
)
|
|
|
|
print(f"Starting ManagedServer OpenAI Proxy on {args.host}:{args.port}")
|
|
print(f" Model: {model_name}")
|
|
print(f" Backends: {len(server_configs)} server(s)")
|
|
for i, cfg in enumerate(server_configs):
|
|
print(f" [{i}] {cfg.base_url} ({cfg.server_type})")
|
|
print()
|
|
print("Endpoints:")
|
|
print(" POST /sessions/create")
|
|
print(" POST /{uuid}/v1/chat/completions")
|
|
print(" POST /{uuid}/v1/chat/completions/render")
|
|
print(" GET /{uuid}/nodes")
|
|
print(" DELETE /{uuid}")
|
|
print(" GET /sessions")
|
|
print(" GET /v1/models")
|
|
print(" GET /health")
|
|
|
|
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|