atropos/atroposlib/envs/server_handling/server_manager.py
2026-03-09 23:13:43 -05:00

542 lines
22 KiB
Python

import asyncio
import inspect
import os
import warnings
from contextlib import asynccontextmanager
from typing import AsyncGenerator, List, Optional, Union
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.completion import Completion
from pydantic import BaseModel, Field
from atroposlib.envs.server_handling.managed_server import (
DummyManagedServer,
ManagedServer,
)
from atroposlib.envs.server_handling.openai_server import OpenAIServer
from atroposlib.envs.server_handling.server_baseline import (
APIServer,
APIServerConfig,
ReasoningConfig,
ServerBaseline,
)
from atroposlib.envs.server_handling.server_harness import ServerHarness
from atroposlib.envs.server_handling.sglang_server import SGLangServer
from atroposlib.envs.server_handling.trl_vllm_server import TrlVllmServer
from atroposlib.envs.server_handling.vllm_server import VLLMServer
class ServerManagerConfig(BaseModel):
slurm: bool = Field(
default=False, description="Whether environment is running on slurm or not."
)
testing: bool = Field(
default=False, description="If set to True, environment uses mock OpenAI data."
)
max_n_completions: int = Field(
default=8,
description=(
"The maximum number of completions to request at once per server call. "
"Will split any n larger than this into multiple calls. "
"This is to help load balance servers."
),
)
proxy_url: Optional[str] = Field(
default=None,
description=(
"URL of the ManagedServer OpenAI proxy (e.g. 'http://localhost:9100'). "
"When set, managed_server(use_proxy=True) routes through this proxy. "
"Can also be set via ATROPOS_PROXY_URL environment variable."
),
)
class ServerManager:
def __init__(
self,
configs: Union[ServerBaseline, List[APIServerConfig]],
server_class: APIServer = APIServer,
slurm=False,
testing=False,
max_n_completions=8,
reasoning_config: Optional[ReasoningConfig] = None,
proxy_url: Optional[str] = None,
use_proxy: bool = False,
tool_parser: Optional[str] = None,
):
self.max_n_completions = max_n_completions
self.reasoning_config = reasoning_config
# Proxy config — when use_proxy=True, managed_server() routes
# through the proxy HTTP API instead of creating in-process instances
self.proxy_url = proxy_url or os.environ.get("ATROPOS_PROXY_URL")
self.use_proxy = use_proxy or bool(self.proxy_url)
# Tool parser — passed to ManagedServer for tool call support
self.tool_parser = tool_parser
# First we check to see if it's the base server class, and if so, we need to select the appropriate server class
# You can't use type() to check if it's the base server class, because it's an abstract class, it'll appear as
# an ABCMeta, not what you're expecting.
if inspect.isabstract(server_class):
if not isinstance(configs, list):
if configs.server_type == "openai":
server_class = OpenAIServer
elif configs.server_type == "trl":
server_class = TrlVllmServer
elif configs.server_type == "sglang":
server_class = SGLangServer
elif configs.server_type == "vllm":
server_class = VLLMServer
else:
raise ValueError(f"Invalid server type: {configs.server_type}")
else:
if configs[0].server_type == "openai":
server_class = OpenAIServer
elif configs[0].server_type == "trl":
server_class = TrlVllmServer
elif configs[0].server_type == "sglang":
server_class = SGLangServer
elif configs[0].server_type == "vllm":
server_class = VLLMServer
else:
raise ValueError(f"Invalid server type: {configs[0].server_type}")
if testing:
# testing :)
self.servers = [ServerHarness()]
return
if not isinstance(configs, list):
urls = []
if os.environ.get("SLURM_JOB_NODELIST", None) is not None:
nodelist = (
os.popen(
f'scontrol show hostnames {os.environ["SLURM_JOB_NODELIST"]}'
)
.read()
.split("\n")
)
nodelist = [node for node in nodelist if node != ""]
if len(nodelist) < 2:
# localhost!
for i in range(4):
urls.append(f"http://localhost:{9000 + i + 4}/v1")
else:
num_training_nodes = int(os.environ.get("NUM_TRAINING_NODES"))
for node in nodelist[num_training_nodes:]:
for i in range(8 // os.environ.get("INFER_TP", 1)):
urls.append(f"http://{node}:{9000 + i}/v1")
openai_configs = []
else:
# localhost!
for i in range(4):
urls.append(f"http://localhost:{9000 + i + 4}/v1")
openai_configs = []
for url in urls:
openai_configs.append(
APIServerConfig(
base_url=url,
timeout=configs.timeout,
num_max_requests_at_once=configs.num_max_requests_at_once,
num_requests_for_eval=configs.num_requests_for_eval,
model_name=configs.model_name,
rolling_buffer_length=configs.rolling_buffer_length,
api_key="x",
tokenizer_name=configs.tokenizer_name,
)
)
self.servers = [
server_class(config, reasoning_config=reasoning_config)
for config in openai_configs
]
elif not slurm:
self.servers = [
server_class(config, reasoning_config=reasoning_config)
for config in configs
]
else:
nodelist = (
os.popen(f'scontrol show hostnames {os.environ["SLURM_JOB_NODELIST"]}')
.read()
.split("\n")
)
nodelist = [node for node in nodelist if node != ""]
if len(nodelist) < 2:
print(
"Not enough nodes to distribute to, assuming single node"
" and you've setup your sglang appropriately."
)
self.servers = [
server_class(config, reasoning_config=reasoning_config)
for config in configs
]
return
urls = []
num_training_nodes = int(os.environ.get("NUM_TRAINING_NODES"))
for node in nodelist[num_training_nodes:]:
if node == "":
continue
for i in range(8 // os.environ.get("INFER_TP", 1)):
urls.append(f"http://{node}:{9000 + i}/v1")
# assume at least one good config is passed in
new_configs = []
for i in range(len(urls)):
new_conf = configs[0].model_copy(deep=True)
new_conf.base_url = urls[i]
new_configs.append(new_conf)
self.servers = [
server_class(config, reasoning_config=reasoning_config)
for config in new_configs
]
async def update_weight(self, weight: float):
for server in self.servers:
await server.update_weight(weight)
def _get_server_base_url(self, server_idx: int = 0) -> Optional[str]:
"""Get the base_url from a server's config."""
if not self.servers:
return None
server = self.servers[server_idx]
if hasattr(server, "config") and hasattr(server.config, "base_url"):
return server.config.base_url
return None
async def wait_for_sem(self, is_training: bool):
"""
Wait for a server to be available. This is used to prevent the client from
overwhelming the server with requests.
"""
def get_available_slots():
if is_training:
eval_vals = [
(
max(0, server.eval_sem._value - server.eval_sem.min_val())
if server.eval_sem._value != server.eval_sem.max_val
else 0
)
for server in self.servers
]
return [
max(0, (server.sem._value - server.sem.min_val()) - eval_val)
for server, eval_val in zip(self.servers, eval_vals)
]
else:
return [
max(0, server.eval_sem._value - server.eval_sem.min_val())
for server in self.servers
]
sem_vals = get_available_slots()
while all(sem_val <= 0 for sem_val in sem_vals):
# None available... wait
await asyncio.sleep(1)
sem_vals = get_available_slots()
async def chat_completion(self, **kwargs) -> ChatCompletion:
"""
Route chat completion to the most available server.
Reasoning config injection is handled by the individual servers.
Pass `skip_reasoning=True` to bypass reasoning injection for this call.
"""
n = kwargs.get("n", 1)
if n > self.max_n_completions:
# Split into multiple completions
completions = []
total_n = n
while total_n > 0:
n_to_use = min(total_n, self.max_n_completions)
kwargs["n"] = n_to_use
completions.append(self.chat_completion(**kwargs))
total_n -= n_to_use
completions = await asyncio.gather(
*completions
) # type: List[ChatCompletion]
# merge choices into one
out = completions[0]
for completion in completions[1:]:
out.choices.extend(completion.choices)
return out
is_train = kwargs.pop("split", "train") == "train"
most_available_server = 0
most_available_server_num_slots = -1
await self.wait_for_sem(is_train)
for i, server in enumerate(self.servers):
if not server.server_healthy:
continue
if (
server.sem._value if is_train else server.eval_sem._value
) > most_available_server_num_slots:
most_available_server = i
most_available_server_num_slots = (
server.sem._value if is_train else server.eval_sem._value
)
return await self.servers[most_available_server].chat_completion(**kwargs)
async def completion(self, **kwargs) -> Completion:
"""
Route completion to the most available server.
Reasoning config injection is handled by the individual servers.
Pass `skip_reasoning=True` to bypass reasoning injection for this call.
"""
n = kwargs.get("n", 1)
if n > self.max_n_completions:
# Split into multiple completions
completions = []
total_n = n
while total_n > 0:
n_to_use = min(total_n, self.max_n_completions)
kwargs["n"] = n_to_use
completions.append(self.completion(**kwargs))
total_n -= n_to_use
completions = await asyncio.gather(*completions) # type: List[Completion]
# merge choices into one
out = completions[0]
for completion in completions[1:]:
out.choices.extend(completion.choices)
return out
is_train = kwargs.pop("split", "train") == "train"
most_available_server = 0
most_available_server_num_slots = -1
await self.wait_for_sem(is_train)
for i, server in enumerate(self.servers):
if not server.server_healthy:
continue
if (
server.sem._value if is_train else server.eval_sem._value
) > most_available_server_num_slots:
most_available_server = i
most_available_server_num_slots = (
server.sem._value if is_train else server.eval_sem._value
)
return await self.servers[most_available_server].completion(**kwargs)
async def tokens_and_logprobs_completion(
self, **kwargs
) -> tuple[list, list, list, list]:
"""
Get tokens and logprobs from completion.
Returns (prompt_tokens, output_tokens, output_logprobs, finish_reasons).
Note: Reasoning config is NOT injected here - this method is for extracting
raw token-level data for training, not for generating reasoned responses.
"""
n = kwargs.get("n", 1)
if n > self.max_n_completions:
# Split into multiple completions
results = []
total_n = n
while total_n > 0:
n_to_use = min(total_n, self.max_n_completions)
kwargs["n"] = n_to_use
results.append(self.tokens_and_logprobs_completion(**kwargs))
total_n -= n_to_use
results = await asyncio.gather(*results)
# Merge results - prompt_tokens should be same, extend output lists
prompt_tokens = results[0][0]
output_tokens = []
output_logprobs = []
finish_reasons = []
for _, out_tokens, out_logprobs, out_finish_reasons in results:
output_tokens.extend(out_tokens)
output_logprobs.extend(out_logprobs)
finish_reasons.extend(out_finish_reasons)
return (prompt_tokens, output_tokens, output_logprobs, finish_reasons)
is_train = kwargs.pop("split", "train") == "train"
most_available_server = 0
most_available_server_num_slots = -1
await self.wait_for_sem(is_train)
for i, server in enumerate(self.servers):
if not server.server_healthy:
continue
if (
server.sem._value if is_train else server.eval_sem._value
) > most_available_server_num_slots:
most_available_server = i
most_available_server_num_slots = (
server.sem._value if is_train else server.eval_sem._value
)
return await self.servers[most_available_server].tokens_and_logprobs_completion(
**kwargs
)
async def get_logprobs(self, **kwargs) -> dict:
"""
Route normalized prompt-logprob requests to the most available server.
Returns a normalized dict with:
- prompt_tokens
- prompt_topk_token_ids
- prompt_topk_logprobs
"""
is_train = kwargs.pop("split", "train") == "train"
most_available_server = 0
most_available_server_num_slots = -1
await self.wait_for_sem(is_train)
for i, server in enumerate(self.servers):
if not server.server_healthy:
continue
if (
server.sem._value if is_train else server.eval_sem._value
) > most_available_server_num_slots:
most_available_server = i
most_available_server_num_slots = (
server.sem._value if is_train else server.eval_sem._value
)
return await self.servers[most_available_server].get_logprobs(**kwargs)
@asynccontextmanager
async def dedicated_server(self) -> AsyncGenerator[OpenAIServer, None]:
most_available_server = 0
most_available_server_num_slots = -1
for i, server in enumerate(self.servers):
if not server.server_healthy:
continue
if server.sem._value > most_available_server_num_slots:
most_available_server = i
most_available_server_num_slots = server.sem._value
async with self.servers[most_available_server].sem:
try:
yield self.servers[most_available_server]
finally:
pass
@asynccontextmanager
async def managed_server(
self,
tokenizer=None,
base_url: Optional[str] = None,
preserve_think_blocks: bool = False,
):
"""
Context manager that provides a ManagedServer instance.
The ManagedServer wraps the most available server and tracks text sequences
with aligned tokens and logprobs. State is automatically cleared on exit.
For OpenAI endpoints (which don't support token IDs/logprobs), a
DummyManagedServer is returned if the ATROPOS_ALLOW_DUMMY_MANAGED_SERVER
environment variable is set. Otherwise, a NotImplementedError is raised.
Args:
tokenizer: Optional tokenizer to use. If not provided, will attempt to
extract from server or create from model name.
base_url: Pin the session to a specific backend server by its base_url.
In production, this comes from the atropos API's server allocation.
preserve_think_blocks: If True, preserves <think> blocks in assistant messages,
which are sometimes stripped by chat templates. Defaults to False.
Usually not needed, since the chat template should be configured
to preserve thinking blocks until a user message arrives.
Yields:
ManagedServer, DummyManagedServer, or ProxyManagedServer instance
Raises:
NotImplementedError: If using OpenAI server without the
ATROPOS_ALLOW_DUMMY_MANAGED_SERVER env var set.
Example:
# In-process (default):
async with server_manager.managed_server() as managed:
response = await managed.chat_completion(
messages=[{"role": "user", "content": "Hello"}],
n=2, tools=[...], tool_choice="auto",
)
state = managed.get_state()
# Via proxy (configured at init with proxy_url= or ATROPOS_PROXY_URL):
# server_manager = ServerManager(configs, proxy_url="http://proxy:9100")
async with server_manager.managed_server() as managed:
response = await managed.chat_completion(...)
api_url = managed.get_url() # for external apps
"""
# -- Proxy path --
if self.use_proxy:
resolved_proxy_url = self.proxy_url
if not resolved_proxy_url:
raise ValueError(
"use_proxy=True requires proxy_url or ATROPOS_PROXY_URL env var "
"to be set at ServerManager init"
)
from atroposlib.envs.server_handling.proxy_client import (
create_proxy_session,
)
model_name = (
self.servers[0].config.model_name
if self.servers and hasattr(self.servers[0], "config")
else "unknown"
)
proxy_managed = await create_proxy_session(
proxy_url=resolved_proxy_url,
base_url=base_url,
tool_parser=self.tool_parser or "hermes",
model_name=model_name,
)
try:
yield proxy_managed
finally:
await proxy_managed.cleanup()
return
# -- In-process path (existing logic) --
most_available_server = 0
most_available_server_num_slots = -1
for i, server in enumerate(self.servers):
if not server.server_healthy:
continue
if server.sem._value > most_available_server_num_slots:
most_available_server = i
most_available_server_num_slots = server.sem._value
selected_server = self.servers[most_available_server]
# Handle OpenAI servers separately - they don't support token IDs/logprobs
if isinstance(selected_server, OpenAIServer):
allow_dummy = os.environ.get(
"ATROPOS_ALLOW_DUMMY_MANAGED_SERVER", ""
).lower() in (
"1",
"true",
"yes",
)
if not allow_dummy:
raise NotImplementedError(
"OpenAI endpoints do not support token IDs or logprobs required for "
"ManagedServer. If you don't need actual token-level training data and "
"are okay with dummy placeholder values, set the environment variable:\n\n"
" export ATROPOS_ALLOW_DUMMY_MANAGED_SERVER=1\n\n"
"WARNING: The DummyManagedServer will return placeholder token IDs and "
"logprobs (all zeros) that are NOT suitable for training. Use only for "
"evaluation or testing workflows."
)
warnings.warn(
"Using DummyManagedServer with OpenAI endpoint. Token IDs and logprobs "
"will be placeholder values and are NOT suitable for training."
)
managed = DummyManagedServer(server=selected_server, tokenizer=tokenizer)
try:
yield managed
finally:
managed.reset()
else:
managed = ManagedServer(
server=selected_server,
tokenizer=tokenizer,
tool_parser=self.tool_parser,
preserve_think_blocks=preserve_think_blocks,
)
try:
yield managed
finally:
# Clean up: reset tracked sequences
managed.reset()