mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
292 lines
10 KiB
Python
292 lines
10 KiB
Python
"""
|
|
Client that talks to the ManagedServer OpenAI proxy over HTTP.
|
|
|
|
Implements the same interface as ManagedServer so it can be used as a
|
|
drop-in replacement via ServerManager.managed_server(use_proxy=True).
|
|
|
|
The proxy handles all the token tracking, tool call parsing, and sequence
|
|
management. This client just ferries requests/responses over HTTP and
|
|
reconstructs the SequenceNode objects from the JSON.
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
import uuid as uuid_lib
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import aiohttp
|
|
from openai.types.chat.chat_completion import (
|
|
ChatCompletion,
|
|
ChatCompletionMessage,
|
|
Choice,
|
|
)
|
|
from openai.types.completion import Completion # noqa: F401 — used in type hint
|
|
|
|
from atroposlib.envs.server_handling.managed_server import SequenceNode
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ProxyManagedServer:
|
|
"""Client that talks to the ManagedServer OpenAI proxy.
|
|
|
|
Same interface as ManagedServer — chat_completion(), completion(),
|
|
get_state(), reset(). But instead of doing token tracking in-process,
|
|
delegates everything to the proxy over HTTP.
|
|
|
|
Created by ServerManager.managed_server(use_proxy=True).
|
|
|
|
Example:
|
|
async with server_manager.managed_server(use_proxy=True) as managed:
|
|
# Same API as regular ManagedServer
|
|
resp = await managed.chat_completion(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
n=4, max_tokens=100, temperature=1.0,
|
|
)
|
|
state = managed.get_state()
|
|
nodes = state["nodes"]
|
|
|
|
# Extra: get URL for external apps to use directly
|
|
url = managed.get_url()
|
|
# → "http://proxy:9100/{uuid}/v1"
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
proxy_url: str,
|
|
session_uuid: str,
|
|
model_name: str = "unknown",
|
|
base_url: Optional[str] = None,
|
|
):
|
|
"""
|
|
Args:
|
|
proxy_url: Base URL of the proxy (e.g. "http://localhost:9100")
|
|
session_uuid: UUID of the session on the proxy.
|
|
model_name: Model name (for response objects).
|
|
base_url: The backend server this session is pinned to.
|
|
"""
|
|
self.proxy_url = proxy_url.rstrip("/")
|
|
self.session_uuid = session_uuid
|
|
self.model_name = model_name
|
|
self.base_url = base_url
|
|
|
|
# Cache for nodes (populated by get_state)
|
|
self._cached_nodes: Optional[List[SequenceNode]] = None
|
|
|
|
def get_url(self) -> str:
|
|
"""Get the OpenAI-compatible API URL for this session.
|
|
|
|
External apps can use this URL with any OpenAI client:
|
|
client = openai.OpenAI(base_url=managed.get_url())
|
|
client.chat.completions.create(messages=..., tools=...)
|
|
|
|
Returns:
|
|
URL like "http://proxy:9100/{uuid}/v1"
|
|
"""
|
|
return f"{self.proxy_url}/{self.session_uuid}/v1"
|
|
|
|
async def _post(self, path: str, json: dict, timeout: int = 300) -> dict:
|
|
"""Make a POST request to the proxy."""
|
|
url = f"{self.proxy_url}{path}"
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.post(
|
|
url, json=json, timeout=aiohttp.ClientTimeout(total=timeout)
|
|
) as resp:
|
|
data = await resp.json()
|
|
if resp.status != 200:
|
|
error_msg = data.get("error", {}).get("message", str(data))
|
|
raise RuntimeError(
|
|
f"Proxy request failed ({resp.status}): {error_msg}"
|
|
)
|
|
return data
|
|
|
|
async def _get(self, path: str, timeout: int = 30) -> dict:
|
|
"""Make a GET request to the proxy."""
|
|
url = f"{self.proxy_url}{path}"
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(
|
|
url, timeout=aiohttp.ClientTimeout(total=timeout)
|
|
) as resp:
|
|
data = await resp.json()
|
|
if resp.status != 200:
|
|
error_msg = data.get("error", {}).get("message", str(data))
|
|
raise RuntimeError(
|
|
f"Proxy request failed ({resp.status}): {error_msg}"
|
|
)
|
|
return data
|
|
|
|
async def _delete(self, path: str, timeout: int = 30) -> dict:
|
|
"""Make a DELETE request to the proxy."""
|
|
url = f"{self.proxy_url}{path}"
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.delete(
|
|
url, timeout=aiohttp.ClientTimeout(total=timeout)
|
|
) as resp:
|
|
data = await resp.json()
|
|
return data
|
|
|
|
async def chat_completion(self, **kwargs) -> ChatCompletion:
|
|
"""Send a chat completion request through the proxy.
|
|
|
|
Same interface as ManagedServer.chat_completion().
|
|
The proxy handles template rendering, tool call parsing,
|
|
and token/logprob tracking.
|
|
"""
|
|
# Convert messages to serializable format
|
|
messages = kwargs.get("messages", [])
|
|
serialized_messages = []
|
|
for msg in messages:
|
|
if isinstance(msg, dict):
|
|
serialized_messages.append(msg)
|
|
else:
|
|
serialized_messages.append(dict(msg))
|
|
|
|
body = {
|
|
"messages": serialized_messages,
|
|
"max_tokens": kwargs.get("max_tokens", 1024),
|
|
"temperature": kwargs.get("temperature", 1.0),
|
|
"n": kwargs.get("n", 1),
|
|
}
|
|
if kwargs.get("stop"):
|
|
body["stop"] = kwargs["stop"]
|
|
if kwargs.get("tools"):
|
|
body["tools"] = kwargs["tools"]
|
|
if kwargs.get("tool_choice") is not None:
|
|
body["tool_choice"] = kwargs["tool_choice"]
|
|
|
|
data = await self._post(f"/{self.session_uuid}/v1/chat/completions", json=body)
|
|
|
|
# Reconstruct ChatCompletion from proxy response
|
|
choices = []
|
|
for choice_data in data.get("choices", []):
|
|
msg = choice_data.get("message", {})
|
|
choice = Choice(
|
|
finish_reason=choice_data.get("finish_reason", "stop"),
|
|
index=choice_data.get("index", 0),
|
|
message=ChatCompletionMessage(
|
|
content=msg.get("content"),
|
|
role=msg.get("role", "assistant"),
|
|
),
|
|
)
|
|
choices.append(choice)
|
|
|
|
return ChatCompletion(
|
|
id=data.get("id", str(uuid_lib.uuid4())),
|
|
created=data.get("created", int(time.time())),
|
|
model=data.get("model", self.model_name),
|
|
object="chat.completion",
|
|
choices=choices,
|
|
)
|
|
|
|
async def completion(self, **kwargs) -> Completion:
|
|
"""Send a completion request through the proxy.
|
|
|
|
Note: the proxy's chat/completions endpoint is the primary interface.
|
|
For raw completions, the proxy renders the prompt via chat template
|
|
internally. If you're calling this, you probably want chat_completion()
|
|
instead.
|
|
"""
|
|
# For completion() calls, we'd need a /completions endpoint on the proxy.
|
|
# Currently the proxy only exposes chat/completions. For now, raise
|
|
# a clear error.
|
|
raise NotImplementedError(
|
|
"ProxyManagedServer.completion() is not supported. "
|
|
"Use chat_completion() instead — the proxy handles template "
|
|
"rendering internally."
|
|
)
|
|
|
|
def get_state(self) -> Dict[str, Any]:
|
|
"""Get the current state synchronously from cache.
|
|
|
|
Call fetch_state() first to populate from the proxy, or use
|
|
the nodes returned by this method after a chat_completion() call.
|
|
|
|
Returns:
|
|
Dict with 'nodes': List[SequenceNode]
|
|
"""
|
|
if self._cached_nodes is not None:
|
|
return {"nodes": self._cached_nodes}
|
|
return {"nodes": []}
|
|
|
|
async def fetch_state(self) -> Dict[str, Any]:
|
|
"""Fetch current state from the proxy (async).
|
|
|
|
Returns:
|
|
Dict with 'nodes': List[SequenceNode]
|
|
"""
|
|
data = await self._get(f"/{self.session_uuid}/nodes")
|
|
nodes = []
|
|
for node_data in data.get("nodes", []):
|
|
nodes.append(SequenceNode(**node_data))
|
|
self._cached_nodes = nodes
|
|
return {"nodes": nodes}
|
|
|
|
def reset(self):
|
|
"""Clear cached state. The actual cleanup happens in __aexit__."""
|
|
self._cached_nodes = None
|
|
|
|
async def cleanup(self):
|
|
"""Delete the session on the proxy."""
|
|
try:
|
|
await self._delete(f"/{self.session_uuid}")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to cleanup proxy session {self.session_uuid}: {e}")
|
|
|
|
# -- context manager support --
|
|
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
# Fetch final state before cleanup so callers can still access nodes
|
|
try:
|
|
await self.fetch_state()
|
|
except Exception:
|
|
pass
|
|
await self.cleanup()
|
|
|
|
|
|
async def create_proxy_session(
|
|
proxy_url: str,
|
|
base_url: Optional[str] = None,
|
|
tool_parser: str = "hermes",
|
|
track_tree: bool = False,
|
|
model_name: str = "unknown",
|
|
) -> ProxyManagedServer:
|
|
"""Create a new session on the proxy and return a ProxyManagedServer.
|
|
|
|
Args:
|
|
proxy_url: Base URL of the proxy (e.g. "http://localhost:9100").
|
|
base_url: Pin to a specific backend server. In production, this
|
|
comes from the atropos API's server allocation.
|
|
tool_parser: vLLM tool parser name (default: "hermes").
|
|
track_tree: Whether to use tree mode for tracking.
|
|
model_name: Model name for response objects.
|
|
|
|
Returns:
|
|
ProxyManagedServer instance ready to use.
|
|
"""
|
|
body = {
|
|
"tool_parser": tool_parser,
|
|
"track_tree": track_tree,
|
|
}
|
|
if base_url:
|
|
body["base_url"] = base_url
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.post(
|
|
f"{proxy_url.rstrip('/')}/sessions/create",
|
|
json=body,
|
|
timeout=aiohttp.ClientTimeout(total=30),
|
|
) as resp:
|
|
data = await resp.json()
|
|
if resp.status != 200:
|
|
error_msg = data.get("error", {}).get("message", str(data))
|
|
raise RuntimeError(f"Failed to create proxy session: {error_msg}")
|
|
|
|
return ProxyManagedServer(
|
|
proxy_url=proxy_url,
|
|
session_uuid=data["uuid"],
|
|
model_name=data.get("model_name", model_name),
|
|
base_url=data.get("base_url"),
|
|
)
|