add tool call parsing based on vllm impl and an openai server endpoint

This commit is contained in:
dmahan93 2026-03-02 23:17:13 -06:00
parent 887a94374c
commit add42a2afb
11 changed files with 3370 additions and 34 deletions

View file

@ -62,6 +62,7 @@ class ManagedServer:
server: APIServer,
tokenizer: Optional[Any] = None,
track_tree: bool = False,
tool_parser: Optional[str] = None,
):
"""
Initialize the managed server.
@ -73,10 +74,17 @@ class ManagedServer:
track_tree: If True, maintains a tree structure with parent-child links
(for multi-turn RL with per-step advantages). If False (default),
maintains a simple list of current nodes that updates in-place.
tool_parser: Optional vLLM tool parser name (e.g. "hermes", "llama3_json",
"mistral", etc.). If provided, enables tool call support in
chat_completion(). The parser handles extraction of structured
tool calls from raw model output. See
ToolParserManager.list_registered() for available parsers.
"""
self.server = server
self.tokenizer = tokenizer
self.track_tree = track_tree
self._tool_parser_name = tool_parser
self._translator = None # Lazy init
# Initialize storage based on mode
if track_tree:
@ -107,19 +115,57 @@ class ManagedServer:
)
self.tokenizer = None
def _convert_messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
def _get_translator(self):
"""Lazily create the ToolCallTranslator when first needed.
Returns None if tool_parser was not specified or if vLLM is not
installed (the translator will warn on creation in that case).
"""
if self._translator is None and self._tool_parser_name and self.tokenizer:
try:
from atroposlib.envs.server_handling.tool_call_translator import (
ToolCallTranslator,
)
self._translator = ToolCallTranslator(
tokenizer=self.tokenizer,
parser_name=self._tool_parser_name,
)
except Exception as e:
warnings.warn(
f"Failed to create ToolCallTranslator: {e}. "
"Tool call parsing will be disabled.",
stacklevel=2,
)
self._tool_parser_name = None # Don't retry
return None
return self._translator
def _convert_messages_to_prompt(
self,
messages: List[Dict[str, str]],
tools: Optional[List[dict]] = None,
) -> str:
"""
Convert chat messages to prompt text using tokenizer's chat template.
Args:
messages: List of message dicts with 'role' and 'content'
tools: Optional list of tool definitions (OpenAI format). Passed to
apply_chat_template() so the template can inject tool defs
into the system prompt.
Returns:
Formatted prompt string
"""
# If tools are active and we have a translator, convert any assistant
# messages with tool_calls back to raw text first
if tools and self._get_translator():
messages = self._get_translator().convert_messages_for_template(messages)
if self.tokenizer is None:
# Fallback: simple concatenation
return "\n".join([f"{m['role']}: {m['content']}" for m in messages])
return "\n".join([f"{m['role']}: {m.get('content', '')}" for m in messages])
if hasattr(self.tokenizer, "apply_chat_template"):
# Only add generation prompt if last message is not from assistant
@ -127,13 +173,19 @@ class ManagedServer:
len(messages) == 0 or messages[-1].get("role") != "assistant"
)
# Build kwargs
template_kwargs = {
"tokenize": False,
"add_generation_prompt": add_generation_prompt,
}
if tools:
template_kwargs["tools"] = tools
# Use the tokenizer's chat template
return self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=add_generation_prompt
)
return self.tokenizer.apply_chat_template(messages, **template_kwargs)
else:
# Fallback for tokenizers without chat template
return "\n".join([f"{m['role']}: {m['content']}" for m in messages])
return "\n".join([f"{m['role']}: {m.get('content', '')}" for m in messages])
def _debug_requests_enabled(self) -> bool:
"""Enable verbose request construction logs with ATROPOS_DEBUG_REQUESTS=1."""
@ -268,15 +320,35 @@ class ManagedServer:
Internally converts to prompt, calls tokens_and_logprobs_completion,
tracks the sequence, and reconstructs a ChatCompletion response.
Supports tool calling when a tool_parser was provided at init:
- Accepts `tools` and `tool_choice` kwargs
- Converts inbound assistant tool_call messages to raw text
- Parses outbound model output for tool calls
- Returns ChatCompletion with proper tool_calls in choices
- Preserves raw text in tracked nodes (tool parsing is response-only)
Args:
**kwargs: Standard chat completion kwargs (messages, n, etc.)
**kwargs: Standard chat completion kwargs (messages, n, max_tokens,
temperature, tools, tool_choice, etc.)
Returns:
ChatCompletion response
ChatCompletion response (with tool_calls if detected)
"""
# Get input text
# Extract tool-related kwargs
tools = kwargs.pop("tools", None)
tool_choice = kwargs.pop("tool_choice", None)
has_tools = bool(tools) and self._get_translator() is not None
# Default tool_choice to "auto" if tools provided
if has_tools and tool_choice is None:
tool_choice = "auto"
# Get input text — passes tools for template rendering and
# handles reconstruction of inbound tool_call messages
messages = kwargs.get("messages", [])
prompt = self._convert_messages_to_prompt(messages)
prompt = self._convert_messages_to_prompt(
messages, tools=tools if has_tools else None
)
# Handle parent node and extending logic based on mode
if self.track_tree:
@ -296,11 +368,12 @@ class ManagedServer:
msg_count = len(messages)
prompt_preview = prompt.replace("\n", "\\n")[:600]
logger.debug(
"[ATROPOS_REQ_DEBUG] chat_completion messages=%s n=%s max_tokens=%s temperature=%s",
"[ATROPOS_REQ_DEBUG] chat_completion messages=%s n=%s max_tokens=%s temperature=%s tools=%s",
msg_count,
completion_kwargs.get("n"),
completion_kwargs.get("max_tokens"),
completion_kwargs.get("temperature"),
bool(tools),
)
logger.debug("[ATROPOS_REQ_DEBUG] prompt_preview=%r", prompt_preview)
@ -336,15 +409,18 @@ class ManagedServer:
else:
finish_reason = finish_reason_raw
# Decode completion text
# Decode completion text — use skip_special_tokens=False when
# tools are active so <tool_call> tags aren't stripped
if self.tokenizer is not None:
completion_text = self.tokenizer.decode(
output_tokens, skip_special_tokens=True
output_tokens,
skip_special_tokens=not has_tools,
)
else:
completion_text = "".join([chr(t) for t in output_tokens if t > 31])
# Create and store sequence node
# Create and store sequence node — always uses the raw text,
# tool parsing only affects the ChatCompletion response
node = self._create_sequence_node(
input_text=prompt,
parent_node=parent_node,
@ -373,14 +449,50 @@ class ManagedServer:
# New context - append to list
self.current_nodes.append(node)
# Parse tool calls from raw output if tools are active
tool_calls_parsed = None
content_for_response = completion_text
if has_tools and tool_choice != "none":
translator = self._get_translator()
content_for_response, tool_calls_parsed, finish_reason = (
translator.parse_model_output(
raw_text=completion_text,
tool_choice=(
tool_choice if isinstance(tool_choice, str) else "auto"
),
tools=tools,
)
)
# Build choice
message_kwargs = {
"content": content_for_response,
"role": "assistant",
}
# Note: openai's ChatCompletionMessage model handles tool_calls
# but we can't pass them through the constructor easily. We'll
# attach them after construction if needed.
choice = Choice(
finish_reason=finish_reason,
index=i,
message=ChatCompletionMessage(
content=completion_text, role="assistant"
),
message=ChatCompletionMessage(**message_kwargs),
)
# Attach tool_calls to the message if present
if tool_calls_parsed:
choice.message.tool_calls = [
# Convert vLLM ToolCall to openai ToolCall format
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
for tc in tool_calls_parsed
]
choices.append(choice)
# Construct ChatCompletion response

View file

@ -0,0 +1,623 @@
"""
OpenAI-compatible chat completions proxy over ManagedServer.
Exposes /{uuid}/v1/chat/completions and related endpoints so that external
environment microservices can interact with ManagedServer via standard
OpenAI API during multi-step rollouts.
Each UUID maps to a session containing a ManagedServer instance. Tool call
parsing uses vLLM's parsers directly. The ManagedServer always stores raw
text tool call translation only affects the HTTP wire format.
Uses ServerManager for routing across multiple backend servers (load balancing,
health checks, etc.) same infrastructure as the rest of atropos.
Usage:
# Standalone with JSON config
python -m atroposlib.envs.server_handling.managed_server_proxy \\
--config servers.json \\
--port 9100
# Or mount into existing FastAPI app
from atroposlib.envs.server_handling.managed_server_proxy import create_app
app = create_app(server_manager, tokenizer, model_name)
servers.json example:
{
"model_name": "Qwen/Qwen3-4B",
"servers": [
{"base_url": "http://gpu1:8000/v1", "server_type": "vllm", "api_key": ""},
{"base_url": "http://gpu2:8000/v1", "server_type": "vllm", "api_key": ""}
]
}
"""
import logging
import time
import uuid as uuid_lib
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from atroposlib.envs.server_handling.managed_server import (
DummyManagedServer,
ManagedServer,
)
from atroposlib.envs.server_handling.openai_server import OpenAIServer
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
from atroposlib.envs.server_handling.server_manager import ServerManager
from atroposlib.envs.server_handling.tool_call_translator import ToolCallTranslator
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Request / Response models
# ---------------------------------------------------------------------------
class ChatMessage(BaseModel):
role: str
content: Optional[str] = None
tool_calls: Optional[List[dict]] = None
tool_call_id: Optional[str] = None
name: Optional[str] = None
class ChatCompletionProxyRequest(BaseModel):
messages: List[ChatMessage]
max_tokens: int = 1024
temperature: float = 1.0
n: int = 1
stop: Optional[List[str]] = None
tools: Optional[List[dict]] = None
tool_choice: Optional[Any] = None # "auto", "none", "required", or dict
# TODO: top_p, frequency_penalty, presence_penalty, seed, response_format,
# logprobs, top_logprobs — pass through to backend when implemented
class SessionCreateRequest(BaseModel):
tool_parser: str = "hermes"
track_tree: bool = False
# Pin to a specific backend server by its base_url. In production,
# the caller gets their assigned server from the atropos API and
# passes it here. If omitted, falls back to picking the server
# with the most open semaphore slots (fine for dev/testing).
base_url: Optional[str] = None
class SessionCreateResponse(BaseModel):
uuid: str
model_name: str
tool_parser: str
base_url: Optional[str] = None # Which backend was selected
created_at: float
class RenderResponse(BaseModel):
prompt_text: str
token_ids: List[int]
num_tokens: int
# ---------------------------------------------------------------------------
# Session state
# ---------------------------------------------------------------------------
@dataclass
class SessionState:
uuid: str
managed_server: ManagedServer
translator: ToolCallTranslator
model_name: str
base_url: Optional[str] = None # Which backend server this session is pinned to
created_at: float = field(default_factory=time.time)
last_accessed: float = field(default_factory=time.time)
# ---------------------------------------------------------------------------
# OpenAI error format
# ---------------------------------------------------------------------------
def openai_error(
status_code: int, message: str, error_type: str = "invalid_request_error"
) -> JSONResponse:
return JSONResponse(
status_code=status_code,
content={
"error": {
"message": message,
"type": error_type,
"code": status_code,
}
},
)
# ---------------------------------------------------------------------------
# App factory
# ---------------------------------------------------------------------------
def create_app(
server_manager: ServerManager,
tokenizer: Any,
model_name: str = "unknown",
) -> FastAPI:
"""Create the proxy FastAPI app.
Args:
server_manager: ServerManager instance managing one or more backend
servers (VLLMServer, SGLangServer, etc.). Used to pick the most
available server when creating sessions.
tokenizer: HuggingFace tokenizer for the model.
model_name: Model name to report in responses.
Returns:
FastAPI app with all endpoints registered.
"""
app = FastAPI(title="ManagedServer OpenAI Proxy")
sessions: Dict[str, SessionState] = {}
# -- helpers --
def _get_session(session_uuid: str) -> SessionState:
session = sessions.get(session_uuid)
if session is None:
raise HTTPException(
status_code=404, detail=f"Session {session_uuid} not found"
)
session.last_accessed = time.time()
return session
def _get_server_base_url(server) -> Optional[str]:
"""Get the base_url from a server's config, if available."""
if hasattr(server, "config") and hasattr(server.config, "base_url"):
return server.config.base_url
return None
def _select_server(base_url: Optional[str] = None):
"""Pick a server from the manager.
Args:
base_url: If provided, pin to the server with this base_url.
Raises 404 if no server matches.
If None, picks the most available server (mirrors
ServerManager.managed_server() logic).
Returns:
Selected APIServer instance.
"""
if base_url is not None:
# Pin to specific server by base_url
for server in server_manager.servers:
server_url = _get_server_base_url(server)
if server_url and server_url.rstrip("/") == base_url.rstrip("/"):
return server
# No match — list available URLs in error
available = [
_get_server_base_url(s)
for s in server_manager.servers
if _get_server_base_url(s)
]
raise HTTPException(
status_code=404,
detail=f"No server with base_url '{base_url}'. Available: {available}",
)
# Auto-select most available
most_available_idx = 0
most_available_slots = -1
for i, server in enumerate(server_manager.servers):
if not server.server_healthy:
continue
if server.sem._value > most_available_slots:
most_available_idx = i
most_available_slots = server.sem._value
return server_manager.servers[most_available_idx]
def _render_prompt(
messages: List[Dict[str, Any]],
tools: Optional[List[dict]] = None,
translator: Optional[ToolCallTranslator] = None,
) -> str:
"""Render messages to prompt text via chat template.
If a translator is provided, converts OpenAI tool_call messages
back to raw text first.
"""
# Convert messages to dicts
msg_dicts = []
for msg in messages:
if isinstance(msg, BaseModel):
msg_dicts.append(msg.model_dump(exclude_none=True))
else:
msg_dicts.append(msg)
# Reconstruct raw text for assistant tool_call messages
if translator:
msg_dicts = translator.convert_messages_for_template(msg_dicts)
# Build kwargs for apply_chat_template
template_kwargs = {
"tokenize": False,
"add_generation_prompt": True,
}
if tools:
template_kwargs["tools"] = tools
return tokenizer.apply_chat_template(msg_dicts, **template_kwargs)
def _build_openai_response(
choices_data: List[dict],
model: str,
) -> dict:
"""Build an OpenAI ChatCompletion response dict."""
return {
"id": f"chatcmpl-{uuid_lib.uuid4().hex[:12]}",
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": choices_data,
}
# -- endpoints --
@app.get("/health")
async def health():
healthy_servers = sum(1 for s in server_manager.servers if s.server_healthy)
return {
"status": "ok",
"model": model_name,
"sessions": len(sessions),
"servers": len(server_manager.servers),
"healthy_servers": healthy_servers,
}
@app.get("/v1/models")
async def list_models():
return {
"object": "list",
"data": [
{
"id": model_name,
"object": "model",
"created": int(time.time()),
"owned_by": "atropos",
}
],
}
@app.post("/setup")
async def setup(request: Request):
"""Receive server configuration from ServerManager.
Accepts the same JSON format as the standalone --config file.
Replaces the current server_manager's servers with the new config.
Called by ServerManager at startup to push its config to the proxy.
Body: {"model_name": "...", "servers": [{"base_url": "...", "server_type": "vllm"}, ...]}
"""
config = await request.json()
new_configs = []
new_model = config.get("model_name", model_name)
for srv in config.get("servers", []):
new_configs.append(
APIServerConfig(
model_name=new_model,
base_url=srv["base_url"],
api_key=srv.get("api_key", ""),
server_type=srv.get("server_type", "vllm"),
num_max_requests_at_once=srv.get("num_max_requests_at_once", 512),
num_requests_for_eval=srv.get("num_requests_for_eval", 64),
timeout=srv.get("timeout", 1200),
tokenizer_name=config.get("tokenizer_name", "none"),
)
)
if new_configs:
new_manager = ServerManager(configs=new_configs)
server_manager.servers = new_manager.servers
logger.info("Setup: replaced servers with %d new configs", len(new_configs))
return {
"status": "ok",
"servers": len(server_manager.servers),
"model_name": new_model,
}
@app.get("/servers")
async def list_servers():
"""List available backend servers.
Useful for discovery/debugging. In production, server allocation
is managed by the atropos API the environment gets told which
server to use and passes that base_url to POST /sessions/create.
"""
server_list = []
for i, server in enumerate(server_manager.servers):
url = _get_server_base_url(server)
healthy = getattr(server, "server_healthy", True)
server_list.append(
{
"index": i,
"base_url": url,
"healthy": healthy,
"model_name": (
getattr(server.config, "model_name", model_name)
if hasattr(server, "config")
else model_name
),
"server_type": (
getattr(server.config, "server_type", "unknown")
if hasattr(server, "config")
else "unknown"
),
}
)
return {"servers": server_list}
@app.post("/sessions/create", response_model=SessionCreateResponse)
async def create_session(request: SessionCreateRequest):
session_uuid = str(uuid_lib.uuid4())
# Pick server — pinned to base_url if specified, otherwise most available
selected_server = _select_server(base_url=request.base_url)
selected_url = _get_server_base_url(selected_server)
# Use DummyManagedServer for OpenAI endpoints (no logprobs support)
if isinstance(selected_server, OpenAIServer):
logger.info(
"Session %s using DummyManagedServer (OpenAI endpoint). "
"Token IDs and logprobs will be placeholders.",
session_uuid,
)
managed = DummyManagedServer(
server=selected_server,
tokenizer=tokenizer,
track_tree=request.track_tree,
)
else:
managed = ManagedServer(
server=selected_server,
tokenizer=tokenizer,
track_tree=request.track_tree,
tool_parser=request.tool_parser,
)
# Translator kept for the render endpoint (prompt preview)
translator = ToolCallTranslator(
tokenizer=tokenizer,
parser_name=request.tool_parser,
)
session = SessionState(
uuid=session_uuid,
managed_server=managed,
translator=translator,
model_name=model_name,
base_url=selected_url,
)
sessions[session_uuid] = session
return SessionCreateResponse(
uuid=session_uuid,
model_name=model_name,
tool_parser=request.tool_parser,
base_url=selected_url,
created_at=session.created_at,
)
@app.get("/sessions")
async def list_sessions():
return {
"sessions": [
{
"uuid": s.uuid,
"model_name": s.model_name,
"base_url": s.base_url,
"created_at": s.created_at,
"last_accessed": s.last_accessed,
"num_nodes": len(
s.managed_server.current_nodes
if hasattr(s.managed_server, "current_nodes")
else s.managed_server.sequences
),
}
for s in sessions.values()
]
}
@app.post("/{session_uuid}/v1/chat/completions")
async def chat_completions(session_uuid: str, request: ChatCompletionProxyRequest):
session = _get_session(session_uuid)
managed = session.managed_server
if not request.messages:
return openai_error(400, "messages must not be empty")
# Convert pydantic messages to dicts
messages = [msg.model_dump(exclude_none=True) for msg in request.messages]
# Build kwargs — ManagedServer.chat_completion() handles all tool
# call logic internally (template rendering, inbound reconstruction,
# outbound parsing, skip_special_tokens)
completion_kwargs = {
"messages": messages,
"n": request.n,
"max_tokens": request.max_tokens,
"temperature": request.temperature,
}
if request.stop:
completion_kwargs["stop"] = request.stop
if request.tools:
completion_kwargs["tools"] = request.tools
if request.tool_choice is not None:
completion_kwargs["tool_choice"] = request.tool_choice
try:
result = await managed.chat_completion(**completion_kwargs)
except Exception as e:
logger.exception("Completion failed")
return openai_error(
500, f"Completion failed: {e}", error_type="server_error"
)
# Convert ChatCompletion to JSON-serializable response
choices = []
for choice in result.choices:
choice_data = {
"index": choice.index,
"message": {
"role": choice.message.role,
"content": choice.message.content,
},
"finish_reason": choice.finish_reason,
}
if choice.message.tool_calls:
choice_data["message"]["tool_calls"] = choice.message.tool_calls
choices.append(choice_data)
return _build_openai_response(choices, model_name)
@app.post("/{session_uuid}/v1/chat/completions/render")
async def render_prompt(session_uuid: str, request: ChatCompletionProxyRequest):
session = _get_session(session_uuid)
try:
prompt_text = _render_prompt(
messages=request.messages,
tools=request.tools,
translator=session.translator,
)
except Exception as e:
return openai_error(400, f"Failed to render prompt: {e}")
token_ids = tokenizer.encode(prompt_text, add_special_tokens=False)
return RenderResponse(
prompt_text=prompt_text,
token_ids=token_ids,
num_tokens=len(token_ids),
)
@app.get("/{session_uuid}/nodes")
async def get_nodes(session_uuid: str):
session = _get_session(session_uuid)
state = session.managed_server.get_state()
if session.managed_server.track_tree:
nodes = list(state.get("sequences", {}).values())
else:
nodes = state.get("nodes", [])
return {
"nodes": [node.model_dump() for node in nodes],
}
@app.delete("/{session_uuid}")
async def delete_session(session_uuid: str):
session = sessions.pop(session_uuid, None)
if session is None:
raise HTTPException(
status_code=404, detail=f"Session {session_uuid} not found"
)
session.managed_server.reset()
return {"status": "deleted", "uuid": session_uuid}
# -- exception handler for OpenAI-style errors --
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
return openai_error(exc.status_code, exc.detail)
# NOTE on concurrent access: each UUID is meant to represent a single
# rollout session used by one caller at a time. If you send concurrent
# requests to the same UUID, the ManagedServer's node extension logic
# may get confused because it does prefix matching on current_nodes.
# Don't do that. UUIDs are cheap, make a new one.
return app
# ---------------------------------------------------------------------------
# Standalone entrypoint
# ---------------------------------------------------------------------------
def main():
import argparse
import json
import uvicorn
from transformers import AutoTokenizer
parser = argparse.ArgumentParser(description="ManagedServer OpenAI Proxy")
parser.add_argument(
"--config",
required=True,
help="Path to JSON config file with server definitions",
)
parser.add_argument("--port", type=int, default=9100, help="Proxy port")
parser.add_argument("--host", default="0.0.0.0", help="Proxy host")
args = parser.parse_args()
# Load config
with open(args.config) as f:
config = json.load(f)
model_name = config["model_name"]
tokenizer = AutoTokenizer.from_pretrained(config.get("tokenizer_name", model_name))
# Build APIServerConfigs from the JSON
server_configs = []
for srv in config["servers"]:
server_configs.append(
APIServerConfig(
model_name=model_name,
base_url=srv["base_url"],
api_key=srv.get("api_key", ""),
server_type=srv.get("server_type", "vllm"),
num_max_requests_at_once=srv.get("num_max_requests_at_once", 512),
num_requests_for_eval=srv.get("num_requests_for_eval", 64),
timeout=srv.get("timeout", 1200),
tokenizer_name=config.get("tokenizer_name", "none"),
)
)
server_manager = ServerManager(configs=server_configs)
app = create_app(
server_manager=server_manager,
tokenizer=tokenizer,
model_name=model_name,
)
print(f"Starting ManagedServer OpenAI Proxy on {args.host}:{args.port}")
print(f" Model: {model_name}")
print(f" Backends: {len(server_configs)} server(s)")
for i, cfg in enumerate(server_configs):
print(f" [{i}] {cfg.base_url} ({cfg.server_type})")
print()
print("Endpoints:")
print(" POST /sessions/create")
print(" POST /{uuid}/v1/chat/completions")
print(" POST /{uuid}/v1/chat/completions/render")
print(" GET /{uuid}/nodes")
print(" DELETE /{uuid}")
print(" GET /sessions")
print(" GET /v1/models")
print(" GET /health")
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
if __name__ == "__main__":
main()

View file

@ -0,0 +1,292 @@
"""
Client that talks to the ManagedServer OpenAI proxy over HTTP.
Implements the same interface as ManagedServer so it can be used as a
drop-in replacement via ServerManager.managed_server(use_proxy=True).
The proxy handles all the token tracking, tool call parsing, and sequence
management. This client just ferries requests/responses over HTTP and
reconstructs the SequenceNode objects from the JSON.
"""
import logging
import time
import uuid as uuid_lib
from typing import Any, Dict, List, Optional
import aiohttp
from openai.types.chat.chat_completion import (
ChatCompletion,
ChatCompletionMessage,
Choice,
)
from openai.types.completion import Completion # noqa: F401 — used in type hint
from atroposlib.envs.server_handling.managed_server import SequenceNode
logger = logging.getLogger(__name__)
class ProxyManagedServer:
"""Client that talks to the ManagedServer OpenAI proxy.
Same interface as ManagedServer chat_completion(), completion(),
get_state(), reset(). But instead of doing token tracking in-process,
delegates everything to the proxy over HTTP.
Created by ServerManager.managed_server(use_proxy=True).
Example:
async with server_manager.managed_server(use_proxy=True) as managed:
# Same API as regular ManagedServer
resp = await managed.chat_completion(
messages=[{"role": "user", "content": "Hello"}],
n=4, max_tokens=100, temperature=1.0,
)
state = managed.get_state()
nodes = state["nodes"]
# Extra: get URL for external apps to use directly
url = managed.get_url()
# → "http://proxy:9100/{uuid}/v1"
"""
def __init__(
self,
proxy_url: str,
session_uuid: str,
model_name: str = "unknown",
base_url: Optional[str] = None,
):
"""
Args:
proxy_url: Base URL of the proxy (e.g. "http://localhost:9100")
session_uuid: UUID of the session on the proxy.
model_name: Model name (for response objects).
base_url: The backend server this session is pinned to.
"""
self.proxy_url = proxy_url.rstrip("/")
self.session_uuid = session_uuid
self.model_name = model_name
self.base_url = base_url
# Cache for nodes (populated by get_state)
self._cached_nodes: Optional[List[SequenceNode]] = None
def get_url(self) -> str:
"""Get the OpenAI-compatible API URL for this session.
External apps can use this URL with any OpenAI client:
client = openai.OpenAI(base_url=managed.get_url())
client.chat.completions.create(messages=..., tools=...)
Returns:
URL like "http://proxy:9100/{uuid}/v1"
"""
return f"{self.proxy_url}/{self.session_uuid}/v1"
async def _post(self, path: str, json: dict, timeout: int = 300) -> dict:
"""Make a POST request to the proxy."""
url = f"{self.proxy_url}{path}"
async with aiohttp.ClientSession() as session:
async with session.post(
url, json=json, timeout=aiohttp.ClientTimeout(total=timeout)
) as resp:
data = await resp.json()
if resp.status != 200:
error_msg = data.get("error", {}).get("message", str(data))
raise RuntimeError(
f"Proxy request failed ({resp.status}): {error_msg}"
)
return data
async def _get(self, path: str, timeout: int = 30) -> dict:
"""Make a GET request to the proxy."""
url = f"{self.proxy_url}{path}"
async with aiohttp.ClientSession() as session:
async with session.get(
url, timeout=aiohttp.ClientTimeout(total=timeout)
) as resp:
data = await resp.json()
if resp.status != 200:
error_msg = data.get("error", {}).get("message", str(data))
raise RuntimeError(
f"Proxy request failed ({resp.status}): {error_msg}"
)
return data
async def _delete(self, path: str, timeout: int = 30) -> dict:
"""Make a DELETE request to the proxy."""
url = f"{self.proxy_url}{path}"
async with aiohttp.ClientSession() as session:
async with session.delete(
url, timeout=aiohttp.ClientTimeout(total=timeout)
) as resp:
data = await resp.json()
return data
async def chat_completion(self, **kwargs) -> ChatCompletion:
"""Send a chat completion request through the proxy.
Same interface as ManagedServer.chat_completion().
The proxy handles template rendering, tool call parsing,
and token/logprob tracking.
"""
# Convert messages to serializable format
messages = kwargs.get("messages", [])
serialized_messages = []
for msg in messages:
if isinstance(msg, dict):
serialized_messages.append(msg)
else:
serialized_messages.append(dict(msg))
body = {
"messages": serialized_messages,
"max_tokens": kwargs.get("max_tokens", 1024),
"temperature": kwargs.get("temperature", 1.0),
"n": kwargs.get("n", 1),
}
if kwargs.get("stop"):
body["stop"] = kwargs["stop"]
if kwargs.get("tools"):
body["tools"] = kwargs["tools"]
if kwargs.get("tool_choice") is not None:
body["tool_choice"] = kwargs["tool_choice"]
data = await self._post(f"/{self.session_uuid}/v1/chat/completions", json=body)
# Reconstruct ChatCompletion from proxy response
choices = []
for choice_data in data.get("choices", []):
msg = choice_data.get("message", {})
choice = Choice(
finish_reason=choice_data.get("finish_reason", "stop"),
index=choice_data.get("index", 0),
message=ChatCompletionMessage(
content=msg.get("content"),
role=msg.get("role", "assistant"),
),
)
choices.append(choice)
return ChatCompletion(
id=data.get("id", str(uuid_lib.uuid4())),
created=data.get("created", int(time.time())),
model=data.get("model", self.model_name),
object="chat.completion",
choices=choices,
)
async def completion(self, **kwargs) -> Completion:
"""Send a completion request through the proxy.
Note: the proxy's chat/completions endpoint is the primary interface.
For raw completions, the proxy renders the prompt via chat template
internally. If you're calling this, you probably want chat_completion()
instead.
"""
# For completion() calls, we'd need a /completions endpoint on the proxy.
# Currently the proxy only exposes chat/completions. For now, raise
# a clear error.
raise NotImplementedError(
"ProxyManagedServer.completion() is not supported. "
"Use chat_completion() instead — the proxy handles template "
"rendering internally."
)
def get_state(self) -> Dict[str, Any]:
"""Get the current state synchronously from cache.
Call fetch_state() first to populate from the proxy, or use
the nodes returned by this method after a chat_completion() call.
Returns:
Dict with 'nodes': List[SequenceNode]
"""
if self._cached_nodes is not None:
return {"nodes": self._cached_nodes}
return {"nodes": []}
async def fetch_state(self) -> Dict[str, Any]:
"""Fetch current state from the proxy (async).
Returns:
Dict with 'nodes': List[SequenceNode]
"""
data = await self._get(f"/{self.session_uuid}/nodes")
nodes = []
for node_data in data.get("nodes", []):
nodes.append(SequenceNode(**node_data))
self._cached_nodes = nodes
return {"nodes": nodes}
def reset(self):
"""Clear cached state. The actual cleanup happens in __aexit__."""
self._cached_nodes = None
async def cleanup(self):
"""Delete the session on the proxy."""
try:
await self._delete(f"/{self.session_uuid}")
except Exception as e:
logger.warning(f"Failed to cleanup proxy session {self.session_uuid}: {e}")
# -- context manager support --
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
# Fetch final state before cleanup so callers can still access nodes
try:
await self.fetch_state()
except Exception:
pass
await self.cleanup()
async def create_proxy_session(
proxy_url: str,
base_url: Optional[str] = None,
tool_parser: str = "hermes",
track_tree: bool = False,
model_name: str = "unknown",
) -> ProxyManagedServer:
"""Create a new session on the proxy and return a ProxyManagedServer.
Args:
proxy_url: Base URL of the proxy (e.g. "http://localhost:9100").
base_url: Pin to a specific backend server. In production, this
comes from the atropos API's server allocation.
tool_parser: vLLM tool parser name (default: "hermes").
track_tree: Whether to use tree mode for tracking.
model_name: Model name for response objects.
Returns:
ProxyManagedServer instance ready to use.
"""
body = {
"tool_parser": tool_parser,
"track_tree": track_tree,
}
if base_url:
body["base_url"] = base_url
async with aiohttp.ClientSession() as session:
async with session.post(
f"{proxy_url.rstrip('/')}/sessions/create",
json=body,
timeout=aiohttp.ClientTimeout(total=30),
) as resp:
data = await resp.json()
if resp.status != 200:
error_msg = data.get("error", {}).get("message", str(data))
raise RuntimeError(f"Failed to create proxy session: {error_msg}")
return ProxyManagedServer(
proxy_url=proxy_url,
session_uuid=data["uuid"],
model_name=data.get("model_name", model_name),
base_url=data.get("base_url"),
)

View file

@ -41,6 +41,14 @@ class ServerManagerConfig(BaseModel):
"This is to help load balance servers."
),
)
proxy_url: Optional[str] = Field(
default=None,
description=(
"URL of the ManagedServer OpenAI proxy (e.g. 'http://localhost:9100'). "
"When set, managed_server(use_proxy=True) routes through this proxy. "
"Can also be set via ATROPOS_PROXY_URL environment variable."
),
)
class ServerManager:
@ -52,9 +60,18 @@ class ServerManager:
testing=False,
max_n_completions=8,
reasoning_config: Optional[ReasoningConfig] = None,
proxy_url: Optional[str] = None,
use_proxy: bool = False,
tool_parser: Optional[str] = None,
):
self.max_n_completions = max_n_completions
self.reasoning_config = reasoning_config
# Proxy config — when use_proxy=True, managed_server() routes
# through the proxy HTTP API instead of creating in-process instances
self.proxy_url = proxy_url or os.environ.get("ATROPOS_PROXY_URL")
self.use_proxy = use_proxy or bool(self.proxy_url)
# Tool parser — passed to ManagedServer for tool call support
self.tool_parser = tool_parser
# First we check to see if it's the base server class, and if so, we need to select the appropriate server class
# You can't use type() to check if it's the base server class, because it's an abstract class, it'll appear as
# an ABCMeta, not what you're expecting.
@ -364,8 +381,10 @@ class ServerManager:
@asynccontextmanager
async def managed_server(
self, tokenizer=None
) -> AsyncGenerator[Union[ManagedServer, DummyManagedServer], None]:
self,
tokenizer=None,
base_url: Optional[str] = None,
):
"""
Context manager that provides a ManagedServer instance.
@ -379,25 +398,63 @@ class ServerManager:
Args:
tokenizer: Optional tokenizer to use. If not provided, will attempt to
extract from server or create from model name.
base_url: Pin the session to a specific backend server by its base_url.
In production, this comes from the atropos API's server allocation.
Yields:
ManagedServer (or DummyManagedServer for OpenAI) instance wrapping
the selected server
ManagedServer, DummyManagedServer, or ProxyManagedServer instance
Raises:
NotImplementedError: If using OpenAI server without the
ATROPOS_ALLOW_DUMMY_MANAGED_SERVER env var set.
Example:
# In-process (default):
async with server_manager.managed_server() as managed:
response = await managed.chat_completion(
messages=[{"role": "user", "content": "Hello"}],
n=2
n=2, tools=[...], tool_choice="auto",
)
state = managed.get_state()
# Process state...
# State is automatically cleared when exiting context
# Via proxy (configured at init with proxy_url= or ATROPOS_PROXY_URL):
# server_manager = ServerManager(configs, proxy_url="http://proxy:9100")
async with server_manager.managed_server() as managed:
response = await managed.chat_completion(...)
api_url = managed.get_url() # for external apps
"""
# -- Proxy path --
if self.use_proxy:
resolved_proxy_url = self.proxy_url
if not resolved_proxy_url:
raise ValueError(
"use_proxy=True requires proxy_url or ATROPOS_PROXY_URL env var "
"to be set at ServerManager init"
)
from atroposlib.envs.server_handling.proxy_client import (
create_proxy_session,
)
model_name = (
self.servers[0].config.model_name
if self.servers and hasattr(self.servers[0], "config")
else "unknown"
)
proxy_managed = await create_proxy_session(
proxy_url=resolved_proxy_url,
base_url=base_url,
tool_parser=self.tool_parser or "hermes",
model_name=model_name,
)
try:
yield proxy_managed
finally:
await proxy_managed.cleanup()
return
# -- In-process path (existing logic) --
most_available_server = 0
most_available_server_num_slots = -1
for i, server in enumerate(self.servers):
@ -441,7 +498,11 @@ class ServerManager:
finally:
managed.reset()
else:
managed = ManagedServer(server=selected_server, tokenizer=tokenizer)
managed = ManagedServer(
server=selected_server,
tokenizer=tokenizer,
tool_parser=self.tool_parser,
)
try:
yield managed

View file

@ -0,0 +1,251 @@
"""
Bidirectional translation between OpenAI tool_calls format and raw model text.
Uses vLLM's tool parsers directly — same parsing logic as vLLM's chat
completions endpoint. Supports 30+ model-specific parsers (hermes, llama,
mistral, deepseek, qwen3, etc.) via ToolParserManager.
Outbound (model client): raw text with <tool_call> tags structured OpenAI tool_calls
Inbound (client model): OpenAI messages with tool roles raw text for chat template
"""
import json
import logging
import warnings
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
# vLLM is optional — tool call parsing degrades gracefully without it
try:
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import (
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.tool_parsers.abstract_tool_parser import ToolParser, ToolParserManager
VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False
ChatCompletionRequest = None
ExtractedToolCallInformation = None
FunctionCall = None
ToolCall = None
ToolParser = None
ToolParserManager = None
class ToolCallTranslator:
"""Bidirectional translation between OpenAI tool_calls and raw model text.
Uses vLLM's tool parsers directly for outbound parsing (model output →
OpenAI format). Maintains a lookup table mapping tool_call IDs back to
the raw text that produced them, for reconstructing messages when the
caller sends tool results back.
The ManagedServer always stores raw text this translator only
transforms what goes over the HTTP wire.
"""
def __init__(self, tokenizer: Any, parser_name: str = "hermes"):
"""
Args:
tokenizer: HuggingFace tokenizer instance.
parser_name: Name of the vLLM tool parser to use.
Available: hermes, llama3_json, llama4_json, mistral,
deepseek_v3, qwen3_coder, granite, internlm, etc.
See ToolParserManager.list_registered() for full list.
Raises:
Warning if vLLM is not installed tool call parsing will be
disabled but the translator can still handle message conversion
and decoding.
"""
self.tokenizer = tokenizer
self.parser_name = parser_name
self.parser = None
if not VLLM_AVAILABLE:
warnings.warn(
"vLLM is not installed — tool call parsing is disabled. "
"Install vllm to enable structured tool call extraction from "
"model output (pip install vllm). The translator will still "
"handle message conversion and template rendering, but "
"parse_model_output() will return raw text without parsing.",
stacklevel=2,
)
else:
ParserClass = ToolParserManager.get_tool_parser(parser_name)
self.parser = ParserClass(tokenizer)
# tool_call_id → raw text segment that produced it.
# Used to reconstruct assistant messages when the caller sends
# follow-up messages with tool results.
self.call_id_to_raw_text: Dict[str, str] = {}
def parse_model_output(
self,
raw_text: str,
tool_choice: Optional[str] = None,
tools: Optional[List[dict]] = None,
) -> Tuple[Optional[str], Optional[List[ToolCall]], str]:
"""Parse raw model output into OpenAI response fields.
Args:
raw_text: Raw model output text (may contain <tool_call> tags etc.)
tool_choice: The tool_choice from the request. If "none", skip
parsing entirely. If "required", force tool_calls interpretation.
tools: Tool definitions from the request (needed for vLLM request obj).
Returns:
Tuple of (content, tool_calls, finish_reason):
content: Text content (before tool calls, or full text if no tools)
tool_calls: List of ToolCall objects, or None
finish_reason: "stop", "tool_calls", or "length"
"""
# If tool_choice is "none" or no tools defined, don't even try parsing
if tool_choice == "none" or not tools:
return raw_text, None, "stop"
# If vLLM isn't available, can't parse — return raw text
if self.parser is None:
return raw_text, None, "stop"
# Build a minimal ChatCompletionRequest for the parser
request = ChatCompletionRequest(
messages=[{"role": "user", "content": ""}],
model="proxy",
tools=tools,
tool_choice=tool_choice,
)
result: ExtractedToolCallInformation = self.parser.extract_tool_calls(
raw_text, request
)
if result.tools_called and result.tool_calls:
# Store mapping for reverse direction
for tc in result.tool_calls:
self.call_id_to_raw_text[tc.id] = raw_text
return result.content, result.tool_calls, "tool_calls"
else:
return raw_text, None, "stop"
def reconstruct_raw_text_from_tool_calls(self, tool_calls: List[dict]) -> str:
"""Reconstruct raw model text from OpenAI-format tool_calls.
When a caller sends an assistant message with tool_calls (e.g. in a
multi-turn conversation), we need to convert it back to the raw text
the model actually generated so the chat template produces the right
tokens.
First tries the lookup table (exact reconstruction). Falls back to
rebuilding from the structured data (best-effort).
Args:
tool_calls: List of tool call dicts from OpenAI format, each with
'id', 'type', 'function' (containing 'name' and 'arguments').
Returns:
Raw text with <tool_call> tags (or whatever format the parser uses).
"""
if not tool_calls:
return ""
# Try lookup table first — if the first call's ID is in the table,
# we can return the exact raw text
first_id = tool_calls[0].get("id", "")
if first_id in self.call_id_to_raw_text:
return self.call_id_to_raw_text[first_id]
# Fallback: reconstruct from structured data
# This is best-effort — the exact formatting may differ from what
# the model originally generated, but it's close enough for the
# chat template to handle.
# TODO: make the tag format configurable per parser (hermes uses
# <tool_call>, others may differ)
parts = []
for tc in tool_calls:
func = tc.get("function", {})
name = func.get("name", "")
arguments = func.get("arguments", "{}")
# Parse arguments string back to dict for clean formatting
try:
args_dict = (
json.loads(arguments) if isinstance(arguments, str) else arguments
)
except (json.JSONDecodeError, TypeError):
args_dict = arguments
call_obj = {"name": name, "arguments": args_dict}
parts.append(f"<tool_call>{json.dumps(call_obj)}</tool_call>")
return "\n".join(parts)
def convert_messages_for_template(
self, messages: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Convert OpenAI messages to raw format suitable for apply_chat_template.
Handles three cases:
1. Regular messages (user, system): pass through unchanged
2. Assistant messages with tool_calls: replace with raw text content
3. Tool result messages (role=tool): pass through (chat template handles them)
Args:
messages: OpenAI-format messages list.
Returns:
Messages list with tool_call assistant messages reconstructed to raw text.
"""
converted = []
for msg in messages:
role = msg.get("role", "")
if role == "assistant" and msg.get("tool_calls"):
# Reconstruct raw text from tool_calls
raw_text = self.reconstruct_raw_text_from_tool_calls(msg["tool_calls"])
# Prepend any content that came before the tool calls
content = msg.get("content") or ""
if content:
raw_text = content + "\n" + raw_text
converted.append(
{
"role": "assistant",
"content": raw_text,
}
)
else:
# Pass through as-is (user, system, tool, regular assistant)
converted.append(msg)
return converted
def decode_with_tool_awareness(
self,
token_ids: List[int],
has_tools: bool = False,
) -> str:
"""Decode token IDs, preserving tool call tags when tools are active.
Some tokenizers mark <tool_call> as a special token. If we decode with
skip_special_tokens=True (the default), the tags vanish before the
parser ever sees them. When tools are in play, we decode with
skip_special_tokens=False to preserve them.
Args:
token_ids: Token IDs to decode.
has_tools: Whether tools are active for this request.
Returns:
Decoded text string.
"""
return self.tokenizer.decode(
token_ids,
skip_special_tokens=not has_tools,
)