mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
add tool call parsing based on vllm impl and an openai server endpoint
This commit is contained in:
parent
887a94374c
commit
add42a2afb
11 changed files with 3370 additions and 34 deletions
|
|
@ -41,6 +41,14 @@ class ServerManagerConfig(BaseModel):
|
|||
"This is to help load balance servers."
|
||||
),
|
||||
)
|
||||
proxy_url: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"URL of the ManagedServer OpenAI proxy (e.g. 'http://localhost:9100'). "
|
||||
"When set, managed_server(use_proxy=True) routes through this proxy. "
|
||||
"Can also be set via ATROPOS_PROXY_URL environment variable."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ServerManager:
|
||||
|
|
@ -52,9 +60,18 @@ class ServerManager:
|
|||
testing=False,
|
||||
max_n_completions=8,
|
||||
reasoning_config: Optional[ReasoningConfig] = None,
|
||||
proxy_url: Optional[str] = None,
|
||||
use_proxy: bool = False,
|
||||
tool_parser: Optional[str] = None,
|
||||
):
|
||||
self.max_n_completions = max_n_completions
|
||||
self.reasoning_config = reasoning_config
|
||||
# Proxy config — when use_proxy=True, managed_server() routes
|
||||
# through the proxy HTTP API instead of creating in-process instances
|
||||
self.proxy_url = proxy_url or os.environ.get("ATROPOS_PROXY_URL")
|
||||
self.use_proxy = use_proxy or bool(self.proxy_url)
|
||||
# Tool parser — passed to ManagedServer for tool call support
|
||||
self.tool_parser = tool_parser
|
||||
# First we check to see if it's the base server class, and if so, we need to select the appropriate server class
|
||||
# You can't use type() to check if it's the base server class, because it's an abstract class, it'll appear as
|
||||
# an ABCMeta, not what you're expecting.
|
||||
|
|
@ -364,8 +381,10 @@ class ServerManager:
|
|||
|
||||
@asynccontextmanager
|
||||
async def managed_server(
|
||||
self, tokenizer=None
|
||||
) -> AsyncGenerator[Union[ManagedServer, DummyManagedServer], None]:
|
||||
self,
|
||||
tokenizer=None,
|
||||
base_url: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Context manager that provides a ManagedServer instance.
|
||||
|
||||
|
|
@ -379,25 +398,63 @@ class ServerManager:
|
|||
Args:
|
||||
tokenizer: Optional tokenizer to use. If not provided, will attempt to
|
||||
extract from server or create from model name.
|
||||
base_url: Pin the session to a specific backend server by its base_url.
|
||||
In production, this comes from the atropos API's server allocation.
|
||||
|
||||
Yields:
|
||||
ManagedServer (or DummyManagedServer for OpenAI) instance wrapping
|
||||
the selected server
|
||||
ManagedServer, DummyManagedServer, or ProxyManagedServer instance
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If using OpenAI server without the
|
||||
ATROPOS_ALLOW_DUMMY_MANAGED_SERVER env var set.
|
||||
|
||||
Example:
|
||||
# In-process (default):
|
||||
async with server_manager.managed_server() as managed:
|
||||
response = await managed.chat_completion(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
n=2
|
||||
n=2, tools=[...], tool_choice="auto",
|
||||
)
|
||||
state = managed.get_state()
|
||||
# Process state...
|
||||
# State is automatically cleared when exiting context
|
||||
|
||||
# Via proxy (configured at init with proxy_url= or ATROPOS_PROXY_URL):
|
||||
# server_manager = ServerManager(configs, proxy_url="http://proxy:9100")
|
||||
async with server_manager.managed_server() as managed:
|
||||
response = await managed.chat_completion(...)
|
||||
api_url = managed.get_url() # for external apps
|
||||
"""
|
||||
# -- Proxy path --
|
||||
if self.use_proxy:
|
||||
resolved_proxy_url = self.proxy_url
|
||||
if not resolved_proxy_url:
|
||||
raise ValueError(
|
||||
"use_proxy=True requires proxy_url or ATROPOS_PROXY_URL env var "
|
||||
"to be set at ServerManager init"
|
||||
)
|
||||
|
||||
from atroposlib.envs.server_handling.proxy_client import (
|
||||
create_proxy_session,
|
||||
)
|
||||
|
||||
model_name = (
|
||||
self.servers[0].config.model_name
|
||||
if self.servers and hasattr(self.servers[0], "config")
|
||||
else "unknown"
|
||||
)
|
||||
|
||||
proxy_managed = await create_proxy_session(
|
||||
proxy_url=resolved_proxy_url,
|
||||
base_url=base_url,
|
||||
tool_parser=self.tool_parser or "hermes",
|
||||
model_name=model_name,
|
||||
)
|
||||
try:
|
||||
yield proxy_managed
|
||||
finally:
|
||||
await proxy_managed.cleanup()
|
||||
return
|
||||
|
||||
# -- In-process path (existing logic) --
|
||||
most_available_server = 0
|
||||
most_available_server_num_slots = -1
|
||||
for i, server in enumerate(self.servers):
|
||||
|
|
@ -441,7 +498,11 @@ class ServerManager:
|
|||
finally:
|
||||
managed.reset()
|
||||
else:
|
||||
managed = ManagedServer(server=selected_server, tokenizer=tokenizer)
|
||||
managed = ManagedServer(
|
||||
server=selected_server,
|
||||
tokenizer=tokenizer,
|
||||
tool_parser=self.tool_parser,
|
||||
)
|
||||
|
||||
try:
|
||||
yield managed
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue