mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
- Updated server classes (OpenAIServer, SGLangServer, TrlVllmServer, VLLMServer) to accept a ReasoningConfig parameter during initialization. - Enhanced ReasoningConfig to allow flexible max_tokens without strict validation, accommodating varying provider limits. - Implemented reasoning configuration injection in APIServer methods for chat and completion handling. - Updated tests to reflect changes in max_tokens validation logic. This commit integrates reasoning capabilities into the server handling architecture, improving compatibility with diverse reasoning models.
396 lines
16 KiB
Python
396 lines
16 KiB
Python
import asyncio
|
|
import inspect
|
|
import os
|
|
from contextlib import asynccontextmanager
|
|
from typing import AsyncGenerator, List, Optional, Union
|
|
|
|
from openai.types.chat.chat_completion import ChatCompletion
|
|
from openai.types.completion import Completion
|
|
from pydantic import BaseModel, Field
|
|
|
|
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
|
from atroposlib.envs.server_handling.openai_server import OpenAIServer
|
|
from atroposlib.envs.server_handling.server_baseline import (
|
|
APIServer,
|
|
APIServerConfig,
|
|
ReasoningConfig,
|
|
ServerBaseline,
|
|
)
|
|
from atroposlib.envs.server_handling.server_harness import ServerHarness
|
|
from atroposlib.envs.server_handling.sglang_server import SGLangServer
|
|
from atroposlib.envs.server_handling.trl_vllm_server import TrlVllmServer
|
|
from atroposlib.envs.server_handling.vllm_server import VLLMServer
|
|
|
|
|
|
class ServerManagerConfig(BaseModel):
|
|
slurm: bool = Field(
|
|
default=False, description="Whether environment is running on slurm or not."
|
|
)
|
|
testing: bool = Field(
|
|
default=False, description="If set to True, environment uses mock OpenAI data."
|
|
)
|
|
max_n_completions: int = Field(
|
|
default=8,
|
|
description=(
|
|
"The maximum number of completions to request at once per server call. "
|
|
"Will split any n larger than this into multiple calls. "
|
|
"This is to help load balance servers."
|
|
),
|
|
)
|
|
|
|
|
|
class ServerManager:
|
|
def __init__(
|
|
self,
|
|
configs: Union[ServerBaseline, List[APIServerConfig]],
|
|
server_class: APIServer = APIServer,
|
|
slurm=False,
|
|
testing=False,
|
|
max_n_completions=8,
|
|
reasoning_config: Optional[ReasoningConfig] = None,
|
|
):
|
|
self.max_n_completions = max_n_completions
|
|
self.reasoning_config = reasoning_config
|
|
# First we check to see if it's the base server class, and if so, we need to select the appropriate server class
|
|
# You can't use type() to check if it's the base server class, because it's an abstract class, it'll appear as
|
|
# an ABCMeta, not what you're expecting.
|
|
if inspect.isabstract(server_class):
|
|
if not isinstance(configs, list):
|
|
if configs.server_type == "openai":
|
|
server_class = OpenAIServer
|
|
elif configs.server_type == "trl":
|
|
server_class = TrlVllmServer
|
|
elif configs.server_type == "sglang":
|
|
server_class = SGLangServer
|
|
elif configs.server_type == "vllm":
|
|
server_class = VLLMServer
|
|
else:
|
|
raise ValueError(f"Invalid server type: {configs.server_type}")
|
|
else:
|
|
if configs[0].server_type == "openai":
|
|
server_class = OpenAIServer
|
|
elif configs[0].server_type == "trl":
|
|
server_class = TrlVllmServer
|
|
elif configs[0].server_type == "sglang":
|
|
server_class = SGLangServer
|
|
elif configs[0].server_type == "vllm":
|
|
server_class = VLLMServer
|
|
else:
|
|
raise ValueError(f"Invalid server type: {configs[0].server_type}")
|
|
if testing:
|
|
# testing :)
|
|
self.servers = [ServerHarness()]
|
|
return
|
|
if not isinstance(configs, list):
|
|
urls = []
|
|
if os.environ.get("SLURM_JOB_NODELIST", None) is not None:
|
|
nodelist = (
|
|
os.popen(
|
|
f'scontrol show hostnames {os.environ["SLURM_JOB_NODELIST"]}'
|
|
)
|
|
.read()
|
|
.split("\n")
|
|
)
|
|
nodelist = [node for node in nodelist if node != ""]
|
|
if len(nodelist) < 2:
|
|
# localhost!
|
|
for i in range(4):
|
|
urls.append(f"http://localhost:{9000 + i + 4}/v1")
|
|
else:
|
|
num_training_nodes = int(os.environ.get("NUM_TRAINING_NODES"))
|
|
for node in nodelist[num_training_nodes:]:
|
|
for i in range(8 // os.environ.get("INFER_TP", 1)):
|
|
urls.append(f"http://{node}:{9000 + i}/v1")
|
|
openai_configs = []
|
|
else:
|
|
# localhost!
|
|
for i in range(4):
|
|
urls.append(f"http://localhost:{9000 + i + 4}/v1")
|
|
openai_configs = []
|
|
for url in urls:
|
|
openai_configs.append(
|
|
APIServerConfig(
|
|
base_url=url,
|
|
timeout=configs.timeout,
|
|
num_max_requests_at_once=configs.num_max_requests_at_once,
|
|
num_requests_for_eval=configs.num_requests_for_eval,
|
|
model_name=configs.model_name,
|
|
rolling_buffer_length=configs.rolling_buffer_length,
|
|
api_key="x",
|
|
)
|
|
)
|
|
self.servers = [
|
|
server_class(config, reasoning_config=reasoning_config)
|
|
for config in openai_configs
|
|
]
|
|
elif not slurm:
|
|
self.servers = [
|
|
server_class(config, reasoning_config=reasoning_config)
|
|
for config in configs
|
|
]
|
|
else:
|
|
nodelist = (
|
|
os.popen(f'scontrol show hostnames {os.environ["SLURM_JOB_NODELIST"]}')
|
|
.read()
|
|
.split("\n")
|
|
)
|
|
nodelist = [node for node in nodelist if node != ""]
|
|
if len(nodelist) < 2:
|
|
print(
|
|
"Not enough nodes to distribute to, assuming single node"
|
|
" and you've setup your sglang appropriately."
|
|
)
|
|
self.servers = [
|
|
server_class(config, reasoning_config=reasoning_config)
|
|
for config in configs
|
|
]
|
|
return
|
|
urls = []
|
|
num_training_nodes = int(os.environ.get("NUM_TRAINING_NODES"))
|
|
for node in nodelist[num_training_nodes:]:
|
|
if node == "":
|
|
continue
|
|
for i in range(8 // os.environ.get("INFER_TP", 1)):
|
|
urls.append(f"http://{node}:{9000 + i}/v1")
|
|
# assume at least one good config is passed in
|
|
new_configs = []
|
|
for i in range(len(urls)):
|
|
new_conf = configs[0].model_copy(deep=True)
|
|
new_conf.base_url = urls[i]
|
|
new_configs.append(new_conf)
|
|
self.servers = [
|
|
server_class(config, reasoning_config=reasoning_config)
|
|
for config in new_configs
|
|
]
|
|
|
|
async def update_weight(self, weight: float):
|
|
for server in self.servers:
|
|
await server.update_weight(weight)
|
|
|
|
async def wait_for_sem(self, is_training: bool):
|
|
"""
|
|
Wait for a server to be available. This is used to prevent the client from
|
|
overwhelming the server with requests.
|
|
"""
|
|
|
|
def get_available_slots():
|
|
if is_training:
|
|
eval_vals = [
|
|
(
|
|
max(0, server.eval_sem._value - server.eval_sem.min_val())
|
|
if server.eval_sem._value != server.eval_sem.max_val
|
|
else 0
|
|
)
|
|
for server in self.servers
|
|
]
|
|
return [
|
|
max(0, (server.sem._value - server.sem.min_val()) - eval_val)
|
|
for server, eval_val in zip(self.servers, eval_vals)
|
|
]
|
|
else:
|
|
return [
|
|
max(0, server.eval_sem._value - server.eval_sem.min_val())
|
|
for server in self.servers
|
|
]
|
|
|
|
sem_vals = get_available_slots()
|
|
while all(sem_val <= 0 for sem_val in sem_vals):
|
|
# None available... wait
|
|
await asyncio.sleep(1)
|
|
sem_vals = get_available_slots()
|
|
|
|
async def chat_completion(self, **kwargs) -> ChatCompletion:
|
|
"""
|
|
Route chat completion to the most available server.
|
|
|
|
Reasoning config injection is handled by the individual servers.
|
|
Pass `skip_reasoning=True` to bypass reasoning injection for this call.
|
|
"""
|
|
n = kwargs.get("n", 1)
|
|
if n > self.max_n_completions:
|
|
# Split into multiple completions
|
|
completions = []
|
|
total_n = n
|
|
while total_n > 0:
|
|
n_to_use = min(total_n, self.max_n_completions)
|
|
kwargs["n"] = n_to_use
|
|
completions.append(self.chat_completion(**kwargs))
|
|
total_n -= n_to_use
|
|
completions = await asyncio.gather(
|
|
*completions
|
|
) # type: List[ChatCompletion]
|
|
# merge choices into one
|
|
out = completions[0]
|
|
for completion in completions[1:]:
|
|
out.choices.extend(completion.choices)
|
|
return out
|
|
is_train = kwargs.pop("split", "train") == "train"
|
|
most_available_server = 0
|
|
most_available_server_num_slots = -1
|
|
await self.wait_for_sem(is_train)
|
|
for i, server in enumerate(self.servers):
|
|
if not server.server_healthy:
|
|
continue
|
|
if (
|
|
server.sem._value if is_train else server.eval_sem._value
|
|
) > most_available_server_num_slots:
|
|
most_available_server = i
|
|
most_available_server_num_slots = (
|
|
server.sem._value if is_train else server.eval_sem._value
|
|
)
|
|
|
|
return await self.servers[most_available_server].chat_completion(**kwargs)
|
|
|
|
async def completion(self, **kwargs) -> Completion:
|
|
"""
|
|
Route completion to the most available server.
|
|
|
|
Reasoning config injection is handled by the individual servers.
|
|
Pass `skip_reasoning=True` to bypass reasoning injection for this call.
|
|
"""
|
|
n = kwargs.get("n", 1)
|
|
if n > self.max_n_completions:
|
|
# Split into multiple completions
|
|
completions = []
|
|
total_n = n
|
|
while total_n > 0:
|
|
n_to_use = min(total_n, self.max_n_completions)
|
|
kwargs["n"] = n_to_use
|
|
completions.append(self.completion(**kwargs))
|
|
total_n -= n_to_use
|
|
completions = await asyncio.gather(*completions) # type: List[Completion]
|
|
# merge choices into one
|
|
out = completions[0]
|
|
for completion in completions[1:]:
|
|
out.choices.extend(completion.choices)
|
|
return out
|
|
is_train = kwargs.pop("split", "train") == "train"
|
|
most_available_server = 0
|
|
most_available_server_num_slots = -1
|
|
await self.wait_for_sem(is_train)
|
|
for i, server in enumerate(self.servers):
|
|
if not server.server_healthy:
|
|
continue
|
|
if (
|
|
server.sem._value if is_train else server.eval_sem._value
|
|
) > most_available_server_num_slots:
|
|
most_available_server = i
|
|
most_available_server_num_slots = (
|
|
server.sem._value if is_train else server.eval_sem._value
|
|
)
|
|
|
|
return await self.servers[most_available_server].completion(**kwargs)
|
|
|
|
async def tokens_and_logprobs_completion(
|
|
self, **kwargs
|
|
) -> tuple[list, list, list, list]:
|
|
"""
|
|
Get tokens and logprobs from completion.
|
|
Returns (prompt_tokens, output_tokens, output_logprobs, finish_reasons).
|
|
|
|
Note: Reasoning config is NOT injected here - this method is for extracting
|
|
raw token-level data for training, not for generating reasoned responses.
|
|
"""
|
|
n = kwargs.get("n", 1)
|
|
if n > self.max_n_completions:
|
|
# Split into multiple completions
|
|
results = []
|
|
total_n = n
|
|
while total_n > 0:
|
|
n_to_use = min(total_n, self.max_n_completions)
|
|
kwargs["n"] = n_to_use
|
|
results.append(self.tokens_and_logprobs_completion(**kwargs))
|
|
total_n -= n_to_use
|
|
results = await asyncio.gather(*results)
|
|
# Merge results - prompt_tokens should be same, extend output lists
|
|
prompt_tokens = results[0][0]
|
|
output_tokens = []
|
|
output_logprobs = []
|
|
finish_reasons = []
|
|
for _, out_tokens, out_logprobs, out_finish_reasons in results:
|
|
output_tokens.extend(out_tokens)
|
|
output_logprobs.extend(out_logprobs)
|
|
finish_reasons.extend(out_finish_reasons)
|
|
return (prompt_tokens, output_tokens, output_logprobs, finish_reasons)
|
|
|
|
is_train = kwargs.pop("split", "train") == "train"
|
|
most_available_server = 0
|
|
most_available_server_num_slots = -1
|
|
await self.wait_for_sem(is_train)
|
|
for i, server in enumerate(self.servers):
|
|
if not server.server_healthy:
|
|
continue
|
|
if (
|
|
server.sem._value if is_train else server.eval_sem._value
|
|
) > most_available_server_num_slots:
|
|
most_available_server = i
|
|
most_available_server_num_slots = (
|
|
server.sem._value if is_train else server.eval_sem._value
|
|
)
|
|
|
|
return await self.servers[most_available_server].tokens_and_logprobs_completion(
|
|
**kwargs
|
|
)
|
|
|
|
@asynccontextmanager
|
|
async def dedicated_server(self) -> AsyncGenerator[OpenAIServer, None]:
|
|
most_available_server = 0
|
|
most_available_server_num_slots = -1
|
|
for i, server in enumerate(self.servers):
|
|
if not server.server_healthy:
|
|
continue
|
|
if server.sem._value > most_available_server_num_slots:
|
|
most_available_server = i
|
|
most_available_server_num_slots = server.sem._value
|
|
async with self.servers[most_available_server].sem:
|
|
try:
|
|
yield self.servers[most_available_server]
|
|
finally:
|
|
pass
|
|
|
|
@asynccontextmanager
|
|
async def managed_server(
|
|
self, tokenizer=None
|
|
) -> AsyncGenerator[ManagedServer, None]:
|
|
"""
|
|
Context manager that provides a ManagedServer instance.
|
|
|
|
The ManagedServer wraps the most available server and tracks text sequences
|
|
with aligned tokens and logprobs. State is automatically cleared on exit.
|
|
|
|
Args:
|
|
tokenizer: Optional tokenizer to use. If not provided, will attempt to
|
|
extract from server or create from model name.
|
|
|
|
Yields:
|
|
ManagedServer instance wrapping the selected server
|
|
|
|
Example:
|
|
async with server_manager.managed_server() as managed:
|
|
response = await managed.chat_completion(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
n=2
|
|
)
|
|
state = managed.get_state()
|
|
# Process state...
|
|
# State is automatically cleared when exiting context
|
|
"""
|
|
most_available_server = 0
|
|
most_available_server_num_slots = -1
|
|
for i, server in enumerate(self.servers):
|
|
if not server.server_healthy:
|
|
continue
|
|
if server.sem._value > most_available_server_num_slots:
|
|
most_available_server = i
|
|
most_available_server_num_slots = server.sem._value
|
|
|
|
# Create ManagedServer wrapping the selected server
|
|
managed = ManagedServer(
|
|
server=self.servers[most_available_server], tokenizer=tokenizer
|
|
)
|
|
|
|
try:
|
|
yield managed
|
|
finally:
|
|
# Clean up: reset tracked sequences
|
|
managed.reset()
|