mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
add tool call parsing based on vllm impl and an openai server endpoint
This commit is contained in:
parent
887a94374c
commit
add42a2afb
11 changed files with 3370 additions and 34 deletions
292
atroposlib/envs/server_handling/proxy_client.py
Normal file
292
atroposlib/envs/server_handling/proxy_client.py
Normal file
|
|
@ -0,0 +1,292 @@
|
|||
"""
|
||||
Client that talks to the ManagedServer OpenAI proxy over HTTP.
|
||||
|
||||
Implements the same interface as ManagedServer so it can be used as a
|
||||
drop-in replacement via ServerManager.managed_server(use_proxy=True).
|
||||
|
||||
The proxy handles all the token tracking, tool call parsing, and sequence
|
||||
management. This client just ferries requests/responses over HTTP and
|
||||
reconstructs the SequenceNode objects from the JSON.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid as uuid_lib
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
from openai.types.chat.chat_completion import (
|
||||
ChatCompletion,
|
||||
ChatCompletionMessage,
|
||||
Choice,
|
||||
)
|
||||
from openai.types.completion import Completion # noqa: F401 — used in type hint
|
||||
|
||||
from atroposlib.envs.server_handling.managed_server import SequenceNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProxyManagedServer:
|
||||
"""Client that talks to the ManagedServer OpenAI proxy.
|
||||
|
||||
Same interface as ManagedServer — chat_completion(), completion(),
|
||||
get_state(), reset(). But instead of doing token tracking in-process,
|
||||
delegates everything to the proxy over HTTP.
|
||||
|
||||
Created by ServerManager.managed_server(use_proxy=True).
|
||||
|
||||
Example:
|
||||
async with server_manager.managed_server(use_proxy=True) as managed:
|
||||
# Same API as regular ManagedServer
|
||||
resp = await managed.chat_completion(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
n=4, max_tokens=100, temperature=1.0,
|
||||
)
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
# Extra: get URL for external apps to use directly
|
||||
url = managed.get_url()
|
||||
# → "http://proxy:9100/{uuid}/v1"
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
proxy_url: str,
|
||||
session_uuid: str,
|
||||
model_name: str = "unknown",
|
||||
base_url: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
proxy_url: Base URL of the proxy (e.g. "http://localhost:9100")
|
||||
session_uuid: UUID of the session on the proxy.
|
||||
model_name: Model name (for response objects).
|
||||
base_url: The backend server this session is pinned to.
|
||||
"""
|
||||
self.proxy_url = proxy_url.rstrip("/")
|
||||
self.session_uuid = session_uuid
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
|
||||
# Cache for nodes (populated by get_state)
|
||||
self._cached_nodes: Optional[List[SequenceNode]] = None
|
||||
|
||||
def get_url(self) -> str:
|
||||
"""Get the OpenAI-compatible API URL for this session.
|
||||
|
||||
External apps can use this URL with any OpenAI client:
|
||||
client = openai.OpenAI(base_url=managed.get_url())
|
||||
client.chat.completions.create(messages=..., tools=...)
|
||||
|
||||
Returns:
|
||||
URL like "http://proxy:9100/{uuid}/v1"
|
||||
"""
|
||||
return f"{self.proxy_url}/{self.session_uuid}/v1"
|
||||
|
||||
async def _post(self, path: str, json: dict, timeout: int = 300) -> dict:
|
||||
"""Make a POST request to the proxy."""
|
||||
url = f"{self.proxy_url}{path}"
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url, json=json, timeout=aiohttp.ClientTimeout(total=timeout)
|
||||
) as resp:
|
||||
data = await resp.json()
|
||||
if resp.status != 200:
|
||||
error_msg = data.get("error", {}).get("message", str(data))
|
||||
raise RuntimeError(
|
||||
f"Proxy request failed ({resp.status}): {error_msg}"
|
||||
)
|
||||
return data
|
||||
|
||||
async def _get(self, path: str, timeout: int = 30) -> dict:
|
||||
"""Make a GET request to the proxy."""
|
||||
url = f"{self.proxy_url}{path}"
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
url, timeout=aiohttp.ClientTimeout(total=timeout)
|
||||
) as resp:
|
||||
data = await resp.json()
|
||||
if resp.status != 200:
|
||||
error_msg = data.get("error", {}).get("message", str(data))
|
||||
raise RuntimeError(
|
||||
f"Proxy request failed ({resp.status}): {error_msg}"
|
||||
)
|
||||
return data
|
||||
|
||||
async def _delete(self, path: str, timeout: int = 30) -> dict:
|
||||
"""Make a DELETE request to the proxy."""
|
||||
url = f"{self.proxy_url}{path}"
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.delete(
|
||||
url, timeout=aiohttp.ClientTimeout(total=timeout)
|
||||
) as resp:
|
||||
data = await resp.json()
|
||||
return data
|
||||
|
||||
async def chat_completion(self, **kwargs) -> ChatCompletion:
|
||||
"""Send a chat completion request through the proxy.
|
||||
|
||||
Same interface as ManagedServer.chat_completion().
|
||||
The proxy handles template rendering, tool call parsing,
|
||||
and token/logprob tracking.
|
||||
"""
|
||||
# Convert messages to serializable format
|
||||
messages = kwargs.get("messages", [])
|
||||
serialized_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, dict):
|
||||
serialized_messages.append(msg)
|
||||
else:
|
||||
serialized_messages.append(dict(msg))
|
||||
|
||||
body = {
|
||||
"messages": serialized_messages,
|
||||
"max_tokens": kwargs.get("max_tokens", 1024),
|
||||
"temperature": kwargs.get("temperature", 1.0),
|
||||
"n": kwargs.get("n", 1),
|
||||
}
|
||||
if kwargs.get("stop"):
|
||||
body["stop"] = kwargs["stop"]
|
||||
if kwargs.get("tools"):
|
||||
body["tools"] = kwargs["tools"]
|
||||
if kwargs.get("tool_choice") is not None:
|
||||
body["tool_choice"] = kwargs["tool_choice"]
|
||||
|
||||
data = await self._post(f"/{self.session_uuid}/v1/chat/completions", json=body)
|
||||
|
||||
# Reconstruct ChatCompletion from proxy response
|
||||
choices = []
|
||||
for choice_data in data.get("choices", []):
|
||||
msg = choice_data.get("message", {})
|
||||
choice = Choice(
|
||||
finish_reason=choice_data.get("finish_reason", "stop"),
|
||||
index=choice_data.get("index", 0),
|
||||
message=ChatCompletionMessage(
|
||||
content=msg.get("content"),
|
||||
role=msg.get("role", "assistant"),
|
||||
),
|
||||
)
|
||||
choices.append(choice)
|
||||
|
||||
return ChatCompletion(
|
||||
id=data.get("id", str(uuid_lib.uuid4())),
|
||||
created=data.get("created", int(time.time())),
|
||||
model=data.get("model", self.model_name),
|
||||
object="chat.completion",
|
||||
choices=choices,
|
||||
)
|
||||
|
||||
async def completion(self, **kwargs) -> Completion:
|
||||
"""Send a completion request through the proxy.
|
||||
|
||||
Note: the proxy's chat/completions endpoint is the primary interface.
|
||||
For raw completions, the proxy renders the prompt via chat template
|
||||
internally. If you're calling this, you probably want chat_completion()
|
||||
instead.
|
||||
"""
|
||||
# For completion() calls, we'd need a /completions endpoint on the proxy.
|
||||
# Currently the proxy only exposes chat/completions. For now, raise
|
||||
# a clear error.
|
||||
raise NotImplementedError(
|
||||
"ProxyManagedServer.completion() is not supported. "
|
||||
"Use chat_completion() instead — the proxy handles template "
|
||||
"rendering internally."
|
||||
)
|
||||
|
||||
def get_state(self) -> Dict[str, Any]:
|
||||
"""Get the current state synchronously from cache.
|
||||
|
||||
Call fetch_state() first to populate from the proxy, or use
|
||||
the nodes returned by this method after a chat_completion() call.
|
||||
|
||||
Returns:
|
||||
Dict with 'nodes': List[SequenceNode]
|
||||
"""
|
||||
if self._cached_nodes is not None:
|
||||
return {"nodes": self._cached_nodes}
|
||||
return {"nodes": []}
|
||||
|
||||
async def fetch_state(self) -> Dict[str, Any]:
|
||||
"""Fetch current state from the proxy (async).
|
||||
|
||||
Returns:
|
||||
Dict with 'nodes': List[SequenceNode]
|
||||
"""
|
||||
data = await self._get(f"/{self.session_uuid}/nodes")
|
||||
nodes = []
|
||||
for node_data in data.get("nodes", []):
|
||||
nodes.append(SequenceNode(**node_data))
|
||||
self._cached_nodes = nodes
|
||||
return {"nodes": nodes}
|
||||
|
||||
def reset(self):
|
||||
"""Clear cached state. The actual cleanup happens in __aexit__."""
|
||||
self._cached_nodes = None
|
||||
|
||||
async def cleanup(self):
|
||||
"""Delete the session on the proxy."""
|
||||
try:
|
||||
await self._delete(f"/{self.session_uuid}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cleanup proxy session {self.session_uuid}: {e}")
|
||||
|
||||
# -- context manager support --
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
# Fetch final state before cleanup so callers can still access nodes
|
||||
try:
|
||||
await self.fetch_state()
|
||||
except Exception:
|
||||
pass
|
||||
await self.cleanup()
|
||||
|
||||
|
||||
async def create_proxy_session(
|
||||
proxy_url: str,
|
||||
base_url: Optional[str] = None,
|
||||
tool_parser: str = "hermes",
|
||||
track_tree: bool = False,
|
||||
model_name: str = "unknown",
|
||||
) -> ProxyManagedServer:
|
||||
"""Create a new session on the proxy and return a ProxyManagedServer.
|
||||
|
||||
Args:
|
||||
proxy_url: Base URL of the proxy (e.g. "http://localhost:9100").
|
||||
base_url: Pin to a specific backend server. In production, this
|
||||
comes from the atropos API's server allocation.
|
||||
tool_parser: vLLM tool parser name (default: "hermes").
|
||||
track_tree: Whether to use tree mode for tracking.
|
||||
model_name: Model name for response objects.
|
||||
|
||||
Returns:
|
||||
ProxyManagedServer instance ready to use.
|
||||
"""
|
||||
body = {
|
||||
"tool_parser": tool_parser,
|
||||
"track_tree": track_tree,
|
||||
}
|
||||
if base_url:
|
||||
body["base_url"] = base_url
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{proxy_url.rstrip('/')}/sessions/create",
|
||||
json=body,
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
) as resp:
|
||||
data = await resp.json()
|
||||
if resp.status != 200:
|
||||
error_msg = data.get("error", {}).get("message", str(data))
|
||||
raise RuntimeError(f"Failed to create proxy session: {error_msg}")
|
||||
|
||||
return ProxyManagedServer(
|
||||
proxy_url=proxy_url,
|
||||
session_uuid=data["uuid"],
|
||||
model_name=data.get("model_name", model_name),
|
||||
base_url=data.get("base_url"),
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue