mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
add tool call parsing based on vllm impl and an openai server endpoint
This commit is contained in:
parent
887a94374c
commit
add42a2afb
11 changed files with 3370 additions and 34 deletions
623
atroposlib/envs/server_handling/managed_server_proxy.py
Normal file
623
atroposlib/envs/server_handling/managed_server_proxy.py
Normal file
|
|
@ -0,0 +1,623 @@
|
|||
"""
|
||||
OpenAI-compatible chat completions proxy over ManagedServer.
|
||||
|
||||
Exposes /{uuid}/v1/chat/completions and related endpoints so that external
|
||||
environment microservices can interact with ManagedServer via standard
|
||||
OpenAI API during multi-step rollouts.
|
||||
|
||||
Each UUID maps to a session containing a ManagedServer instance. Tool call
|
||||
parsing uses vLLM's parsers directly. The ManagedServer always stores raw
|
||||
text — tool call translation only affects the HTTP wire format.
|
||||
|
||||
Uses ServerManager for routing across multiple backend servers (load balancing,
|
||||
health checks, etc.) — same infrastructure as the rest of atropos.
|
||||
|
||||
Usage:
|
||||
# Standalone with JSON config
|
||||
python -m atroposlib.envs.server_handling.managed_server_proxy \\
|
||||
--config servers.json \\
|
||||
--port 9100
|
||||
|
||||
# Or mount into existing FastAPI app
|
||||
from atroposlib.envs.server_handling.managed_server_proxy import create_app
|
||||
app = create_app(server_manager, tokenizer, model_name)
|
||||
|
||||
servers.json example:
|
||||
{
|
||||
"model_name": "Qwen/Qwen3-4B",
|
||||
"servers": [
|
||||
{"base_url": "http://gpu1:8000/v1", "server_type": "vllm", "api_key": ""},
|
||||
{"base_url": "http://gpu2:8000/v1", "server_type": "vllm", "api_key": ""}
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid as uuid_lib
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from atroposlib.envs.server_handling.managed_server import (
|
||||
DummyManagedServer,
|
||||
ManagedServer,
|
||||
)
|
||||
from atroposlib.envs.server_handling.openai_server import OpenAIServer
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
|
||||
from atroposlib.envs.server_handling.server_manager import ServerManager
|
||||
from atroposlib.envs.server_handling.tool_call_translator import ToolCallTranslator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request / Response models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[List[dict]] = None
|
||||
tool_call_id: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class ChatCompletionProxyRequest(BaseModel):
|
||||
messages: List[ChatMessage]
|
||||
max_tokens: int = 1024
|
||||
temperature: float = 1.0
|
||||
n: int = 1
|
||||
stop: Optional[List[str]] = None
|
||||
tools: Optional[List[dict]] = None
|
||||
tool_choice: Optional[Any] = None # "auto", "none", "required", or dict
|
||||
# TODO: top_p, frequency_penalty, presence_penalty, seed, response_format,
|
||||
# logprobs, top_logprobs — pass through to backend when implemented
|
||||
|
||||
|
||||
class SessionCreateRequest(BaseModel):
|
||||
tool_parser: str = "hermes"
|
||||
track_tree: bool = False
|
||||
# Pin to a specific backend server by its base_url. In production,
|
||||
# the caller gets their assigned server from the atropos API and
|
||||
# passes it here. If omitted, falls back to picking the server
|
||||
# with the most open semaphore slots (fine for dev/testing).
|
||||
base_url: Optional[str] = None
|
||||
|
||||
|
||||
class SessionCreateResponse(BaseModel):
|
||||
uuid: str
|
||||
model_name: str
|
||||
tool_parser: str
|
||||
base_url: Optional[str] = None # Which backend was selected
|
||||
created_at: float
|
||||
|
||||
|
||||
class RenderResponse(BaseModel):
|
||||
prompt_text: str
|
||||
token_ids: List[int]
|
||||
num_tokens: int
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session state
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionState:
|
||||
uuid: str
|
||||
managed_server: ManagedServer
|
||||
translator: ToolCallTranslator
|
||||
model_name: str
|
||||
base_url: Optional[str] = None # Which backend server this session is pinned to
|
||||
created_at: float = field(default_factory=time.time)
|
||||
last_accessed: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenAI error format
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def openai_error(
|
||||
status_code: int, message: str, error_type: str = "invalid_request_error"
|
||||
) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content={
|
||||
"error": {
|
||||
"message": message,
|
||||
"type": error_type,
|
||||
"code": status_code,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# App factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_app(
|
||||
server_manager: ServerManager,
|
||||
tokenizer: Any,
|
||||
model_name: str = "unknown",
|
||||
) -> FastAPI:
|
||||
"""Create the proxy FastAPI app.
|
||||
|
||||
Args:
|
||||
server_manager: ServerManager instance managing one or more backend
|
||||
servers (VLLMServer, SGLangServer, etc.). Used to pick the most
|
||||
available server when creating sessions.
|
||||
tokenizer: HuggingFace tokenizer for the model.
|
||||
model_name: Model name to report in responses.
|
||||
|
||||
Returns:
|
||||
FastAPI app with all endpoints registered.
|
||||
"""
|
||||
|
||||
app = FastAPI(title="ManagedServer OpenAI Proxy")
|
||||
sessions: Dict[str, SessionState] = {}
|
||||
|
||||
# -- helpers --
|
||||
|
||||
def _get_session(session_uuid: str) -> SessionState:
|
||||
session = sessions.get(session_uuid)
|
||||
if session is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Session {session_uuid} not found"
|
||||
)
|
||||
session.last_accessed = time.time()
|
||||
return session
|
||||
|
||||
def _get_server_base_url(server) -> Optional[str]:
|
||||
"""Get the base_url from a server's config, if available."""
|
||||
if hasattr(server, "config") and hasattr(server.config, "base_url"):
|
||||
return server.config.base_url
|
||||
return None
|
||||
|
||||
def _select_server(base_url: Optional[str] = None):
|
||||
"""Pick a server from the manager.
|
||||
|
||||
Args:
|
||||
base_url: If provided, pin to the server with this base_url.
|
||||
Raises 404 if no server matches.
|
||||
If None, picks the most available server (mirrors
|
||||
ServerManager.managed_server() logic).
|
||||
|
||||
Returns:
|
||||
Selected APIServer instance.
|
||||
"""
|
||||
if base_url is not None:
|
||||
# Pin to specific server by base_url
|
||||
for server in server_manager.servers:
|
||||
server_url = _get_server_base_url(server)
|
||||
if server_url and server_url.rstrip("/") == base_url.rstrip("/"):
|
||||
return server
|
||||
# No match — list available URLs in error
|
||||
available = [
|
||||
_get_server_base_url(s)
|
||||
for s in server_manager.servers
|
||||
if _get_server_base_url(s)
|
||||
]
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No server with base_url '{base_url}'. Available: {available}",
|
||||
)
|
||||
|
||||
# Auto-select most available
|
||||
most_available_idx = 0
|
||||
most_available_slots = -1
|
||||
for i, server in enumerate(server_manager.servers):
|
||||
if not server.server_healthy:
|
||||
continue
|
||||
if server.sem._value > most_available_slots:
|
||||
most_available_idx = i
|
||||
most_available_slots = server.sem._value
|
||||
return server_manager.servers[most_available_idx]
|
||||
|
||||
def _render_prompt(
|
||||
messages: List[Dict[str, Any]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
translator: Optional[ToolCallTranslator] = None,
|
||||
) -> str:
|
||||
"""Render messages to prompt text via chat template.
|
||||
|
||||
If a translator is provided, converts OpenAI tool_call messages
|
||||
back to raw text first.
|
||||
"""
|
||||
# Convert messages to dicts
|
||||
msg_dicts = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, BaseModel):
|
||||
msg_dicts.append(msg.model_dump(exclude_none=True))
|
||||
else:
|
||||
msg_dicts.append(msg)
|
||||
|
||||
# Reconstruct raw text for assistant tool_call messages
|
||||
if translator:
|
||||
msg_dicts = translator.convert_messages_for_template(msg_dicts)
|
||||
|
||||
# Build kwargs for apply_chat_template
|
||||
template_kwargs = {
|
||||
"tokenize": False,
|
||||
"add_generation_prompt": True,
|
||||
}
|
||||
if tools:
|
||||
template_kwargs["tools"] = tools
|
||||
|
||||
return tokenizer.apply_chat_template(msg_dicts, **template_kwargs)
|
||||
|
||||
def _build_openai_response(
|
||||
choices_data: List[dict],
|
||||
model: str,
|
||||
) -> dict:
|
||||
"""Build an OpenAI ChatCompletion response dict."""
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid_lib.uuid4().hex[:12]}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": choices_data,
|
||||
}
|
||||
|
||||
# -- endpoints --
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
healthy_servers = sum(1 for s in server_manager.servers if s.server_healthy)
|
||||
return {
|
||||
"status": "ok",
|
||||
"model": model_name,
|
||||
"sessions": len(sessions),
|
||||
"servers": len(server_manager.servers),
|
||||
"healthy_servers": healthy_servers,
|
||||
}
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def list_models():
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": model_name,
|
||||
"object": "model",
|
||||
"created": int(time.time()),
|
||||
"owned_by": "atropos",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
@app.post("/setup")
|
||||
async def setup(request: Request):
|
||||
"""Receive server configuration from ServerManager.
|
||||
|
||||
Accepts the same JSON format as the standalone --config file.
|
||||
Replaces the current server_manager's servers with the new config.
|
||||
Called by ServerManager at startup to push its config to the proxy.
|
||||
|
||||
Body: {"model_name": "...", "servers": [{"base_url": "...", "server_type": "vllm"}, ...]}
|
||||
"""
|
||||
config = await request.json()
|
||||
|
||||
new_configs = []
|
||||
new_model = config.get("model_name", model_name)
|
||||
for srv in config.get("servers", []):
|
||||
new_configs.append(
|
||||
APIServerConfig(
|
||||
model_name=new_model,
|
||||
base_url=srv["base_url"],
|
||||
api_key=srv.get("api_key", ""),
|
||||
server_type=srv.get("server_type", "vllm"),
|
||||
num_max_requests_at_once=srv.get("num_max_requests_at_once", 512),
|
||||
num_requests_for_eval=srv.get("num_requests_for_eval", 64),
|
||||
timeout=srv.get("timeout", 1200),
|
||||
tokenizer_name=config.get("tokenizer_name", "none"),
|
||||
)
|
||||
)
|
||||
|
||||
if new_configs:
|
||||
new_manager = ServerManager(configs=new_configs)
|
||||
server_manager.servers = new_manager.servers
|
||||
logger.info("Setup: replaced servers with %d new configs", len(new_configs))
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"servers": len(server_manager.servers),
|
||||
"model_name": new_model,
|
||||
}
|
||||
|
||||
@app.get("/servers")
|
||||
async def list_servers():
|
||||
"""List available backend servers.
|
||||
|
||||
Useful for discovery/debugging. In production, server allocation
|
||||
is managed by the atropos API — the environment gets told which
|
||||
server to use and passes that base_url to POST /sessions/create.
|
||||
"""
|
||||
server_list = []
|
||||
for i, server in enumerate(server_manager.servers):
|
||||
url = _get_server_base_url(server)
|
||||
healthy = getattr(server, "server_healthy", True)
|
||||
server_list.append(
|
||||
{
|
||||
"index": i,
|
||||
"base_url": url,
|
||||
"healthy": healthy,
|
||||
"model_name": (
|
||||
getattr(server.config, "model_name", model_name)
|
||||
if hasattr(server, "config")
|
||||
else model_name
|
||||
),
|
||||
"server_type": (
|
||||
getattr(server.config, "server_type", "unknown")
|
||||
if hasattr(server, "config")
|
||||
else "unknown"
|
||||
),
|
||||
}
|
||||
)
|
||||
return {"servers": server_list}
|
||||
|
||||
@app.post("/sessions/create", response_model=SessionCreateResponse)
|
||||
async def create_session(request: SessionCreateRequest):
|
||||
session_uuid = str(uuid_lib.uuid4())
|
||||
|
||||
# Pick server — pinned to base_url if specified, otherwise most available
|
||||
selected_server = _select_server(base_url=request.base_url)
|
||||
selected_url = _get_server_base_url(selected_server)
|
||||
|
||||
# Use DummyManagedServer for OpenAI endpoints (no logprobs support)
|
||||
if isinstance(selected_server, OpenAIServer):
|
||||
logger.info(
|
||||
"Session %s using DummyManagedServer (OpenAI endpoint). "
|
||||
"Token IDs and logprobs will be placeholders.",
|
||||
session_uuid,
|
||||
)
|
||||
managed = DummyManagedServer(
|
||||
server=selected_server,
|
||||
tokenizer=tokenizer,
|
||||
track_tree=request.track_tree,
|
||||
)
|
||||
else:
|
||||
managed = ManagedServer(
|
||||
server=selected_server,
|
||||
tokenizer=tokenizer,
|
||||
track_tree=request.track_tree,
|
||||
tool_parser=request.tool_parser,
|
||||
)
|
||||
|
||||
# Translator kept for the render endpoint (prompt preview)
|
||||
translator = ToolCallTranslator(
|
||||
tokenizer=tokenizer,
|
||||
parser_name=request.tool_parser,
|
||||
)
|
||||
session = SessionState(
|
||||
uuid=session_uuid,
|
||||
managed_server=managed,
|
||||
translator=translator,
|
||||
model_name=model_name,
|
||||
base_url=selected_url,
|
||||
)
|
||||
sessions[session_uuid] = session
|
||||
|
||||
return SessionCreateResponse(
|
||||
uuid=session_uuid,
|
||||
model_name=model_name,
|
||||
tool_parser=request.tool_parser,
|
||||
base_url=selected_url,
|
||||
created_at=session.created_at,
|
||||
)
|
||||
|
||||
@app.get("/sessions")
|
||||
async def list_sessions():
|
||||
return {
|
||||
"sessions": [
|
||||
{
|
||||
"uuid": s.uuid,
|
||||
"model_name": s.model_name,
|
||||
"base_url": s.base_url,
|
||||
"created_at": s.created_at,
|
||||
"last_accessed": s.last_accessed,
|
||||
"num_nodes": len(
|
||||
s.managed_server.current_nodes
|
||||
if hasattr(s.managed_server, "current_nodes")
|
||||
else s.managed_server.sequences
|
||||
),
|
||||
}
|
||||
for s in sessions.values()
|
||||
]
|
||||
}
|
||||
|
||||
@app.post("/{session_uuid}/v1/chat/completions")
|
||||
async def chat_completions(session_uuid: str, request: ChatCompletionProxyRequest):
|
||||
session = _get_session(session_uuid)
|
||||
managed = session.managed_server
|
||||
|
||||
if not request.messages:
|
||||
return openai_error(400, "messages must not be empty")
|
||||
|
||||
# Convert pydantic messages to dicts
|
||||
messages = [msg.model_dump(exclude_none=True) for msg in request.messages]
|
||||
|
||||
# Build kwargs — ManagedServer.chat_completion() handles all tool
|
||||
# call logic internally (template rendering, inbound reconstruction,
|
||||
# outbound parsing, skip_special_tokens)
|
||||
completion_kwargs = {
|
||||
"messages": messages,
|
||||
"n": request.n,
|
||||
"max_tokens": request.max_tokens,
|
||||
"temperature": request.temperature,
|
||||
}
|
||||
if request.stop:
|
||||
completion_kwargs["stop"] = request.stop
|
||||
if request.tools:
|
||||
completion_kwargs["tools"] = request.tools
|
||||
if request.tool_choice is not None:
|
||||
completion_kwargs["tool_choice"] = request.tool_choice
|
||||
|
||||
try:
|
||||
result = await managed.chat_completion(**completion_kwargs)
|
||||
except Exception as e:
|
||||
logger.exception("Completion failed")
|
||||
return openai_error(
|
||||
500, f"Completion failed: {e}", error_type="server_error"
|
||||
)
|
||||
|
||||
# Convert ChatCompletion to JSON-serializable response
|
||||
choices = []
|
||||
for choice in result.choices:
|
||||
choice_data = {
|
||||
"index": choice.index,
|
||||
"message": {
|
||||
"role": choice.message.role,
|
||||
"content": choice.message.content,
|
||||
},
|
||||
"finish_reason": choice.finish_reason,
|
||||
}
|
||||
if choice.message.tool_calls:
|
||||
choice_data["message"]["tool_calls"] = choice.message.tool_calls
|
||||
|
||||
choices.append(choice_data)
|
||||
|
||||
return _build_openai_response(choices, model_name)
|
||||
|
||||
@app.post("/{session_uuid}/v1/chat/completions/render")
|
||||
async def render_prompt(session_uuid: str, request: ChatCompletionProxyRequest):
|
||||
session = _get_session(session_uuid)
|
||||
|
||||
try:
|
||||
prompt_text = _render_prompt(
|
||||
messages=request.messages,
|
||||
tools=request.tools,
|
||||
translator=session.translator,
|
||||
)
|
||||
except Exception as e:
|
||||
return openai_error(400, f"Failed to render prompt: {e}")
|
||||
|
||||
token_ids = tokenizer.encode(prompt_text, add_special_tokens=False)
|
||||
|
||||
return RenderResponse(
|
||||
prompt_text=prompt_text,
|
||||
token_ids=token_ids,
|
||||
num_tokens=len(token_ids),
|
||||
)
|
||||
|
||||
@app.get("/{session_uuid}/nodes")
|
||||
async def get_nodes(session_uuid: str):
|
||||
session = _get_session(session_uuid)
|
||||
state = session.managed_server.get_state()
|
||||
|
||||
if session.managed_server.track_tree:
|
||||
nodes = list(state.get("sequences", {}).values())
|
||||
else:
|
||||
nodes = state.get("nodes", [])
|
||||
|
||||
return {
|
||||
"nodes": [node.model_dump() for node in nodes],
|
||||
}
|
||||
|
||||
@app.delete("/{session_uuid}")
|
||||
async def delete_session(session_uuid: str):
|
||||
session = sessions.pop(session_uuid, None)
|
||||
if session is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Session {session_uuid} not found"
|
||||
)
|
||||
session.managed_server.reset()
|
||||
return {"status": "deleted", "uuid": session_uuid}
|
||||
|
||||
# -- exception handler for OpenAI-style errors --
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(request: Request, exc: HTTPException):
|
||||
return openai_error(exc.status_code, exc.detail)
|
||||
|
||||
# NOTE on concurrent access: each UUID is meant to represent a single
|
||||
# rollout session used by one caller at a time. If you send concurrent
|
||||
# requests to the same UUID, the ManagedServer's node extension logic
|
||||
# may get confused because it does prefix matching on current_nodes.
|
||||
# Don't do that. UUIDs are cheap, make a new one.
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Standalone entrypoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import uvicorn
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
parser = argparse.ArgumentParser(description="ManagedServer OpenAI Proxy")
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
required=True,
|
||||
help="Path to JSON config file with server definitions",
|
||||
)
|
||||
parser.add_argument("--port", type=int, default=9100, help="Proxy port")
|
||||
parser.add_argument("--host", default="0.0.0.0", help="Proxy host")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load config
|
||||
with open(args.config) as f:
|
||||
config = json.load(f)
|
||||
|
||||
model_name = config["model_name"]
|
||||
tokenizer = AutoTokenizer.from_pretrained(config.get("tokenizer_name", model_name))
|
||||
|
||||
# Build APIServerConfigs from the JSON
|
||||
server_configs = []
|
||||
for srv in config["servers"]:
|
||||
server_configs.append(
|
||||
APIServerConfig(
|
||||
model_name=model_name,
|
||||
base_url=srv["base_url"],
|
||||
api_key=srv.get("api_key", ""),
|
||||
server_type=srv.get("server_type", "vllm"),
|
||||
num_max_requests_at_once=srv.get("num_max_requests_at_once", 512),
|
||||
num_requests_for_eval=srv.get("num_requests_for_eval", 64),
|
||||
timeout=srv.get("timeout", 1200),
|
||||
tokenizer_name=config.get("tokenizer_name", "none"),
|
||||
)
|
||||
)
|
||||
|
||||
server_manager = ServerManager(configs=server_configs)
|
||||
|
||||
app = create_app(
|
||||
server_manager=server_manager,
|
||||
tokenizer=tokenizer,
|
||||
model_name=model_name,
|
||||
)
|
||||
|
||||
print(f"Starting ManagedServer OpenAI Proxy on {args.host}:{args.port}")
|
||||
print(f" Model: {model_name}")
|
||||
print(f" Backends: {len(server_configs)} server(s)")
|
||||
for i, cfg in enumerate(server_configs):
|
||||
print(f" [{i}] {cfg.base_url} ({cfg.server_type})")
|
||||
print()
|
||||
print("Endpoints:")
|
||||
print(" POST /sessions/create")
|
||||
print(" POST /{uuid}/v1/chat/completions")
|
||||
print(" POST /{uuid}/v1/chat/completions/render")
|
||||
print(" GET /{uuid}/nodes")
|
||||
print(" DELETE /{uuid}")
|
||||
print(" GET /sessions")
|
||||
print(" GET /v1/models")
|
||||
print(" GET /health")
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue