mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
add tool call parsing based on vllm impl and an openai server endpoint
This commit is contained in:
parent
887a94374c
commit
add42a2afb
11 changed files with 3370 additions and 34 deletions
|
|
@ -62,6 +62,7 @@ class ManagedServer:
|
|||
server: APIServer,
|
||||
tokenizer: Optional[Any] = None,
|
||||
track_tree: bool = False,
|
||||
tool_parser: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the managed server.
|
||||
|
|
@ -73,10 +74,17 @@ class ManagedServer:
|
|||
track_tree: If True, maintains a tree structure with parent-child links
|
||||
(for multi-turn RL with per-step advantages). If False (default),
|
||||
maintains a simple list of current nodes that updates in-place.
|
||||
tool_parser: Optional vLLM tool parser name (e.g. "hermes", "llama3_json",
|
||||
"mistral", etc.). If provided, enables tool call support in
|
||||
chat_completion(). The parser handles extraction of structured
|
||||
tool calls from raw model output. See
|
||||
ToolParserManager.list_registered() for available parsers.
|
||||
"""
|
||||
self.server = server
|
||||
self.tokenizer = tokenizer
|
||||
self.track_tree = track_tree
|
||||
self._tool_parser_name = tool_parser
|
||||
self._translator = None # Lazy init
|
||||
|
||||
# Initialize storage based on mode
|
||||
if track_tree:
|
||||
|
|
@ -107,19 +115,57 @@ class ManagedServer:
|
|||
)
|
||||
self.tokenizer = None
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
|
||||
def _get_translator(self):
|
||||
"""Lazily create the ToolCallTranslator when first needed.
|
||||
|
||||
Returns None if tool_parser was not specified or if vLLM is not
|
||||
installed (the translator will warn on creation in that case).
|
||||
"""
|
||||
if self._translator is None and self._tool_parser_name and self.tokenizer:
|
||||
try:
|
||||
from atroposlib.envs.server_handling.tool_call_translator import (
|
||||
ToolCallTranslator,
|
||||
)
|
||||
|
||||
self._translator = ToolCallTranslator(
|
||||
tokenizer=self.tokenizer,
|
||||
parser_name=self._tool_parser_name,
|
||||
)
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
f"Failed to create ToolCallTranslator: {e}. "
|
||||
"Tool call parsing will be disabled.",
|
||||
stacklevel=2,
|
||||
)
|
||||
self._tool_parser_name = None # Don't retry
|
||||
return None
|
||||
return self._translator
|
||||
|
||||
def _convert_messages_to_prompt(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Convert chat messages to prompt text using tokenizer's chat template.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'
|
||||
tools: Optional list of tool definitions (OpenAI format). Passed to
|
||||
apply_chat_template() so the template can inject tool defs
|
||||
into the system prompt.
|
||||
|
||||
Returns:
|
||||
Formatted prompt string
|
||||
"""
|
||||
# If tools are active and we have a translator, convert any assistant
|
||||
# messages with tool_calls back to raw text first
|
||||
if tools and self._get_translator():
|
||||
messages = self._get_translator().convert_messages_for_template(messages)
|
||||
|
||||
if self.tokenizer is None:
|
||||
# Fallback: simple concatenation
|
||||
return "\n".join([f"{m['role']}: {m['content']}" for m in messages])
|
||||
return "\n".join([f"{m['role']}: {m.get('content', '')}" for m in messages])
|
||||
|
||||
if hasattr(self.tokenizer, "apply_chat_template"):
|
||||
# Only add generation prompt if last message is not from assistant
|
||||
|
|
@ -127,13 +173,19 @@ class ManagedServer:
|
|||
len(messages) == 0 or messages[-1].get("role") != "assistant"
|
||||
)
|
||||
|
||||
# Build kwargs
|
||||
template_kwargs = {
|
||||
"tokenize": False,
|
||||
"add_generation_prompt": add_generation_prompt,
|
||||
}
|
||||
if tools:
|
||||
template_kwargs["tools"] = tools
|
||||
|
||||
# Use the tokenizer's chat template
|
||||
return self.tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=add_generation_prompt
|
||||
)
|
||||
return self.tokenizer.apply_chat_template(messages, **template_kwargs)
|
||||
else:
|
||||
# Fallback for tokenizers without chat template
|
||||
return "\n".join([f"{m['role']}: {m['content']}" for m in messages])
|
||||
return "\n".join([f"{m['role']}: {m.get('content', '')}" for m in messages])
|
||||
|
||||
def _debug_requests_enabled(self) -> bool:
|
||||
"""Enable verbose request construction logs with ATROPOS_DEBUG_REQUESTS=1."""
|
||||
|
|
@ -268,15 +320,35 @@ class ManagedServer:
|
|||
Internally converts to prompt, calls tokens_and_logprobs_completion,
|
||||
tracks the sequence, and reconstructs a ChatCompletion response.
|
||||
|
||||
Supports tool calling when a tool_parser was provided at init:
|
||||
- Accepts `tools` and `tool_choice` kwargs
|
||||
- Converts inbound assistant tool_call messages to raw text
|
||||
- Parses outbound model output for tool calls
|
||||
- Returns ChatCompletion with proper tool_calls in choices
|
||||
- Preserves raw text in tracked nodes (tool parsing is response-only)
|
||||
|
||||
Args:
|
||||
**kwargs: Standard chat completion kwargs (messages, n, etc.)
|
||||
**kwargs: Standard chat completion kwargs (messages, n, max_tokens,
|
||||
temperature, tools, tool_choice, etc.)
|
||||
|
||||
Returns:
|
||||
ChatCompletion response
|
||||
ChatCompletion response (with tool_calls if detected)
|
||||
"""
|
||||
# Get input text
|
||||
# Extract tool-related kwargs
|
||||
tools = kwargs.pop("tools", None)
|
||||
tool_choice = kwargs.pop("tool_choice", None)
|
||||
has_tools = bool(tools) and self._get_translator() is not None
|
||||
|
||||
# Default tool_choice to "auto" if tools provided
|
||||
if has_tools and tool_choice is None:
|
||||
tool_choice = "auto"
|
||||
|
||||
# Get input text — passes tools for template rendering and
|
||||
# handles reconstruction of inbound tool_call messages
|
||||
messages = kwargs.get("messages", [])
|
||||
prompt = self._convert_messages_to_prompt(messages)
|
||||
prompt = self._convert_messages_to_prompt(
|
||||
messages, tools=tools if has_tools else None
|
||||
)
|
||||
|
||||
# Handle parent node and extending logic based on mode
|
||||
if self.track_tree:
|
||||
|
|
@ -296,11 +368,12 @@ class ManagedServer:
|
|||
msg_count = len(messages)
|
||||
prompt_preview = prompt.replace("\n", "\\n")[:600]
|
||||
logger.debug(
|
||||
"[ATROPOS_REQ_DEBUG] chat_completion messages=%s n=%s max_tokens=%s temperature=%s",
|
||||
"[ATROPOS_REQ_DEBUG] chat_completion messages=%s n=%s max_tokens=%s temperature=%s tools=%s",
|
||||
msg_count,
|
||||
completion_kwargs.get("n"),
|
||||
completion_kwargs.get("max_tokens"),
|
||||
completion_kwargs.get("temperature"),
|
||||
bool(tools),
|
||||
)
|
||||
logger.debug("[ATROPOS_REQ_DEBUG] prompt_preview=%r", prompt_preview)
|
||||
|
||||
|
|
@ -336,15 +409,18 @@ class ManagedServer:
|
|||
else:
|
||||
finish_reason = finish_reason_raw
|
||||
|
||||
# Decode completion text
|
||||
# Decode completion text — use skip_special_tokens=False when
|
||||
# tools are active so <tool_call> tags aren't stripped
|
||||
if self.tokenizer is not None:
|
||||
completion_text = self.tokenizer.decode(
|
||||
output_tokens, skip_special_tokens=True
|
||||
output_tokens,
|
||||
skip_special_tokens=not has_tools,
|
||||
)
|
||||
else:
|
||||
completion_text = "".join([chr(t) for t in output_tokens if t > 31])
|
||||
|
||||
# Create and store sequence node
|
||||
# Create and store sequence node — always uses the raw text,
|
||||
# tool parsing only affects the ChatCompletion response
|
||||
node = self._create_sequence_node(
|
||||
input_text=prompt,
|
||||
parent_node=parent_node,
|
||||
|
|
@ -373,14 +449,50 @@ class ManagedServer:
|
|||
# New context - append to list
|
||||
self.current_nodes.append(node)
|
||||
|
||||
# Parse tool calls from raw output if tools are active
|
||||
tool_calls_parsed = None
|
||||
content_for_response = completion_text
|
||||
if has_tools and tool_choice != "none":
|
||||
translator = self._get_translator()
|
||||
content_for_response, tool_calls_parsed, finish_reason = (
|
||||
translator.parse_model_output(
|
||||
raw_text=completion_text,
|
||||
tool_choice=(
|
||||
tool_choice if isinstance(tool_choice, str) else "auto"
|
||||
),
|
||||
tools=tools,
|
||||
)
|
||||
)
|
||||
|
||||
# Build choice
|
||||
message_kwargs = {
|
||||
"content": content_for_response,
|
||||
"role": "assistant",
|
||||
}
|
||||
# Note: openai's ChatCompletionMessage model handles tool_calls
|
||||
# but we can't pass them through the constructor easily. We'll
|
||||
# attach them after construction if needed.
|
||||
choice = Choice(
|
||||
finish_reason=finish_reason,
|
||||
index=i,
|
||||
message=ChatCompletionMessage(
|
||||
content=completion_text, role="assistant"
|
||||
),
|
||||
message=ChatCompletionMessage(**message_kwargs),
|
||||
)
|
||||
|
||||
# Attach tool_calls to the message if present
|
||||
if tool_calls_parsed:
|
||||
choice.message.tool_calls = [
|
||||
# Convert vLLM ToolCall to openai ToolCall format
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
},
|
||||
}
|
||||
for tc in tool_calls_parsed
|
||||
]
|
||||
|
||||
choices.append(choice)
|
||||
|
||||
# Construct ChatCompletion response
|
||||
|
|
|
|||
623
atroposlib/envs/server_handling/managed_server_proxy.py
Normal file
623
atroposlib/envs/server_handling/managed_server_proxy.py
Normal file
|
|
@ -0,0 +1,623 @@
|
|||
"""
|
||||
OpenAI-compatible chat completions proxy over ManagedServer.
|
||||
|
||||
Exposes /{uuid}/v1/chat/completions and related endpoints so that external
|
||||
environment microservices can interact with ManagedServer via standard
|
||||
OpenAI API during multi-step rollouts.
|
||||
|
||||
Each UUID maps to a session containing a ManagedServer instance. Tool call
|
||||
parsing uses vLLM's parsers directly. The ManagedServer always stores raw
|
||||
text — tool call translation only affects the HTTP wire format.
|
||||
|
||||
Uses ServerManager for routing across multiple backend servers (load balancing,
|
||||
health checks, etc.) — same infrastructure as the rest of atropos.
|
||||
|
||||
Usage:
|
||||
# Standalone with JSON config
|
||||
python -m atroposlib.envs.server_handling.managed_server_proxy \\
|
||||
--config servers.json \\
|
||||
--port 9100
|
||||
|
||||
# Or mount into existing FastAPI app
|
||||
from atroposlib.envs.server_handling.managed_server_proxy import create_app
|
||||
app = create_app(server_manager, tokenizer, model_name)
|
||||
|
||||
servers.json example:
|
||||
{
|
||||
"model_name": "Qwen/Qwen3-4B",
|
||||
"servers": [
|
||||
{"base_url": "http://gpu1:8000/v1", "server_type": "vllm", "api_key": ""},
|
||||
{"base_url": "http://gpu2:8000/v1", "server_type": "vllm", "api_key": ""}
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid as uuid_lib
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from atroposlib.envs.server_handling.managed_server import (
|
||||
DummyManagedServer,
|
||||
ManagedServer,
|
||||
)
|
||||
from atroposlib.envs.server_handling.openai_server import OpenAIServer
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
|
||||
from atroposlib.envs.server_handling.server_manager import ServerManager
|
||||
from atroposlib.envs.server_handling.tool_call_translator import ToolCallTranslator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request / Response models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[List[dict]] = None
|
||||
tool_call_id: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class ChatCompletionProxyRequest(BaseModel):
|
||||
messages: List[ChatMessage]
|
||||
max_tokens: int = 1024
|
||||
temperature: float = 1.0
|
||||
n: int = 1
|
||||
stop: Optional[List[str]] = None
|
||||
tools: Optional[List[dict]] = None
|
||||
tool_choice: Optional[Any] = None # "auto", "none", "required", or dict
|
||||
# TODO: top_p, frequency_penalty, presence_penalty, seed, response_format,
|
||||
# logprobs, top_logprobs — pass through to backend when implemented
|
||||
|
||||
|
||||
class SessionCreateRequest(BaseModel):
|
||||
tool_parser: str = "hermes"
|
||||
track_tree: bool = False
|
||||
# Pin to a specific backend server by its base_url. In production,
|
||||
# the caller gets their assigned server from the atropos API and
|
||||
# passes it here. If omitted, falls back to picking the server
|
||||
# with the most open semaphore slots (fine for dev/testing).
|
||||
base_url: Optional[str] = None
|
||||
|
||||
|
||||
class SessionCreateResponse(BaseModel):
|
||||
uuid: str
|
||||
model_name: str
|
||||
tool_parser: str
|
||||
base_url: Optional[str] = None # Which backend was selected
|
||||
created_at: float
|
||||
|
||||
|
||||
class RenderResponse(BaseModel):
|
||||
prompt_text: str
|
||||
token_ids: List[int]
|
||||
num_tokens: int
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session state
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionState:
|
||||
uuid: str
|
||||
managed_server: ManagedServer
|
||||
translator: ToolCallTranslator
|
||||
model_name: str
|
||||
base_url: Optional[str] = None # Which backend server this session is pinned to
|
||||
created_at: float = field(default_factory=time.time)
|
||||
last_accessed: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenAI error format
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def openai_error(
|
||||
status_code: int, message: str, error_type: str = "invalid_request_error"
|
||||
) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content={
|
||||
"error": {
|
||||
"message": message,
|
||||
"type": error_type,
|
||||
"code": status_code,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# App factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_app(
|
||||
server_manager: ServerManager,
|
||||
tokenizer: Any,
|
||||
model_name: str = "unknown",
|
||||
) -> FastAPI:
|
||||
"""Create the proxy FastAPI app.
|
||||
|
||||
Args:
|
||||
server_manager: ServerManager instance managing one or more backend
|
||||
servers (VLLMServer, SGLangServer, etc.). Used to pick the most
|
||||
available server when creating sessions.
|
||||
tokenizer: HuggingFace tokenizer for the model.
|
||||
model_name: Model name to report in responses.
|
||||
|
||||
Returns:
|
||||
FastAPI app with all endpoints registered.
|
||||
"""
|
||||
|
||||
app = FastAPI(title="ManagedServer OpenAI Proxy")
|
||||
sessions: Dict[str, SessionState] = {}
|
||||
|
||||
# -- helpers --
|
||||
|
||||
def _get_session(session_uuid: str) -> SessionState:
|
||||
session = sessions.get(session_uuid)
|
||||
if session is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Session {session_uuid} not found"
|
||||
)
|
||||
session.last_accessed = time.time()
|
||||
return session
|
||||
|
||||
def _get_server_base_url(server) -> Optional[str]:
|
||||
"""Get the base_url from a server's config, if available."""
|
||||
if hasattr(server, "config") and hasattr(server.config, "base_url"):
|
||||
return server.config.base_url
|
||||
return None
|
||||
|
||||
def _select_server(base_url: Optional[str] = None):
|
||||
"""Pick a server from the manager.
|
||||
|
||||
Args:
|
||||
base_url: If provided, pin to the server with this base_url.
|
||||
Raises 404 if no server matches.
|
||||
If None, picks the most available server (mirrors
|
||||
ServerManager.managed_server() logic).
|
||||
|
||||
Returns:
|
||||
Selected APIServer instance.
|
||||
"""
|
||||
if base_url is not None:
|
||||
# Pin to specific server by base_url
|
||||
for server in server_manager.servers:
|
||||
server_url = _get_server_base_url(server)
|
||||
if server_url and server_url.rstrip("/") == base_url.rstrip("/"):
|
||||
return server
|
||||
# No match — list available URLs in error
|
||||
available = [
|
||||
_get_server_base_url(s)
|
||||
for s in server_manager.servers
|
||||
if _get_server_base_url(s)
|
||||
]
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No server with base_url '{base_url}'. Available: {available}",
|
||||
)
|
||||
|
||||
# Auto-select most available
|
||||
most_available_idx = 0
|
||||
most_available_slots = -1
|
||||
for i, server in enumerate(server_manager.servers):
|
||||
if not server.server_healthy:
|
||||
continue
|
||||
if server.sem._value > most_available_slots:
|
||||
most_available_idx = i
|
||||
most_available_slots = server.sem._value
|
||||
return server_manager.servers[most_available_idx]
|
||||
|
||||
def _render_prompt(
|
||||
messages: List[Dict[str, Any]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
translator: Optional[ToolCallTranslator] = None,
|
||||
) -> str:
|
||||
"""Render messages to prompt text via chat template.
|
||||
|
||||
If a translator is provided, converts OpenAI tool_call messages
|
||||
back to raw text first.
|
||||
"""
|
||||
# Convert messages to dicts
|
||||
msg_dicts = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, BaseModel):
|
||||
msg_dicts.append(msg.model_dump(exclude_none=True))
|
||||
else:
|
||||
msg_dicts.append(msg)
|
||||
|
||||
# Reconstruct raw text for assistant tool_call messages
|
||||
if translator:
|
||||
msg_dicts = translator.convert_messages_for_template(msg_dicts)
|
||||
|
||||
# Build kwargs for apply_chat_template
|
||||
template_kwargs = {
|
||||
"tokenize": False,
|
||||
"add_generation_prompt": True,
|
||||
}
|
||||
if tools:
|
||||
template_kwargs["tools"] = tools
|
||||
|
||||
return tokenizer.apply_chat_template(msg_dicts, **template_kwargs)
|
||||
|
||||
def _build_openai_response(
|
||||
choices_data: List[dict],
|
||||
model: str,
|
||||
) -> dict:
|
||||
"""Build an OpenAI ChatCompletion response dict."""
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid_lib.uuid4().hex[:12]}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": choices_data,
|
||||
}
|
||||
|
||||
# -- endpoints --
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
healthy_servers = sum(1 for s in server_manager.servers if s.server_healthy)
|
||||
return {
|
||||
"status": "ok",
|
||||
"model": model_name,
|
||||
"sessions": len(sessions),
|
||||
"servers": len(server_manager.servers),
|
||||
"healthy_servers": healthy_servers,
|
||||
}
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def list_models():
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": model_name,
|
||||
"object": "model",
|
||||
"created": int(time.time()),
|
||||
"owned_by": "atropos",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
@app.post("/setup")
|
||||
async def setup(request: Request):
|
||||
"""Receive server configuration from ServerManager.
|
||||
|
||||
Accepts the same JSON format as the standalone --config file.
|
||||
Replaces the current server_manager's servers with the new config.
|
||||
Called by ServerManager at startup to push its config to the proxy.
|
||||
|
||||
Body: {"model_name": "...", "servers": [{"base_url": "...", "server_type": "vllm"}, ...]}
|
||||
"""
|
||||
config = await request.json()
|
||||
|
||||
new_configs = []
|
||||
new_model = config.get("model_name", model_name)
|
||||
for srv in config.get("servers", []):
|
||||
new_configs.append(
|
||||
APIServerConfig(
|
||||
model_name=new_model,
|
||||
base_url=srv["base_url"],
|
||||
api_key=srv.get("api_key", ""),
|
||||
server_type=srv.get("server_type", "vllm"),
|
||||
num_max_requests_at_once=srv.get("num_max_requests_at_once", 512),
|
||||
num_requests_for_eval=srv.get("num_requests_for_eval", 64),
|
||||
timeout=srv.get("timeout", 1200),
|
||||
tokenizer_name=config.get("tokenizer_name", "none"),
|
||||
)
|
||||
)
|
||||
|
||||
if new_configs:
|
||||
new_manager = ServerManager(configs=new_configs)
|
||||
server_manager.servers = new_manager.servers
|
||||
logger.info("Setup: replaced servers with %d new configs", len(new_configs))
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"servers": len(server_manager.servers),
|
||||
"model_name": new_model,
|
||||
}
|
||||
|
||||
@app.get("/servers")
|
||||
async def list_servers():
|
||||
"""List available backend servers.
|
||||
|
||||
Useful for discovery/debugging. In production, server allocation
|
||||
is managed by the atropos API — the environment gets told which
|
||||
server to use and passes that base_url to POST /sessions/create.
|
||||
"""
|
||||
server_list = []
|
||||
for i, server in enumerate(server_manager.servers):
|
||||
url = _get_server_base_url(server)
|
||||
healthy = getattr(server, "server_healthy", True)
|
||||
server_list.append(
|
||||
{
|
||||
"index": i,
|
||||
"base_url": url,
|
||||
"healthy": healthy,
|
||||
"model_name": (
|
||||
getattr(server.config, "model_name", model_name)
|
||||
if hasattr(server, "config")
|
||||
else model_name
|
||||
),
|
||||
"server_type": (
|
||||
getattr(server.config, "server_type", "unknown")
|
||||
if hasattr(server, "config")
|
||||
else "unknown"
|
||||
),
|
||||
}
|
||||
)
|
||||
return {"servers": server_list}
|
||||
|
||||
@app.post("/sessions/create", response_model=SessionCreateResponse)
|
||||
async def create_session(request: SessionCreateRequest):
|
||||
session_uuid = str(uuid_lib.uuid4())
|
||||
|
||||
# Pick server — pinned to base_url if specified, otherwise most available
|
||||
selected_server = _select_server(base_url=request.base_url)
|
||||
selected_url = _get_server_base_url(selected_server)
|
||||
|
||||
# Use DummyManagedServer for OpenAI endpoints (no logprobs support)
|
||||
if isinstance(selected_server, OpenAIServer):
|
||||
logger.info(
|
||||
"Session %s using DummyManagedServer (OpenAI endpoint). "
|
||||
"Token IDs and logprobs will be placeholders.",
|
||||
session_uuid,
|
||||
)
|
||||
managed = DummyManagedServer(
|
||||
server=selected_server,
|
||||
tokenizer=tokenizer,
|
||||
track_tree=request.track_tree,
|
||||
)
|
||||
else:
|
||||
managed = ManagedServer(
|
||||
server=selected_server,
|
||||
tokenizer=tokenizer,
|
||||
track_tree=request.track_tree,
|
||||
tool_parser=request.tool_parser,
|
||||
)
|
||||
|
||||
# Translator kept for the render endpoint (prompt preview)
|
||||
translator = ToolCallTranslator(
|
||||
tokenizer=tokenizer,
|
||||
parser_name=request.tool_parser,
|
||||
)
|
||||
session = SessionState(
|
||||
uuid=session_uuid,
|
||||
managed_server=managed,
|
||||
translator=translator,
|
||||
model_name=model_name,
|
||||
base_url=selected_url,
|
||||
)
|
||||
sessions[session_uuid] = session
|
||||
|
||||
return SessionCreateResponse(
|
||||
uuid=session_uuid,
|
||||
model_name=model_name,
|
||||
tool_parser=request.tool_parser,
|
||||
base_url=selected_url,
|
||||
created_at=session.created_at,
|
||||
)
|
||||
|
||||
@app.get("/sessions")
|
||||
async def list_sessions():
|
||||
return {
|
||||
"sessions": [
|
||||
{
|
||||
"uuid": s.uuid,
|
||||
"model_name": s.model_name,
|
||||
"base_url": s.base_url,
|
||||
"created_at": s.created_at,
|
||||
"last_accessed": s.last_accessed,
|
||||
"num_nodes": len(
|
||||
s.managed_server.current_nodes
|
||||
if hasattr(s.managed_server, "current_nodes")
|
||||
else s.managed_server.sequences
|
||||
),
|
||||
}
|
||||
for s in sessions.values()
|
||||
]
|
||||
}
|
||||
|
||||
@app.post("/{session_uuid}/v1/chat/completions")
|
||||
async def chat_completions(session_uuid: str, request: ChatCompletionProxyRequest):
|
||||
session = _get_session(session_uuid)
|
||||
managed = session.managed_server
|
||||
|
||||
if not request.messages:
|
||||
return openai_error(400, "messages must not be empty")
|
||||
|
||||
# Convert pydantic messages to dicts
|
||||
messages = [msg.model_dump(exclude_none=True) for msg in request.messages]
|
||||
|
||||
# Build kwargs — ManagedServer.chat_completion() handles all tool
|
||||
# call logic internally (template rendering, inbound reconstruction,
|
||||
# outbound parsing, skip_special_tokens)
|
||||
completion_kwargs = {
|
||||
"messages": messages,
|
||||
"n": request.n,
|
||||
"max_tokens": request.max_tokens,
|
||||
"temperature": request.temperature,
|
||||
}
|
||||
if request.stop:
|
||||
completion_kwargs["stop"] = request.stop
|
||||
if request.tools:
|
||||
completion_kwargs["tools"] = request.tools
|
||||
if request.tool_choice is not None:
|
||||
completion_kwargs["tool_choice"] = request.tool_choice
|
||||
|
||||
try:
|
||||
result = await managed.chat_completion(**completion_kwargs)
|
||||
except Exception as e:
|
||||
logger.exception("Completion failed")
|
||||
return openai_error(
|
||||
500, f"Completion failed: {e}", error_type="server_error"
|
||||
)
|
||||
|
||||
# Convert ChatCompletion to JSON-serializable response
|
||||
choices = []
|
||||
for choice in result.choices:
|
||||
choice_data = {
|
||||
"index": choice.index,
|
||||
"message": {
|
||||
"role": choice.message.role,
|
||||
"content": choice.message.content,
|
||||
},
|
||||
"finish_reason": choice.finish_reason,
|
||||
}
|
||||
if choice.message.tool_calls:
|
||||
choice_data["message"]["tool_calls"] = choice.message.tool_calls
|
||||
|
||||
choices.append(choice_data)
|
||||
|
||||
return _build_openai_response(choices, model_name)
|
||||
|
||||
@app.post("/{session_uuid}/v1/chat/completions/render")
|
||||
async def render_prompt(session_uuid: str, request: ChatCompletionProxyRequest):
|
||||
session = _get_session(session_uuid)
|
||||
|
||||
try:
|
||||
prompt_text = _render_prompt(
|
||||
messages=request.messages,
|
||||
tools=request.tools,
|
||||
translator=session.translator,
|
||||
)
|
||||
except Exception as e:
|
||||
return openai_error(400, f"Failed to render prompt: {e}")
|
||||
|
||||
token_ids = tokenizer.encode(prompt_text, add_special_tokens=False)
|
||||
|
||||
return RenderResponse(
|
||||
prompt_text=prompt_text,
|
||||
token_ids=token_ids,
|
||||
num_tokens=len(token_ids),
|
||||
)
|
||||
|
||||
@app.get("/{session_uuid}/nodes")
|
||||
async def get_nodes(session_uuid: str):
|
||||
session = _get_session(session_uuid)
|
||||
state = session.managed_server.get_state()
|
||||
|
||||
if session.managed_server.track_tree:
|
||||
nodes = list(state.get("sequences", {}).values())
|
||||
else:
|
||||
nodes = state.get("nodes", [])
|
||||
|
||||
return {
|
||||
"nodes": [node.model_dump() for node in nodes],
|
||||
}
|
||||
|
||||
@app.delete("/{session_uuid}")
|
||||
async def delete_session(session_uuid: str):
|
||||
session = sessions.pop(session_uuid, None)
|
||||
if session is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Session {session_uuid} not found"
|
||||
)
|
||||
session.managed_server.reset()
|
||||
return {"status": "deleted", "uuid": session_uuid}
|
||||
|
||||
# -- exception handler for OpenAI-style errors --
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(request: Request, exc: HTTPException):
|
||||
return openai_error(exc.status_code, exc.detail)
|
||||
|
||||
# NOTE on concurrent access: each UUID is meant to represent a single
|
||||
# rollout session used by one caller at a time. If you send concurrent
|
||||
# requests to the same UUID, the ManagedServer's node extension logic
|
||||
# may get confused because it does prefix matching on current_nodes.
|
||||
# Don't do that. UUIDs are cheap, make a new one.
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Standalone entrypoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import uvicorn
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
parser = argparse.ArgumentParser(description="ManagedServer OpenAI Proxy")
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
required=True,
|
||||
help="Path to JSON config file with server definitions",
|
||||
)
|
||||
parser.add_argument("--port", type=int, default=9100, help="Proxy port")
|
||||
parser.add_argument("--host", default="0.0.0.0", help="Proxy host")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load config
|
||||
with open(args.config) as f:
|
||||
config = json.load(f)
|
||||
|
||||
model_name = config["model_name"]
|
||||
tokenizer = AutoTokenizer.from_pretrained(config.get("tokenizer_name", model_name))
|
||||
|
||||
# Build APIServerConfigs from the JSON
|
||||
server_configs = []
|
||||
for srv in config["servers"]:
|
||||
server_configs.append(
|
||||
APIServerConfig(
|
||||
model_name=model_name,
|
||||
base_url=srv["base_url"],
|
||||
api_key=srv.get("api_key", ""),
|
||||
server_type=srv.get("server_type", "vllm"),
|
||||
num_max_requests_at_once=srv.get("num_max_requests_at_once", 512),
|
||||
num_requests_for_eval=srv.get("num_requests_for_eval", 64),
|
||||
timeout=srv.get("timeout", 1200),
|
||||
tokenizer_name=config.get("tokenizer_name", "none"),
|
||||
)
|
||||
)
|
||||
|
||||
server_manager = ServerManager(configs=server_configs)
|
||||
|
||||
app = create_app(
|
||||
server_manager=server_manager,
|
||||
tokenizer=tokenizer,
|
||||
model_name=model_name,
|
||||
)
|
||||
|
||||
print(f"Starting ManagedServer OpenAI Proxy on {args.host}:{args.port}")
|
||||
print(f" Model: {model_name}")
|
||||
print(f" Backends: {len(server_configs)} server(s)")
|
||||
for i, cfg in enumerate(server_configs):
|
||||
print(f" [{i}] {cfg.base_url} ({cfg.server_type})")
|
||||
print()
|
||||
print("Endpoints:")
|
||||
print(" POST /sessions/create")
|
||||
print(" POST /{uuid}/v1/chat/completions")
|
||||
print(" POST /{uuid}/v1/chat/completions/render")
|
||||
print(" GET /{uuid}/nodes")
|
||||
print(" DELETE /{uuid}")
|
||||
print(" GET /sessions")
|
||||
print(" GET /v1/models")
|
||||
print(" GET /health")
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
292
atroposlib/envs/server_handling/proxy_client.py
Normal file
292
atroposlib/envs/server_handling/proxy_client.py
Normal file
|
|
@ -0,0 +1,292 @@
|
|||
"""
|
||||
Client that talks to the ManagedServer OpenAI proxy over HTTP.
|
||||
|
||||
Implements the same interface as ManagedServer so it can be used as a
|
||||
drop-in replacement via ServerManager.managed_server(use_proxy=True).
|
||||
|
||||
The proxy handles all the token tracking, tool call parsing, and sequence
|
||||
management. This client just ferries requests/responses over HTTP and
|
||||
reconstructs the SequenceNode objects from the JSON.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid as uuid_lib
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
from openai.types.chat.chat_completion import (
|
||||
ChatCompletion,
|
||||
ChatCompletionMessage,
|
||||
Choice,
|
||||
)
|
||||
from openai.types.completion import Completion # noqa: F401 — used in type hint
|
||||
|
||||
from atroposlib.envs.server_handling.managed_server import SequenceNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProxyManagedServer:
|
||||
"""Client that talks to the ManagedServer OpenAI proxy.
|
||||
|
||||
Same interface as ManagedServer — chat_completion(), completion(),
|
||||
get_state(), reset(). But instead of doing token tracking in-process,
|
||||
delegates everything to the proxy over HTTP.
|
||||
|
||||
Created by ServerManager.managed_server(use_proxy=True).
|
||||
|
||||
Example:
|
||||
async with server_manager.managed_server(use_proxy=True) as managed:
|
||||
# Same API as regular ManagedServer
|
||||
resp = await managed.chat_completion(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
n=4, max_tokens=100, temperature=1.0,
|
||||
)
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
# Extra: get URL for external apps to use directly
|
||||
url = managed.get_url()
|
||||
# → "http://proxy:9100/{uuid}/v1"
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
proxy_url: str,
|
||||
session_uuid: str,
|
||||
model_name: str = "unknown",
|
||||
base_url: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
proxy_url: Base URL of the proxy (e.g. "http://localhost:9100")
|
||||
session_uuid: UUID of the session on the proxy.
|
||||
model_name: Model name (for response objects).
|
||||
base_url: The backend server this session is pinned to.
|
||||
"""
|
||||
self.proxy_url = proxy_url.rstrip("/")
|
||||
self.session_uuid = session_uuid
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
|
||||
# Cache for nodes (populated by get_state)
|
||||
self._cached_nodes: Optional[List[SequenceNode]] = None
|
||||
|
||||
def get_url(self) -> str:
|
||||
"""Get the OpenAI-compatible API URL for this session.
|
||||
|
||||
External apps can use this URL with any OpenAI client:
|
||||
client = openai.OpenAI(base_url=managed.get_url())
|
||||
client.chat.completions.create(messages=..., tools=...)
|
||||
|
||||
Returns:
|
||||
URL like "http://proxy:9100/{uuid}/v1"
|
||||
"""
|
||||
return f"{self.proxy_url}/{self.session_uuid}/v1"
|
||||
|
||||
async def _post(self, path: str, json: dict, timeout: int = 300) -> dict:
|
||||
"""Make a POST request to the proxy."""
|
||||
url = f"{self.proxy_url}{path}"
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url, json=json, timeout=aiohttp.ClientTimeout(total=timeout)
|
||||
) as resp:
|
||||
data = await resp.json()
|
||||
if resp.status != 200:
|
||||
error_msg = data.get("error", {}).get("message", str(data))
|
||||
raise RuntimeError(
|
||||
f"Proxy request failed ({resp.status}): {error_msg}"
|
||||
)
|
||||
return data
|
||||
|
||||
async def _get(self, path: str, timeout: int = 30) -> dict:
|
||||
"""Make a GET request to the proxy."""
|
||||
url = f"{self.proxy_url}{path}"
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
url, timeout=aiohttp.ClientTimeout(total=timeout)
|
||||
) as resp:
|
||||
data = await resp.json()
|
||||
if resp.status != 200:
|
||||
error_msg = data.get("error", {}).get("message", str(data))
|
||||
raise RuntimeError(
|
||||
f"Proxy request failed ({resp.status}): {error_msg}"
|
||||
)
|
||||
return data
|
||||
|
||||
async def _delete(self, path: str, timeout: int = 30) -> dict:
|
||||
"""Make a DELETE request to the proxy."""
|
||||
url = f"{self.proxy_url}{path}"
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.delete(
|
||||
url, timeout=aiohttp.ClientTimeout(total=timeout)
|
||||
) as resp:
|
||||
data = await resp.json()
|
||||
return data
|
||||
|
||||
async def chat_completion(self, **kwargs) -> ChatCompletion:
|
||||
"""Send a chat completion request through the proxy.
|
||||
|
||||
Same interface as ManagedServer.chat_completion().
|
||||
The proxy handles template rendering, tool call parsing,
|
||||
and token/logprob tracking.
|
||||
"""
|
||||
# Convert messages to serializable format
|
||||
messages = kwargs.get("messages", [])
|
||||
serialized_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, dict):
|
||||
serialized_messages.append(msg)
|
||||
else:
|
||||
serialized_messages.append(dict(msg))
|
||||
|
||||
body = {
|
||||
"messages": serialized_messages,
|
||||
"max_tokens": kwargs.get("max_tokens", 1024),
|
||||
"temperature": kwargs.get("temperature", 1.0),
|
||||
"n": kwargs.get("n", 1),
|
||||
}
|
||||
if kwargs.get("stop"):
|
||||
body["stop"] = kwargs["stop"]
|
||||
if kwargs.get("tools"):
|
||||
body["tools"] = kwargs["tools"]
|
||||
if kwargs.get("tool_choice") is not None:
|
||||
body["tool_choice"] = kwargs["tool_choice"]
|
||||
|
||||
data = await self._post(f"/{self.session_uuid}/v1/chat/completions", json=body)
|
||||
|
||||
# Reconstruct ChatCompletion from proxy response
|
||||
choices = []
|
||||
for choice_data in data.get("choices", []):
|
||||
msg = choice_data.get("message", {})
|
||||
choice = Choice(
|
||||
finish_reason=choice_data.get("finish_reason", "stop"),
|
||||
index=choice_data.get("index", 0),
|
||||
message=ChatCompletionMessage(
|
||||
content=msg.get("content"),
|
||||
role=msg.get("role", "assistant"),
|
||||
),
|
||||
)
|
||||
choices.append(choice)
|
||||
|
||||
return ChatCompletion(
|
||||
id=data.get("id", str(uuid_lib.uuid4())),
|
||||
created=data.get("created", int(time.time())),
|
||||
model=data.get("model", self.model_name),
|
||||
object="chat.completion",
|
||||
choices=choices,
|
||||
)
|
||||
|
||||
async def completion(self, **kwargs) -> Completion:
|
||||
"""Send a completion request through the proxy.
|
||||
|
||||
Note: the proxy's chat/completions endpoint is the primary interface.
|
||||
For raw completions, the proxy renders the prompt via chat template
|
||||
internally. If you're calling this, you probably want chat_completion()
|
||||
instead.
|
||||
"""
|
||||
# For completion() calls, we'd need a /completions endpoint on the proxy.
|
||||
# Currently the proxy only exposes chat/completions. For now, raise
|
||||
# a clear error.
|
||||
raise NotImplementedError(
|
||||
"ProxyManagedServer.completion() is not supported. "
|
||||
"Use chat_completion() instead — the proxy handles template "
|
||||
"rendering internally."
|
||||
)
|
||||
|
||||
def get_state(self) -> Dict[str, Any]:
|
||||
"""Get the current state synchronously from cache.
|
||||
|
||||
Call fetch_state() first to populate from the proxy, or use
|
||||
the nodes returned by this method after a chat_completion() call.
|
||||
|
||||
Returns:
|
||||
Dict with 'nodes': List[SequenceNode]
|
||||
"""
|
||||
if self._cached_nodes is not None:
|
||||
return {"nodes": self._cached_nodes}
|
||||
return {"nodes": []}
|
||||
|
||||
async def fetch_state(self) -> Dict[str, Any]:
|
||||
"""Fetch current state from the proxy (async).
|
||||
|
||||
Returns:
|
||||
Dict with 'nodes': List[SequenceNode]
|
||||
"""
|
||||
data = await self._get(f"/{self.session_uuid}/nodes")
|
||||
nodes = []
|
||||
for node_data in data.get("nodes", []):
|
||||
nodes.append(SequenceNode(**node_data))
|
||||
self._cached_nodes = nodes
|
||||
return {"nodes": nodes}
|
||||
|
||||
def reset(self):
|
||||
"""Clear cached state. The actual cleanup happens in __aexit__."""
|
||||
self._cached_nodes = None
|
||||
|
||||
async def cleanup(self):
|
||||
"""Delete the session on the proxy."""
|
||||
try:
|
||||
await self._delete(f"/{self.session_uuid}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cleanup proxy session {self.session_uuid}: {e}")
|
||||
|
||||
# -- context manager support --
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
# Fetch final state before cleanup so callers can still access nodes
|
||||
try:
|
||||
await self.fetch_state()
|
||||
except Exception:
|
||||
pass
|
||||
await self.cleanup()
|
||||
|
||||
|
||||
async def create_proxy_session(
|
||||
proxy_url: str,
|
||||
base_url: Optional[str] = None,
|
||||
tool_parser: str = "hermes",
|
||||
track_tree: bool = False,
|
||||
model_name: str = "unknown",
|
||||
) -> ProxyManagedServer:
|
||||
"""Create a new session on the proxy and return a ProxyManagedServer.
|
||||
|
||||
Args:
|
||||
proxy_url: Base URL of the proxy (e.g. "http://localhost:9100").
|
||||
base_url: Pin to a specific backend server. In production, this
|
||||
comes from the atropos API's server allocation.
|
||||
tool_parser: vLLM tool parser name (default: "hermes").
|
||||
track_tree: Whether to use tree mode for tracking.
|
||||
model_name: Model name for response objects.
|
||||
|
||||
Returns:
|
||||
ProxyManagedServer instance ready to use.
|
||||
"""
|
||||
body = {
|
||||
"tool_parser": tool_parser,
|
||||
"track_tree": track_tree,
|
||||
}
|
||||
if base_url:
|
||||
body["base_url"] = base_url
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{proxy_url.rstrip('/')}/sessions/create",
|
||||
json=body,
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
) as resp:
|
||||
data = await resp.json()
|
||||
if resp.status != 200:
|
||||
error_msg = data.get("error", {}).get("message", str(data))
|
||||
raise RuntimeError(f"Failed to create proxy session: {error_msg}")
|
||||
|
||||
return ProxyManagedServer(
|
||||
proxy_url=proxy_url,
|
||||
session_uuid=data["uuid"],
|
||||
model_name=data.get("model_name", model_name),
|
||||
base_url=data.get("base_url"),
|
||||
)
|
||||
|
|
@ -41,6 +41,14 @@ class ServerManagerConfig(BaseModel):
|
|||
"This is to help load balance servers."
|
||||
),
|
||||
)
|
||||
proxy_url: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"URL of the ManagedServer OpenAI proxy (e.g. 'http://localhost:9100'). "
|
||||
"When set, managed_server(use_proxy=True) routes through this proxy. "
|
||||
"Can also be set via ATROPOS_PROXY_URL environment variable."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ServerManager:
|
||||
|
|
@ -52,9 +60,18 @@ class ServerManager:
|
|||
testing=False,
|
||||
max_n_completions=8,
|
||||
reasoning_config: Optional[ReasoningConfig] = None,
|
||||
proxy_url: Optional[str] = None,
|
||||
use_proxy: bool = False,
|
||||
tool_parser: Optional[str] = None,
|
||||
):
|
||||
self.max_n_completions = max_n_completions
|
||||
self.reasoning_config = reasoning_config
|
||||
# Proxy config — when use_proxy=True, managed_server() routes
|
||||
# through the proxy HTTP API instead of creating in-process instances
|
||||
self.proxy_url = proxy_url or os.environ.get("ATROPOS_PROXY_URL")
|
||||
self.use_proxy = use_proxy or bool(self.proxy_url)
|
||||
# Tool parser — passed to ManagedServer for tool call support
|
||||
self.tool_parser = tool_parser
|
||||
# First we check to see if it's the base server class, and if so, we need to select the appropriate server class
|
||||
# You can't use type() to check if it's the base server class, because it's an abstract class, it'll appear as
|
||||
# an ABCMeta, not what you're expecting.
|
||||
|
|
@ -364,8 +381,10 @@ class ServerManager:
|
|||
|
||||
@asynccontextmanager
|
||||
async def managed_server(
|
||||
self, tokenizer=None
|
||||
) -> AsyncGenerator[Union[ManagedServer, DummyManagedServer], None]:
|
||||
self,
|
||||
tokenizer=None,
|
||||
base_url: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Context manager that provides a ManagedServer instance.
|
||||
|
||||
|
|
@ -379,25 +398,63 @@ class ServerManager:
|
|||
Args:
|
||||
tokenizer: Optional tokenizer to use. If not provided, will attempt to
|
||||
extract from server or create from model name.
|
||||
base_url: Pin the session to a specific backend server by its base_url.
|
||||
In production, this comes from the atropos API's server allocation.
|
||||
|
||||
Yields:
|
||||
ManagedServer (or DummyManagedServer for OpenAI) instance wrapping
|
||||
the selected server
|
||||
ManagedServer, DummyManagedServer, or ProxyManagedServer instance
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If using OpenAI server without the
|
||||
ATROPOS_ALLOW_DUMMY_MANAGED_SERVER env var set.
|
||||
|
||||
Example:
|
||||
# In-process (default):
|
||||
async with server_manager.managed_server() as managed:
|
||||
response = await managed.chat_completion(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
n=2
|
||||
n=2, tools=[...], tool_choice="auto",
|
||||
)
|
||||
state = managed.get_state()
|
||||
# Process state...
|
||||
# State is automatically cleared when exiting context
|
||||
|
||||
# Via proxy (configured at init with proxy_url= or ATROPOS_PROXY_URL):
|
||||
# server_manager = ServerManager(configs, proxy_url="http://proxy:9100")
|
||||
async with server_manager.managed_server() as managed:
|
||||
response = await managed.chat_completion(...)
|
||||
api_url = managed.get_url() # for external apps
|
||||
"""
|
||||
# -- Proxy path --
|
||||
if self.use_proxy:
|
||||
resolved_proxy_url = self.proxy_url
|
||||
if not resolved_proxy_url:
|
||||
raise ValueError(
|
||||
"use_proxy=True requires proxy_url or ATROPOS_PROXY_URL env var "
|
||||
"to be set at ServerManager init"
|
||||
)
|
||||
|
||||
from atroposlib.envs.server_handling.proxy_client import (
|
||||
create_proxy_session,
|
||||
)
|
||||
|
||||
model_name = (
|
||||
self.servers[0].config.model_name
|
||||
if self.servers and hasattr(self.servers[0], "config")
|
||||
else "unknown"
|
||||
)
|
||||
|
||||
proxy_managed = await create_proxy_session(
|
||||
proxy_url=resolved_proxy_url,
|
||||
base_url=base_url,
|
||||
tool_parser=self.tool_parser or "hermes",
|
||||
model_name=model_name,
|
||||
)
|
||||
try:
|
||||
yield proxy_managed
|
||||
finally:
|
||||
await proxy_managed.cleanup()
|
||||
return
|
||||
|
||||
# -- In-process path (existing logic) --
|
||||
most_available_server = 0
|
||||
most_available_server_num_slots = -1
|
||||
for i, server in enumerate(self.servers):
|
||||
|
|
@ -441,7 +498,11 @@ class ServerManager:
|
|||
finally:
|
||||
managed.reset()
|
||||
else:
|
||||
managed = ManagedServer(server=selected_server, tokenizer=tokenizer)
|
||||
managed = ManagedServer(
|
||||
server=selected_server,
|
||||
tokenizer=tokenizer,
|
||||
tool_parser=self.tool_parser,
|
||||
)
|
||||
|
||||
try:
|
||||
yield managed
|
||||
|
|
|
|||
251
atroposlib/envs/server_handling/tool_call_translator.py
Normal file
251
atroposlib/envs/server_handling/tool_call_translator.py
Normal file
|
|
@ -0,0 +1,251 @@
|
|||
"""
|
||||
Bidirectional translation between OpenAI tool_calls format and raw model text.
|
||||
|
||||
Uses vLLM's tool parsers directly — same parsing logic as vLLM's chat
|
||||
completions endpoint. Supports 30+ model-specific parsers (hermes, llama,
|
||||
mistral, deepseek, qwen3, etc.) via ToolParserManager.
|
||||
|
||||
Outbound (model → client): raw text with <tool_call> tags → structured OpenAI tool_calls
|
||||
Inbound (client → model): OpenAI messages with tool roles → raw text for chat template
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# vLLM is optional — tool call parsing degrades gracefully without it
|
||||
try:
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.tool_parsers.abstract_tool_parser import ToolParser, ToolParserManager
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
VLLM_AVAILABLE = False
|
||||
ChatCompletionRequest = None
|
||||
ExtractedToolCallInformation = None
|
||||
FunctionCall = None
|
||||
ToolCall = None
|
||||
ToolParser = None
|
||||
ToolParserManager = None
|
||||
|
||||
|
||||
class ToolCallTranslator:
|
||||
"""Bidirectional translation between OpenAI tool_calls and raw model text.
|
||||
|
||||
Uses vLLM's tool parsers directly for outbound parsing (model output →
|
||||
OpenAI format). Maintains a lookup table mapping tool_call IDs back to
|
||||
the raw text that produced them, for reconstructing messages when the
|
||||
caller sends tool results back.
|
||||
|
||||
The ManagedServer always stores raw text — this translator only
|
||||
transforms what goes over the HTTP wire.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: Any, parser_name: str = "hermes"):
|
||||
"""
|
||||
Args:
|
||||
tokenizer: HuggingFace tokenizer instance.
|
||||
parser_name: Name of the vLLM tool parser to use.
|
||||
Available: hermes, llama3_json, llama4_json, mistral,
|
||||
deepseek_v3, qwen3_coder, granite, internlm, etc.
|
||||
See ToolParserManager.list_registered() for full list.
|
||||
|
||||
Raises:
|
||||
Warning if vLLM is not installed — tool call parsing will be
|
||||
disabled but the translator can still handle message conversion
|
||||
and decoding.
|
||||
"""
|
||||
self.tokenizer = tokenizer
|
||||
self.parser_name = parser_name
|
||||
self.parser = None
|
||||
|
||||
if not VLLM_AVAILABLE:
|
||||
warnings.warn(
|
||||
"vLLM is not installed — tool call parsing is disabled. "
|
||||
"Install vllm to enable structured tool call extraction from "
|
||||
"model output (pip install vllm). The translator will still "
|
||||
"handle message conversion and template rendering, but "
|
||||
"parse_model_output() will return raw text without parsing.",
|
||||
stacklevel=2,
|
||||
)
|
||||
else:
|
||||
ParserClass = ToolParserManager.get_tool_parser(parser_name)
|
||||
self.parser = ParserClass(tokenizer)
|
||||
|
||||
# tool_call_id → raw text segment that produced it.
|
||||
# Used to reconstruct assistant messages when the caller sends
|
||||
# follow-up messages with tool results.
|
||||
self.call_id_to_raw_text: Dict[str, str] = {}
|
||||
|
||||
def parse_model_output(
|
||||
self,
|
||||
raw_text: str,
|
||||
tool_choice: Optional[str] = None,
|
||||
tools: Optional[List[dict]] = None,
|
||||
) -> Tuple[Optional[str], Optional[List[ToolCall]], str]:
|
||||
"""Parse raw model output into OpenAI response fields.
|
||||
|
||||
Args:
|
||||
raw_text: Raw model output text (may contain <tool_call> tags etc.)
|
||||
tool_choice: The tool_choice from the request. If "none", skip
|
||||
parsing entirely. If "required", force tool_calls interpretation.
|
||||
tools: Tool definitions from the request (needed for vLLM request obj).
|
||||
|
||||
Returns:
|
||||
Tuple of (content, tool_calls, finish_reason):
|
||||
content: Text content (before tool calls, or full text if no tools)
|
||||
tool_calls: List of ToolCall objects, or None
|
||||
finish_reason: "stop", "tool_calls", or "length"
|
||||
"""
|
||||
# If tool_choice is "none" or no tools defined, don't even try parsing
|
||||
if tool_choice == "none" or not tools:
|
||||
return raw_text, None, "stop"
|
||||
|
||||
# If vLLM isn't available, can't parse — return raw text
|
||||
if self.parser is None:
|
||||
return raw_text, None, "stop"
|
||||
|
||||
# Build a minimal ChatCompletionRequest for the parser
|
||||
request = ChatCompletionRequest(
|
||||
messages=[{"role": "user", "content": ""}],
|
||||
model="proxy",
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
result: ExtractedToolCallInformation = self.parser.extract_tool_calls(
|
||||
raw_text, request
|
||||
)
|
||||
|
||||
if result.tools_called and result.tool_calls:
|
||||
# Store mapping for reverse direction
|
||||
for tc in result.tool_calls:
|
||||
self.call_id_to_raw_text[tc.id] = raw_text
|
||||
|
||||
return result.content, result.tool_calls, "tool_calls"
|
||||
else:
|
||||
return raw_text, None, "stop"
|
||||
|
||||
def reconstruct_raw_text_from_tool_calls(self, tool_calls: List[dict]) -> str:
|
||||
"""Reconstruct raw model text from OpenAI-format tool_calls.
|
||||
|
||||
When a caller sends an assistant message with tool_calls (e.g. in a
|
||||
multi-turn conversation), we need to convert it back to the raw text
|
||||
the model actually generated so the chat template produces the right
|
||||
tokens.
|
||||
|
||||
First tries the lookup table (exact reconstruction). Falls back to
|
||||
rebuilding from the structured data (best-effort).
|
||||
|
||||
Args:
|
||||
tool_calls: List of tool call dicts from OpenAI format, each with
|
||||
'id', 'type', 'function' (containing 'name' and 'arguments').
|
||||
|
||||
Returns:
|
||||
Raw text with <tool_call> tags (or whatever format the parser uses).
|
||||
"""
|
||||
if not tool_calls:
|
||||
return ""
|
||||
|
||||
# Try lookup table first — if the first call's ID is in the table,
|
||||
# we can return the exact raw text
|
||||
first_id = tool_calls[0].get("id", "")
|
||||
if first_id in self.call_id_to_raw_text:
|
||||
return self.call_id_to_raw_text[first_id]
|
||||
|
||||
# Fallback: reconstruct from structured data
|
||||
# This is best-effort — the exact formatting may differ from what
|
||||
# the model originally generated, but it's close enough for the
|
||||
# chat template to handle.
|
||||
# TODO: make the tag format configurable per parser (hermes uses
|
||||
# <tool_call>, others may differ)
|
||||
parts = []
|
||||
for tc in tool_calls:
|
||||
func = tc.get("function", {})
|
||||
name = func.get("name", "")
|
||||
arguments = func.get("arguments", "{}")
|
||||
# Parse arguments string back to dict for clean formatting
|
||||
try:
|
||||
args_dict = (
|
||||
json.loads(arguments) if isinstance(arguments, str) else arguments
|
||||
)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
args_dict = arguments
|
||||
call_obj = {"name": name, "arguments": args_dict}
|
||||
parts.append(f"<tool_call>{json.dumps(call_obj)}</tool_call>")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
def convert_messages_for_template(
|
||||
self, messages: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Convert OpenAI messages to raw format suitable for apply_chat_template.
|
||||
|
||||
Handles three cases:
|
||||
1. Regular messages (user, system): pass through unchanged
|
||||
2. Assistant messages with tool_calls: replace with raw text content
|
||||
3. Tool result messages (role=tool): pass through (chat template handles them)
|
||||
|
||||
Args:
|
||||
messages: OpenAI-format messages list.
|
||||
|
||||
Returns:
|
||||
Messages list with tool_call assistant messages reconstructed to raw text.
|
||||
"""
|
||||
converted = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "")
|
||||
|
||||
if role == "assistant" and msg.get("tool_calls"):
|
||||
# Reconstruct raw text from tool_calls
|
||||
raw_text = self.reconstruct_raw_text_from_tool_calls(msg["tool_calls"])
|
||||
# Prepend any content that came before the tool calls
|
||||
content = msg.get("content") or ""
|
||||
if content:
|
||||
raw_text = content + "\n" + raw_text
|
||||
|
||||
converted.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": raw_text,
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Pass through as-is (user, system, tool, regular assistant)
|
||||
converted.append(msg)
|
||||
|
||||
return converted
|
||||
|
||||
def decode_with_tool_awareness(
|
||||
self,
|
||||
token_ids: List[int],
|
||||
has_tools: bool = False,
|
||||
) -> str:
|
||||
"""Decode token IDs, preserving tool call tags when tools are active.
|
||||
|
||||
Some tokenizers mark <tool_call> as a special token. If we decode with
|
||||
skip_special_tokens=True (the default), the tags vanish before the
|
||||
parser ever sees them. When tools are in play, we decode with
|
||||
skip_special_tokens=False to preserve them.
|
||||
|
||||
Args:
|
||||
token_ids: Token IDs to decode.
|
||||
has_tools: Whether tools are active for this request.
|
||||
|
||||
Returns:
|
||||
Decoded text string.
|
||||
"""
|
||||
return self.tokenizer.decode(
|
||||
token_ids,
|
||||
skip_special_tokens=not has_tools,
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue