Merge branch 'main' into blackjack2-env

This commit is contained in:
Shannon Sands 2025-05-14 17:27:44 -07:00
commit 00dd120067
34 changed files with 1620 additions and 386 deletions

View file

@ -58,10 +58,38 @@ These methods have default implementations or are optional based on your needs:
* **`save_checkpoint(self, step, data=None)`**: The base class calls this method automatically at checkpoint intervals determined by the server. It saves the provided `data` dictionary (which you might populate with environment-specific state) to a JSON file. You can override this to customize *what* data is saved or *how* it's saved (e.g., using a different format or location), but the triggering mechanism remains automatic.
* **`@classmethod config_init(cls) -> Tuple[BaseEnvConfig, Union[ServerBaseline, List[OpenaiConfig]]]`**: This class method is used by the default `get_cli_serve_config_cls` implementation to get the initial environment configuration (`BaseEnvConfig` subclass) and server configurations (`ServerBaseline` or `List[OpenaiConfig]`) when setting up the `serve` command. The default implementation returns `cls.env_config_cls(), ServerBaseline()`. You might override this if your environment requires different default configurations or specific server setups (like multiple `OpenaiConfig` instances) when run via the CLI `serve` command.
* **`@classmethod config_init(cls) -> Tuple[BaseEnvConfig, Union[ServerBaseline, List[APIServerConfig]]]`**: This class method is used by the default `get_cli_serve_config_cls` implementation to get the initial environment configuration (`BaseEnvConfig` subclass) and server configurations (`ServerBaseline` or `List[APIServerConfig]`) when setting up the `serve` command. The default implementation returns `cls.env_config_cls(), ServerBaseline()`. You might override this if your environment requires different default configurations or specific server setups (like multiple `APIServerConfig` instances) when run via the CLI `serve` command.
* **`async def cleanup(self)`**: Called after each call to `handle_env`. You can implement this for any cleanup needed after processing a single item, though it's often not required.
## Overrideable Class Variables
These class-level variables in `BaseEnv` can be overridden in your subclass to customize its behavior:
* **`name: Optional[str]`**:
* Default: `None`
* Purpose: You can set a string name for your environment. This name is used by default for `wandb_name` in the `BaseEnvConfig` if not otherwise specified, influencing how runs are grouped or named in Weights & Biases. It can also be useful for general identification or logging purposes.
* **`env_config_cls: Type[BaseEnvConfig]`**:
* Default: `BaseEnvConfig`
* Purpose: This variable holds the Pydantic model class that will be used for your environment's configuration. If your environment requires custom configuration fields beyond what `BaseEnvConfig` offers, you should create a new class that inherits from `BaseEnvConfig` (or a subclass thereof) and assign it to `env_config_cls`. This allows the CLI and other parts of the system to correctly parse and manage your environment's specific settings.
```python
from pydantic import Field
from atroposlib.envs import BaseEnv, BaseEnvConfig
class MyEnvConfig(BaseEnvConfig):
my_custom_param: str = Field(default="default_value", description="A custom parameter for MyEnv")
class MyEnv(BaseEnv):
env_config_cls = MyEnvConfig
name = "MyCustomEnvironment"
# ... other implementations
```
* **`server_cls: Type[APIServer]`**:
* Default: `APIServer`
* Purpose: Specifies the class to be used for managing interactions with API servers (e.g., inference endpoints). Should mostly be used for developing addiitonal API interfaces, but if you need a nonstandard way of connecting with an existing API you can use this to easily slot in any modifications you need.
## Provided Functionality
`BaseEnv` provides several helpful features:

View file

@ -40,7 +40,8 @@ from atroposlib.utils.metrics import get_std_min_max_avg
from ..type_definitions import Item, Message
from .server_handling.server_manager import (
OpenaiConfig,
APIServer,
APIServerConfig,
ServerBaseline,
ServerManager,
ServerManagerConfig,
@ -163,13 +164,14 @@ class BaseEnvConfig(BaseModel):
class BaseEnv(ABC):
name = None
env_config_cls = BaseEnvConfig
name: Optional[str] = None
env_config_cls: BaseEnvConfig = BaseEnvConfig
server_cls: APIServer = APIServer
def __init__(
self,
config: BaseEnvConfig,
server_configs: Union[ServerBaseline, List[OpenaiConfig]],
server_configs: Union[ServerBaseline, List[APIServerConfig]],
slurm=False,
testing=False,
):
@ -184,7 +186,9 @@ class BaseEnv(ABC):
self.last_loop_time = None
self.last_completed_item = None
self.config = config
self.server = ServerManager(server_configs, slurm=slurm, testing=testing)
self.server = ServerManager(
server_configs, slurm=slurm, testing=testing, server_class=self.server_cls
)
self.workers = set()
self.eval_workers = set()
self.backlog = []
@ -234,7 +238,7 @@ class BaseEnv(ABC):
@classmethod
def config_init(
cls,
) -> Tuple[BaseEnvConfig, Union[ServerBaseline, List[OpenaiConfig]]]:
) -> Tuple[BaseEnvConfig, Union[ServerBaseline, List[APIServerConfig]]]:
"""
Initialize the config
"""
@ -1020,7 +1024,6 @@ class BaseEnv(ABC):
Returns:
type: The CliServeConfig class for serving commands.
"""
# Get the default configurations defined by the specific environment class
default_env_config, default_server_configs = cls.config_init()
@ -1032,8 +1035,8 @@ class BaseEnv(ABC):
class CliServeConfig(
get_prefixed_pydantic_model(type(default_env_config), env_full_prefix),
get_prefixed_pydantic_model(
OpenaiConfig, openai_full_prefix
), # Use OpenaiConfig for CLI args
APIServerConfig, openai_full_prefix
), # Use APIServerConfig for CLI args
ServerManagerConfig, # ServerManager args are not namespaced by default
Cmd,
):
@ -1089,7 +1092,7 @@ class BaseEnv(ABC):
oai_cli_passed_args or yaml_oai_config
):
raise ValueError(
"ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use OpenaiConfig." # noqa: E501
"ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use APIServerConfig." # noqa: E501
)
if (
isinstance(default_server_configs, list)
@ -1101,11 +1104,11 @@ class BaseEnv(ABC):
default_openai_config_ = default_server_configs
if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1:
yaml_oai_config = yaml_oai_config[0]
if isinstance(default_openai_config_, OpenaiConfig) and isinstance(
if isinstance(default_openai_config_, APIServerConfig) and isinstance(
yaml_oai_config, dict
):
openai_config_dict = merge_dicts(
default_openai_config_.model_dump(), # Default OpenaiConfig (or from class init)
default_openai_config_.model_dump(), # Default APIServerConfig (or from class init)
yaml_oai_config,
oai_cli_passed_args,
)
@ -1189,7 +1192,7 @@ class BaseEnv(ABC):
data_path_to_save_groups=f"data/{cls.name or 'groups'}.jsonl",
use_wandb=True,
)
PROCESS_MODE_OPENAI_DEFAULT_CONFIG = OpenaiConfig(
PROCESS_MODE_OPENAI_DEFAULT_CONFIG = APIServerConfig(
model_name="gpt-4.1-nano",
base_url=None,
api_key=None,
@ -1200,10 +1203,7 @@ class BaseEnv(ABC):
)
# Get the base default configurations from the specific environment class
(
default_env_config,
default_server_configs,
) = cls.config_init()
default_env_config, default_server_configs = cls.config_init()
# Define namespace prefixes
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
@ -1215,8 +1215,7 @@ class BaseEnv(ABC):
type(default_env_config), PROCESS_MODE_ENV_DEFAULT_CONFIG
)
openai_config_cls_new_defaults = adjust_model_defaults(
OpenaiConfig,
PROCESS_MODE_OPENAI_DEFAULT_CONFIG,
APIServerConfig, PROCESS_MODE_OPENAI_DEFAULT_CONFIG
)
server_manager_config_cls_new_defaults = adjust_model_defaults(
ServerManagerConfig,
@ -1283,7 +1282,7 @@ class BaseEnv(ABC):
oai_cli_passed_args or yaml_oai_config
):
raise ValueError(
"ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use OpenaiConfig." # noqa: E501
"ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use APIServerConfig." # noqa: E501
)
if (
@ -1296,11 +1295,11 @@ class BaseEnv(ABC):
default_openai_config_ = default_server_configs
if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1:
yaml_oai_config = yaml_oai_config[0]
if isinstance(default_openai_config_, OpenaiConfig) and isinstance(
if isinstance(default_openai_config_, APIServerConfig) and isinstance(
yaml_oai_config, dict
):
openai_config_dict = merge_dicts(
default_openai_config_.model_dump(), # Default OpenaiConfig (or from class init)
default_openai_config_.model_dump(), # Default APIServerConfig (or from class init)
PROCESS_MODE_OPENAI_DEFAULT_CONFIG.model_dump(), # Process Mode Defaults
yaml_oai_config,
oai_cli_passed_args,

View file

@ -1,142 +1,28 @@
import asyncio
import collections
import time
from asyncio import exceptions
from typing import Optional
import warnings
import aiohttp
import numpy as np
import openai
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.completion import Completion
from pydantic import BaseModel, Field
from pydantic_cli import FailedExecutionException
from tenacity import retry, stop_after_attempt, wait_random_exponential
from atroposlib.envs.constants import NAMESPACE_SEP, OPENAI_NAMESPACE
from atroposlib.envs.server_handling.server_baseline import APIServer, APIServerConfig
class OpenaiConfig(BaseModel):
class OpenAIServer(APIServer):
"""
Configuration for the server manager.
OpenAI server handling.
"""
api_key: Optional[str] = Field(
default=None, description="API key for OpenAI API. Use 'x' for local servers."
)
base_url: Optional[str] = Field(
default=None,
description="URL of the API endpoint. None if using official OpenAI API, otherwise local server URL.",
)
timeout: int = Field(
default=1200, description="Timeout for the request in seconds."
)
num_max_requests_at_once: int = Field(
default=512,
description="Maximum number of concurrent requests. Note: You should divide this by the n kwarg.",
)
num_requests_for_eval: int = Field(
default=64, description="Maximum number of concurrent requests for evaluation."
)
model_name: str = Field(
default="default",
description="The model name to use. Required for both OpenAI and local models.",
)
rolling_buffer_length: int = Field(
default=1000, description="Length of the rolling buffer to store metrics."
)
class AsyncSemWithAdaptiveWeight(asyncio.Semaphore):
def __init__(self, value: int):
super().__init__(value=value)
self.max_val = value
self.weight = 1.0
def update_weight(self, weight: float) -> None:
self.weight = weight
def min_val(self):
return self.max_val * (1.0 - self.weight)
def release(self):
"""Release a semaphore, incrementing the internal counter by one.
When it was zero on entry and another coroutine is waiting for it to
become larger than zero again, wake up that coroutine.
If weight is set, it'll only wake up next if the value is greater than the max_val * weight
"""
self._value += 1
if self._value > self.min_val():
self._wake_up_next()
def locked(self):
"""Returns True if semaphore cannot be acquired immediately."""
return self._value <= self.min_val() or (
any(not w.cancelled() for w in (self._waiters or ()))
)
async def acquire(self):
"""Acquire a semaphore.
If the internal counter is larger than zero on entry,
decrement it by one and return True immediately. If it is
zero on entry, block, waiting until some other coroutine has
called release() to make it larger than 0, and then return
True.
"""
if not self.locked():
self._value -= 1
return True
if self._waiters is None:
self._waiters = collections.deque()
fut = self._get_loop().create_future()
self._waiters.append(fut)
# Finally block should be called before the CancelledError
# handling as we don't want CancelledError to call
# _wake_up_first() and attempt to wake up itself.
try:
try:
await fut
finally:
self._waiters.remove(fut)
except exceptions.CancelledError:
if not fut.cancelled():
self._value += 1
self._wake_up_next()
raise
if self._value > self.min_val():
self._wake_up_next()
return True
class OpenAIServer:
def __init__(self, config: OpenaiConfig):
self.config = config
def __init__(self, config: APIServerConfig):
self.openai = openai.AsyncClient(
api_key=config.api_key,
base_url=config.base_url,
timeout=config.timeout,
)
self.sem = AsyncSemWithAdaptiveWeight(config.num_max_requests_at_once)
self.eval_sem = AsyncSemWithAdaptiveWeight(config.num_requests_for_eval)
self.server_healthy = True
self.attempts_list = []
self.request_timings = []
# in case eval is much different, we should keep different buffers
self.eval_attempts_list = []
self.eval_request_timings = []
self.check_task = None
self.initialized = False
async def update_weight(self, weight: float) -> None:
# need to update sems
self.sem.update_weight(weight)
self.eval_sem.update_weight(weight)
super().__init__(config)
async def check_server_status_task(self):
while True:
@ -156,147 +42,90 @@ class OpenAIServer:
self.server_healthy = False
await asyncio.sleep(1)
async def wandb_metrics(
self, metrics_dict: Optional[dict], server_name: Optional[str]
):
if server_name is None:
server_name = "server"
if len(self.request_timings) > 0:
metrics_dict[f"server/{server_name}_request_time_avg"] = np.mean(
self.request_timings
async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion:
"""
Wrapper for the chat completion using the openai client.
"""
assert (
kwargs.get("model", None) is not None
), "Model is required for chat completion!"
assert (
kwargs.get("messages", None) is not None
), "Messages are required for chat completion!"
if self.config.n_kwarg_is_ignored:
n = kwargs.pop("n", 1)
completion_list = await asyncio.gather(
*[self.openai.chat.completions.create(**kwargs) for _ in range(n)]
)
metrics_dict[f"server/{server_name}_request_time_std"] = np.std(
self.request_timings
)
metrics_dict[f"server/{server_name}_request_time_99p"] = np.percentile(
self.request_timings, 99
)
if len(self.eval_request_timings) > 0:
metrics_dict[f"server/{server_name}_eval_request_time_avg"] = np.mean(
self.eval_request_timings
)
metrics_dict[f"server/{server_name}_eval_request_time_std"] = np.std(
self.eval_request_timings
)
metrics_dict[f"server/{server_name}_eval_request_time_99p"] = np.percentile(
self.eval_request_timings, 99
)
if len(self.attempts_list) > 0:
metrics_dict[f"server/{server_name}_average_num_attempts"] = np.mean(
self.attempts_list
)
if len(self.eval_attempts_list) > 0:
metrics_dict[f"server/{server_name}_eval_retry_rate"] = np.mean(
self.eval_attempts_list
)
return metrics_dict
@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
)
async def _chat_comp(self, stat_dict, **kwargs) -> ChatCompletion:
while not self.server_healthy:
await asyncio.sleep(1)
async with self.sem:
if stat_dict.get("start", None) is None:
stat_dict["start"] = time.time()
stat_dict["attempts"] += 1
completions = await self.openai.chat.completions.create(**kwargs)
stat_dict["end"] = time.time()
return completions
@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
)
async def _chat_eval(self, stat_dict, **kwargs) -> ChatCompletion:
while not self.server_healthy:
await asyncio.sleep(1)
async with self.eval_sem:
if stat_dict.get("start", None) is None:
stat_dict["start"] = time.time()
stat_dict["attempts"] += 1
completions = await self.openai.chat.completions.create(**kwargs)
stat_dict["end"] = time.time()
return completions
@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
)
async def chat_completion(self, **kwargs) -> ChatCompletion:
if not self.initialized:
if (
self.config.base_url is not None
): # skip health check if using OpenAI API
self.check_task = asyncio.create_task(self.check_server_status_task())
completions = completion_list[0]
if n > 1:
for c in completion_list[1:]:
completions.choices.extend(c.choices)
else:
self.server_healthy = True
self.initialized = True
kwargs["model"] = self.config.model_name
split = kwargs.pop("split", "train")
stat_dict = {}
stat_dict["attempts"] = 0
if split == "train":
ret_data = await self._chat_comp(stat_dict, **kwargs)
self.request_timings.append(stat_dict["end"] - stat_dict["start"])
self.attempts_list.append(stat_dict["attempts"])
completions = await self.openai.chat.completions.create(**kwargs)
else:
# Give separate eval workers, if desired, gotta go fast for those evals
ret_data = await self._chat_eval(stat_dict, **kwargs)
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
self.eval_attempts_list.append(stat_dict["attempts"])
return ret_data
@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
)
async def _comp(self, stat_dict, **kwargs) -> Completion:
while not self.server_healthy:
await asyncio.sleep(1)
async with self.sem:
if stat_dict.get("start", None) is None:
stat_dict["start"] = time.time()
stat_dict["attempts"] += 1
completions = await self.openai.completions.create(**kwargs)
stat_dict["end"] = time.time()
return completions
@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
)
async def _comp_eval(self, stat_dict, **kwargs) -> Completion:
while not self.server_healthy:
await asyncio.sleep(1)
async with self.eval_sem:
if stat_dict.get("start", None) is None:
stat_dict["start"] = time.time()
stat_dict["attempts"] += 1
completions = await self.openai.completions.create(**kwargs)
stat_dict["end"] = time.time()
return completions
async def completion(self, **kwargs) -> Completion:
if not self.initialized:
if (
self.config.base_url is not None
): # skip health check if using OpenAI API
self.check_task = asyncio.create_task(self.check_server_status_task())
if "n" in kwargs:
n = kwargs["n"]
else:
self.server_healthy = True
self.initialized = True
kwargs["model"] = self.config.model_name
split = kwargs.pop("split", "train")
stat_dict = {}
stat_dict["attempts"] = 0
if split == "train":
ret_data = await self._comp(stat_dict, **kwargs)
self.request_timings.append(stat_dict["end"] - stat_dict["start"])
self.attempts_list.append(stat_dict["attempts"])
n = 1
completions = await self.openai.chat.completions.create(**kwargs)
if len(completions.choices) != n:
if len(completions.choices) != 1:
raise ValueError(
f"Expected 1 or {n} completions, got {len(completions.choices)}!"
)
else:
warnings.warn("n kwarg is ignored by the API, setting to True")
self.config.n_kwarg_is_ignored = True
completion_list = await asyncio.gather(
*[
self.openai.chat.completions.create(**kwargs)
for _ in range(1, n)
]
)
for c in completion_list:
completions.choices.extend(c.choices)
return completions
async def _completion_wrapper(self, **kwargs) -> Completion:
"""
Wrapper for the completion using the openai client.
"""
assert (
kwargs.get("model", None) is not None
), "Model is required for completion!"
assert (
kwargs.get("prompt", None) is not None
), "Prompt is required for completion!"
if self.config.n_kwarg_is_ignored:
n = kwargs.pop("n", 1)
completion_list = await asyncio.gather(
*[self.openai.completions.create(**kwargs) for _ in range(n)]
)
completions = completion_list[0]
if n > 1:
for c in completion_list[1:]:
completions.choices.extend(c.choices)
else:
# Give separate eval workers, if desired, gotta go fast for those evals
ret_data = await self._comp_eval(stat_dict, **kwargs)
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
self.eval_attempts_list.append(stat_dict["attempts"])
return ret_data
if "n" in kwargs:
n = kwargs["n"]
else:
n = 1
completions = await self.openai.completions.create(**kwargs)
if len(completions.choices) != n:
if len(completions.choices) != 1:
raise ValueError(
f"Expected 1 or {n} completions, got {len(completions.choices)}!"
)
else:
warnings.warn("n kwarg is ignored by the API, setting to True")
self.config.n_kwarg_is_ignored = True
completion_list = await asyncio.gather(
*[self.openai.completions.create(**kwargs) for _ in range(1, n)]
)
for c in completion_list:
completions.choices.extend(c.choices)
return completions
def resolve_openai_configs(
@ -338,7 +167,7 @@ def resolve_openai_configs(
f"Using multi-server configuration defined in YAML under '{OPENAI_NAMESPACE}'."
)
try:
server_configs = [OpenaiConfig(**cfg) for cfg in openai_yaml_config]
server_configs = [APIServerConfig(**cfg) for cfg in openai_yaml_config]
except Exception as e:
raise FailedExecutionException(
f"Error parsing multi-server OpenAI configuration from YAML under '{OPENAI_NAMESPACE}': {e}"
@ -354,14 +183,14 @@ def resolve_openai_configs(
"Using single OpenAI server configuration based on merged settings (default/YAML/CLI)."
)
try:
final_openai_config = OpenaiConfig(**openai_config_dict)
final_openai_config = APIServerConfig(**openai_config_dict)
except Exception as e:
raise FailedExecutionException(
f"Error creating final OpenAI configuration from merged settings: {e}\n"
f"Merged Dict: {openai_config_dict}"
) from e
if isinstance(default_server_configs, OpenaiConfig):
if isinstance(default_server_configs, APIServerConfig):
server_configs = final_openai_config
elif isinstance(default_server_configs, list):
server_configs = [final_openai_config]

View file

@ -0,0 +1,340 @@
import asyncio
import collections
import time
from abc import ABC, abstractmethod
from asyncio import exceptions
from typing import Literal, Optional
import numpy as np
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.completion import Completion
from pydantic import BaseModel, Field
from tenacity import retry, stop_after_attempt, wait_random_exponential
class AsyncSemWithAdaptiveWeight(asyncio.Semaphore):
def __init__(self, value: int):
super().__init__(value=value)
self.max_val = value
self.weight = 1.0
def update_weight(self, weight: float) -> None:
"""
Update the weight of the semaphore.
"""
self.weight = weight
def min_val(self):
"""
Returns the minimum value of the semaphore.
"""
return self.max_val * (1.0 - self.weight)
def release(self):
"""Release a semaphore, incrementing the internal counter by one.
When it was zero on entry and another coroutine is waiting for it to
become larger than zero again, wake up that coroutine.
If weight is set, it'll only wake up next if the value is greater than the max_val * weight
"""
self._value += 1
if self._value > self.min_val():
self._wake_up_next()
def locked(self):
"""Returns True if semaphore cannot be acquired immediately."""
return self._value <= self.min_val() or (
any(not w.cancelled() for w in (self._waiters or ()))
)
async def acquire(self):
"""Acquire a semaphore.
If the internal counter is larger than zero on entry,
decrement it by one and return True immediately. If it is
zero on entry, block, waiting until some other coroutine has
called release() to make it larger than 0, and then return
True.
"""
if not self.locked():
self._value -= 1
return True
if self._waiters is None:
self._waiters = collections.deque()
fut = self._get_loop().create_future()
self._waiters.append(fut)
# Finally block should be called before the CancelledError
# handling as we don't want CancelledError to call
# _wake_up_first() and attempt to wake up itself.
try:
try:
await fut
finally:
self._waiters.remove(fut)
except exceptions.CancelledError:
if not fut.cancelled():
self._value += 1
self._wake_up_next()
raise
if self._value > self.min_val():
self._wake_up_next()
return True
class ServerBaseline(BaseModel):
"""
Baseline configuration for server information. If local, uses ports 9004-9007 for the servers,
assuming a 1:1 split of GPUs.
"""
timeout: int = Field(
default=1200, description="Timeout for the request in seconds."
)
num_max_requests_at_once: int = Field(
default=512,
description="Maximum number of concurrent requests. You should divide this by the n kwarg.",
)
num_requests_for_eval: int = Field(
default=64, description="Maximum number of concurrent requests for evaluation."
)
model_name: str = Field(
default="default",
description="The model name to use. Only works with sglang, please provide the model name.",
)
rolling_buffer_length: int = Field(
default=1000, description="Length of the rolling buffer to store metrics."
)
server_type: Literal["openai", "trl"] = Field(
default="openai", description="Type of server to use, openai or trl"
)
class APIServerConfig(ServerBaseline):
"""
API server configuration.
"""
api_key: Optional[str] = Field(default="", description="API key for the server.")
base_url: Optional[str] = Field(default="", description="Base URL for the server.")
n_kwarg_is_ignored: bool = Field(
default=False, description="Whether the n kwarg is ignored by this API server."
)
class APIServer(ABC):
"""
Abstract class for API servers.
"""
def __init__(self, config: APIServerConfig):
self.config = config
self.sem = AsyncSemWithAdaptiveWeight(config.num_max_requests_at_once)
self.eval_sem = AsyncSemWithAdaptiveWeight(config.num_requests_for_eval)
self.server_healthy = True
self.attempts_list = []
self.request_timings = []
# in case eval is much different, we should keep different buffers
self.eval_attempts_list = []
self.eval_request_timings = []
self.check_task = None
self.initialized = False
async def update_weight(self, weight: float) -> None:
"""
Update the weight of the semaphores
"""
# need to update sems
self.sem.update_weight(weight)
self.eval_sem.update_weight(weight)
@abstractmethod
async def check_server_status_task(self):
"""
Check the status of the server. Should be overridden by the child class.
Set self.server_healthy to True if the server is healthy.
"""
self.server_healthy = False
async def wandb_metrics(
self, metrics_dict: Optional[dict], server_name: Optional[str]
):
"""
Add metrics to the metrics dictionary.
If you want to add more metrics, you can do so by overriding this method, but make sure to call
super().wandb_metrics(metrics_dict, server_name) first to get the default metrics, if you still want them.
"""
if server_name is None:
server_name = "server"
if len(self.request_timings) > 0:
metrics_dict[f"server/{server_name}_request_time_avg"] = np.mean(
self.request_timings
)
metrics_dict[f"server/{server_name}_request_time_std"] = np.std(
self.request_timings
)
metrics_dict[f"server/{server_name}_request_time_99p"] = np.percentile(
self.request_timings, 99
)
if len(self.eval_request_timings) > 0:
metrics_dict[f"server/{server_name}_eval_request_time_avg"] = np.mean(
self.eval_request_timings
)
metrics_dict[f"server/{server_name}_eval_request_time_std"] = np.std(
self.eval_request_timings
)
metrics_dict[f"server/{server_name}_eval_request_time_99p"] = np.percentile(
self.eval_request_timings, 99
)
if len(self.attempts_list) > 0:
metrics_dict[f"server/{server_name}_average_num_attempts"] = np.mean(
self.attempts_list
)
if len(self.eval_attempts_list) > 0:
metrics_dict[f"server/{server_name}_eval_retry_rate"] = np.mean(
self.eval_attempts_list
)
return metrics_dict
@abstractmethod
async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion:
"""
Wrapper for the chat completion. Should be overridden by the child class and return a ChatCompletion object.
"""
pass
@abstractmethod
async def _completion_wrapper(self, **kwargs) -> Completion:
"""
Wrapper for the completion. Should be overridden by the child class and return a Completion object.
"""
pass
@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
)
async def _chat_comp(self, stat_dict, **kwargs) -> ChatCompletion:
"""
Simple retry and stat collection wrapper for the chat completion.
"""
while not self.server_healthy:
await asyncio.sleep(1)
async with self.sem:
if stat_dict.get("start", None) is None:
stat_dict["start"] = time.time()
stat_dict["attempts"] += 1
completions = await self._chat_completion_wrapper(**kwargs)
stat_dict["end"] = time.time()
return completions
@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
)
async def _chat_eval(self, stat_dict, **kwargs) -> ChatCompletion:
"""
Simple retry and stat collection wrapper for the chat completion.
"""
while not self.server_healthy:
await asyncio.sleep(1)
async with self.eval_sem:
if stat_dict.get("start", None) is None:
stat_dict["start"] = time.time()
stat_dict["attempts"] += 1
completions = await self._chat_completion_wrapper(**kwargs)
stat_dict["end"] = time.time()
return completions
@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
)
async def chat_completion(self, **kwargs) -> ChatCompletion:
"""
Chat completion handler, waits for the server to be healthy and then calls the chat completion wrapper.
"""
if not self.initialized:
if (
self.config.base_url is not None
): # skip health check if using OpenAI API
self.check_task = asyncio.create_task(self.check_server_status_task())
else:
self.server_healthy = True
self.initialized = True
kwargs["model"] = self.config.model_name
split = kwargs.pop("split", "train")
stat_dict = {}
stat_dict["attempts"] = 0
if split == "train":
ret_data = await self._chat_comp(stat_dict, **kwargs)
self.request_timings.append(stat_dict["end"] - stat_dict["start"])
self.attempts_list.append(stat_dict["attempts"])
else:
# Give separate eval workers, if desired, gotta go fast for those evals
ret_data = await self._chat_eval(stat_dict, **kwargs)
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
self.eval_attempts_list.append(stat_dict["attempts"])
return ret_data
@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
)
async def _comp(self, stat_dict, **kwargs) -> Completion:
"""
Simple retry and stat collection wrapper for the completion.
"""
while not self.server_healthy:
await asyncio.sleep(1)
async with self.sem:
if stat_dict.get("start", None) is None:
stat_dict["start"] = time.time()
stat_dict["attempts"] += 1
completions = await self._completion_wrapper(**kwargs)
stat_dict["end"] = time.time()
return completions
@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
)
async def _comp_eval(self, stat_dict, **kwargs) -> Completion:
"""
Simple retry and stat collection wrapper for the completion.
"""
while not self.server_healthy:
await asyncio.sleep(1)
async with self.eval_sem:
if stat_dict.get("start", None) is None:
stat_dict["start"] = time.time()
stat_dict["attempts"] += 1
completions = await self._completion_wrapper(**kwargs)
stat_dict["end"] = time.time()
return completions
async def completion(self, **kwargs) -> Completion:
"""
Completion handler, waits for the server to be healthy and then calls the completion wrapper.
"""
if not self.initialized:
if (
self.config.base_url is not None
): # skip health check if using OpenAI API
self.check_task = asyncio.create_task(self.check_server_status_task())
else:
self.server_healthy = True
self.initialized = True
kwargs["model"] = self.config.model_name
split = kwargs.pop("split", "train")
stat_dict = {}
stat_dict["attempts"] = 0
if split == "train":
ret_data = await self._comp(stat_dict, **kwargs)
self.request_timings.append(stat_dict["end"] - stat_dict["start"])
self.attempts_list.append(stat_dict["attempts"])
else:
# Give separate eval workers, if desired, gotta go fast for those evals
ret_data = await self._comp_eval(stat_dict, **kwargs)
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
self.eval_attempts_list.append(stat_dict["attempts"])
return ret_data

View file

@ -1,4 +1,5 @@
import asyncio
import inspect
import os
from contextlib import asynccontextmanager
from typing import AsyncGenerator, List, Union
@ -7,56 +8,56 @@ from openai.types.chat.chat_completion import ChatCompletion
from openai.types.completion import Completion
from pydantic import BaseModel, Field
from atroposlib.envs.server_handling.openai_server import OpenaiConfig, OpenAIServer
from atroposlib.envs.server_handling.openai_server import OpenAIServer
from atroposlib.envs.server_handling.server_baseline import (
APIServer,
APIServerConfig,
ServerBaseline,
)
from atroposlib.envs.server_handling.server_harness import ServerHarness
from atroposlib.envs.server_handling.trl_vllm_server import TrlVllmServer
class ServerManagerConfig(BaseModel):
slurm: bool = Field(
default=True, description="Whether environment is running on slurm or not."
default=False, description="Whether environment is running on slurm or not."
)
testing: bool = Field(
default=False, description="If set to True, environment uses mock OpenAI data."
)
class ServerBaseline(BaseModel):
"""
Baseline configuration for server information. If local, uses ports 9004-9007 for the servers,
assuming a 1:1 split of GPUs.
"""
timeout: int = Field(
default=1200, description="Timeout for the request in seconds."
)
num_max_requests_at_once: int = Field(
default=512,
description="Maximum number of concurrent requests. You should divide this by the n kwarg.",
)
num_requests_for_eval: int = Field(
default=64, description="Maximum number of concurrent requests for evaluation."
)
model_name: str = Field(
default="default",
description="The model name to use. Only works with sglang, please provide the model name.",
)
rolling_buffer_length: int = Field(
default=1000, description="Length of the rolling buffer to store metrics."
)
class ServerManager:
def __init__(
self,
configs: Union[ServerBaseline, List[OpenaiConfig]],
configs: Union[ServerBaseline, List[APIServerConfig]],
server_class: APIServer = APIServer,
slurm=False,
testing=False,
):
# First we check to see if it's the base server class, and if so, we need to select the appropriate server class
# You can't use type() to check if it's the base server class, because it's an abstract class, it'll appear as
# an ABCMeta, not what you're expecting.
if inspect.isabstract(server_class):
if not isinstance(configs, list):
if configs.server_type == "openai":
server_class = OpenAIServer
elif configs.server_type == "trl":
server_class = TrlVllmServer
else:
raise ValueError(f"Invalid server type: {configs.server_type}")
else:
if configs[0].server_type == "openai":
server_class = OpenAIServer
elif configs[0].server_type == "trl":
server_class = TrlVllmServer
else:
raise ValueError(f"Invalid server type: {configs[0].server_type}")
if testing:
# testing :)
self.servers = [ServerHarness()]
return
if isinstance(configs, ServerBaseline):
if not isinstance(configs, list):
urls = []
if os.environ.get("SLURM_JOB_NODELIST", None) is not None:
nodelist = (
@ -84,7 +85,7 @@ class ServerManager:
openai_configs = []
for url in urls:
openai_configs.append(
OpenaiConfig(
APIServerConfig(
base_url=url,
timeout=configs.timeout,
num_max_requests_at_once=configs.num_max_requests_at_once,
@ -94,9 +95,9 @@ class ServerManager:
api_key="x",
)
)
self.servers = [OpenAIServer(config) for config in openai_configs]
self.servers = [server_class(config) for config in openai_configs]
elif not slurm:
self.servers = [OpenAIServer(config) for config in configs]
self.servers = [server_class(config) for config in configs]
else:
nodelist = (
os.popen(f'scontrol show hostnames {os.environ["SLURM_JOB_NODELIST"]}')
@ -109,7 +110,7 @@ class ServerManager:
"Not enough nodes to distribute to, assuming single node"
" and you've setup your sglang appropriately."
)
self.servers = [OpenAIServer(config) for config in configs]
self.servers = [server_class(config) for config in configs]
return
urls = []
num_training_nodes = int(os.environ.get("NUM_TRAINING_NODES"))
@ -124,7 +125,7 @@ class ServerManager:
new_conf = configs[0].model_copy(deep=True)
new_conf.base_url = urls[i]
new_configs.append(new_conf)
self.servers = [OpenAIServer(config) for config in new_configs]
self.servers = [server_class(config) for config in new_configs]
async def update_weight(self, weight: float):
for server in self.servers:

View file

@ -0,0 +1,126 @@
"""
This is a server that interfaces with trl's vLLM server.
Developed with much help from @winglian when they worked on integrating Atropos into Axolotl.
"""
import time
import uuid
import aiohttp
from openai.types.chat.chat_completion import (
ChatCompletion,
ChatCompletionMessage,
Choice,
)
from transformers import AutoTokenizer
from atroposlib.envs.server_handling.server_baseline import APIServer, APIServerConfig
class TrlVllmServer(APIServer):
"""
A server that interfaces with trl's vLLM server.
"""
def __init__(self, config: APIServerConfig):
self.config = config
self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
super().__init__(config)
async def check_server_status_task(self):
"""
TODO: Implement server health check for trl's vLLM server
"""
self.server_healthy = True
async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion:
"""
Wrapper for the chat completion using the trl's vLLM server.
"""
url = f"{self.config.base_url}/generate/"
prompt = kwargs.get("messages", [])
prompt = self.tokenizer.apply_chat_template(
prompt, tokenize=False, add_generation_prompt=True
)
async with aiohttp.ClientSession() as session:
async with session.post(
url,
json={
"prompts": [prompt],
"n": kwargs.get("n", 1),
"repetition_penalty": kwargs.get("repetition_penalty", 1.0),
"temperature": kwargs.get("temperature", 1.0),
"top_p": kwargs.get("top_p", 1.0),
"top_k": kwargs.get("top_k", -1),
"min_p": kwargs.get("min_p", 0.0),
"max_tokens": kwargs.get("max_tokens", 1024),
},
) as response:
completions = await response.json()
completions = ChatCompletion(
id=str(uuid.uuid4()),
object="chat.completion",
created=int(time.time()),
model=self.config.model_name,
choices=[
Choice(
finish_reason=(
"stop"
if self.tokenizer.eos_token_id in completion
else "length"
),
index=i,
message=ChatCompletionMessage(
content=self.tokenizer.decode(completion),
role="assistant",
),
)
for i, completion in enumerate(completions["completion_ids"])
],
)
return completions
async def _completion_wrapper(self, **kwargs) -> ChatCompletion:
"""
Wrapper for the completion using the trl's vLLM server.
"""
url = f"{self.config.base_url}/generate/"
prompt = kwargs.get("prompt", "")
async with aiohttp.ClientSession() as session:
async with session.post(
url,
json={
"prompts": [prompt],
"n": kwargs.get("n", 1),
"repetition_penalty": kwargs.get("repetition_penalty", 1.0),
"temperature": kwargs.get("temperature", 1.0),
"top_p": kwargs.get("top_p", 1.0),
"top_k": kwargs.get("top_k", -1),
"min_p": kwargs.get("min_p", 0.0),
"max_tokens": kwargs.get("max_tokens", 1024),
},
) as response:
completions = await response.json()
completions = ChatCompletion(
id=str(uuid.uuid4()),
object="chat.completion",
created=int(time.time()),
model=self.config.model_name,
choices=[
Choice(
finish_reason=(
"stop"
if self.tokenizer.eos_token_id in completion
else "length"
),
index=i,
message=ChatCompletionMessage(
content=self.tokenizer.decode(completion),
role="assistant",
),
)
for i, completion in enumerate(completions["completion_ids"])
],
)
return completions