atropos/atroposlib/envs/server_handling/proxy_client.py

292 lines
10 KiB
Python

"""
Client that talks to the ManagedServer OpenAI proxy over HTTP.
Implements the same interface as ManagedServer so it can be used as a
drop-in replacement via ServerManager.managed_server(use_proxy=True).
The proxy handles all the token tracking, tool call parsing, and sequence
management. This client just ferries requests/responses over HTTP and
reconstructs the SequenceNode objects from the JSON.
"""
import logging
import time
import uuid as uuid_lib
from typing import Any, Dict, List, Optional
import aiohttp
from openai.types.chat.chat_completion import (
ChatCompletion,
ChatCompletionMessage,
Choice,
)
from openai.types.completion import Completion # noqa: F401 — used in type hint
from atroposlib.envs.server_handling.managed_server import SequenceNode
logger = logging.getLogger(__name__)
class ProxyManagedServer:
"""Client that talks to the ManagedServer OpenAI proxy.
Same interface as ManagedServer — chat_completion(), completion(),
get_state(), reset(). But instead of doing token tracking in-process,
delegates everything to the proxy over HTTP.
Created by ServerManager.managed_server(use_proxy=True).
Example:
async with server_manager.managed_server(use_proxy=True) as managed:
# Same API as regular ManagedServer
resp = await managed.chat_completion(
messages=[{"role": "user", "content": "Hello"}],
n=4, max_tokens=100, temperature=1.0,
)
state = managed.get_state()
nodes = state["nodes"]
# Extra: get URL for external apps to use directly
url = managed.get_url()
# → "http://proxy:9100/{uuid}/v1"
"""
def __init__(
self,
proxy_url: str,
session_uuid: str,
model_name: str = "unknown",
base_url: Optional[str] = None,
):
"""
Args:
proxy_url: Base URL of the proxy (e.g. "http://localhost:9100")
session_uuid: UUID of the session on the proxy.
model_name: Model name (for response objects).
base_url: The backend server this session is pinned to.
"""
self.proxy_url = proxy_url.rstrip("/")
self.session_uuid = session_uuid
self.model_name = model_name
self.base_url = base_url
# Cache for nodes (populated by get_state)
self._cached_nodes: Optional[List[SequenceNode]] = None
def get_url(self) -> str:
"""Get the OpenAI-compatible API URL for this session.
External apps can use this URL with any OpenAI client:
client = openai.OpenAI(base_url=managed.get_url())
client.chat.completions.create(messages=..., tools=...)
Returns:
URL like "http://proxy:9100/{uuid}/v1"
"""
return f"{self.proxy_url}/{self.session_uuid}/v1"
async def _post(self, path: str, json: dict, timeout: int = 300) -> dict:
"""Make a POST request to the proxy."""
url = f"{self.proxy_url}{path}"
async with aiohttp.ClientSession() as session:
async with session.post(
url, json=json, timeout=aiohttp.ClientTimeout(total=timeout)
) as resp:
data = await resp.json()
if resp.status != 200:
error_msg = data.get("error", {}).get("message", str(data))
raise RuntimeError(
f"Proxy request failed ({resp.status}): {error_msg}"
)
return data
async def _get(self, path: str, timeout: int = 30) -> dict:
"""Make a GET request to the proxy."""
url = f"{self.proxy_url}{path}"
async with aiohttp.ClientSession() as session:
async with session.get(
url, timeout=aiohttp.ClientTimeout(total=timeout)
) as resp:
data = await resp.json()
if resp.status != 200:
error_msg = data.get("error", {}).get("message", str(data))
raise RuntimeError(
f"Proxy request failed ({resp.status}): {error_msg}"
)
return data
async def _delete(self, path: str, timeout: int = 30) -> dict:
"""Make a DELETE request to the proxy."""
url = f"{self.proxy_url}{path}"
async with aiohttp.ClientSession() as session:
async with session.delete(
url, timeout=aiohttp.ClientTimeout(total=timeout)
) as resp:
data = await resp.json()
return data
async def chat_completion(self, **kwargs) -> ChatCompletion:
"""Send a chat completion request through the proxy.
Same interface as ManagedServer.chat_completion().
The proxy handles template rendering, tool call parsing,
and token/logprob tracking.
"""
# Convert messages to serializable format
messages = kwargs.get("messages", [])
serialized_messages = []
for msg in messages:
if isinstance(msg, dict):
serialized_messages.append(msg)
else:
serialized_messages.append(dict(msg))
body = {
"messages": serialized_messages,
"max_tokens": kwargs.get("max_tokens", 1024),
"temperature": kwargs.get("temperature", 1.0),
"n": kwargs.get("n", 1),
}
if kwargs.get("stop"):
body["stop"] = kwargs["stop"]
if kwargs.get("tools"):
body["tools"] = kwargs["tools"]
if kwargs.get("tool_choice") is not None:
body["tool_choice"] = kwargs["tool_choice"]
data = await self._post(f"/{self.session_uuid}/v1/chat/completions", json=body)
# Reconstruct ChatCompletion from proxy response
choices = []
for choice_data in data.get("choices", []):
msg = choice_data.get("message", {})
choice = Choice(
finish_reason=choice_data.get("finish_reason", "stop"),
index=choice_data.get("index", 0),
message=ChatCompletionMessage(
content=msg.get("content"),
role=msg.get("role", "assistant"),
),
)
choices.append(choice)
return ChatCompletion(
id=data.get("id", str(uuid_lib.uuid4())),
created=data.get("created", int(time.time())),
model=data.get("model", self.model_name),
object="chat.completion",
choices=choices,
)
async def completion(self, **kwargs) -> Completion:
"""Send a completion request through the proxy.
Note: the proxy's chat/completions endpoint is the primary interface.
For raw completions, the proxy renders the prompt via chat template
internally. If you're calling this, you probably want chat_completion()
instead.
"""
# For completion() calls, we'd need a /completions endpoint on the proxy.
# Currently the proxy only exposes chat/completions. For now, raise
# a clear error.
raise NotImplementedError(
"ProxyManagedServer.completion() is not supported. "
"Use chat_completion() instead — the proxy handles template "
"rendering internally."
)
def get_state(self) -> Dict[str, Any]:
"""Get the current state synchronously from cache.
Call fetch_state() first to populate from the proxy, or use
the nodes returned by this method after a chat_completion() call.
Returns:
Dict with 'nodes': List[SequenceNode]
"""
if self._cached_nodes is not None:
return {"nodes": self._cached_nodes}
return {"nodes": []}
async def fetch_state(self) -> Dict[str, Any]:
"""Fetch current state from the proxy (async).
Returns:
Dict with 'nodes': List[SequenceNode]
"""
data = await self._get(f"/{self.session_uuid}/nodes")
nodes = []
for node_data in data.get("nodes", []):
nodes.append(SequenceNode(**node_data))
self._cached_nodes = nodes
return {"nodes": nodes}
def reset(self):
"""Clear cached state. The actual cleanup happens in __aexit__."""
self._cached_nodes = None
async def cleanup(self):
"""Delete the session on the proxy."""
try:
await self._delete(f"/{self.session_uuid}")
except Exception as e:
logger.warning(f"Failed to cleanup proxy session {self.session_uuid}: {e}")
# -- context manager support --
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
# Fetch final state before cleanup so callers can still access nodes
try:
await self.fetch_state()
except Exception:
pass
await self.cleanup()
async def create_proxy_session(
proxy_url: str,
base_url: Optional[str] = None,
tool_parser: str = "hermes",
track_tree: bool = False,
model_name: str = "unknown",
) -> ProxyManagedServer:
"""Create a new session on the proxy and return a ProxyManagedServer.
Args:
proxy_url: Base URL of the proxy (e.g. "http://localhost:9100").
base_url: Pin to a specific backend server. In production, this
comes from the atropos API's server allocation.
tool_parser: vLLM tool parser name (default: "hermes").
track_tree: Whether to use tree mode for tracking.
model_name: Model name for response objects.
Returns:
ProxyManagedServer instance ready to use.
"""
body = {
"tool_parser": tool_parser,
"track_tree": track_tree,
}
if base_url:
body["base_url"] = base_url
async with aiohttp.ClientSession() as session:
async with session.post(
f"{proxy_url.rstrip('/')}/sessions/create",
json=body,
timeout=aiohttp.ClientTimeout(total=30),
) as resp:
data = await resp.json()
if resp.status != 200:
error_msg = data.get("error", {}).get("message", str(data))
raise RuntimeError(f"Failed to create proxy session: {error_msg}")
return ProxyManagedServer(
proxy_url=proxy_url,
session_uuid=data["uuid"],
model_name=data.get("model_name", model_name),
base_url=data.get("base_url"),
)