mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-30 17:40:36 +00:00
Merge branch 'main' into blackjack2-env
This commit is contained in:
commit
00dd120067
34 changed files with 1620 additions and 386 deletions
|
|
@ -58,10 +58,38 @@ These methods have default implementations or are optional based on your needs:
|
|||
|
||||
* **`save_checkpoint(self, step, data=None)`**: The base class calls this method automatically at checkpoint intervals determined by the server. It saves the provided `data` dictionary (which you might populate with environment-specific state) to a JSON file. You can override this to customize *what* data is saved or *how* it's saved (e.g., using a different format or location), but the triggering mechanism remains automatic.
|
||||
|
||||
* **`@classmethod config_init(cls) -> Tuple[BaseEnvConfig, Union[ServerBaseline, List[OpenaiConfig]]]`**: This class method is used by the default `get_cli_serve_config_cls` implementation to get the initial environment configuration (`BaseEnvConfig` subclass) and server configurations (`ServerBaseline` or `List[OpenaiConfig]`) when setting up the `serve` command. The default implementation returns `cls.env_config_cls(), ServerBaseline()`. You might override this if your environment requires different default configurations or specific server setups (like multiple `OpenaiConfig` instances) when run via the CLI `serve` command.
|
||||
* **`@classmethod config_init(cls) -> Tuple[BaseEnvConfig, Union[ServerBaseline, List[APIServerConfig]]]`**: This class method is used by the default `get_cli_serve_config_cls` implementation to get the initial environment configuration (`BaseEnvConfig` subclass) and server configurations (`ServerBaseline` or `List[APIServerConfig]`) when setting up the `serve` command. The default implementation returns `cls.env_config_cls(), ServerBaseline()`. You might override this if your environment requires different default configurations or specific server setups (like multiple `APIServerConfig` instances) when run via the CLI `serve` command.
|
||||
|
||||
* **`async def cleanup(self)`**: Called after each call to `handle_env`. You can implement this for any cleanup needed after processing a single item, though it's often not required.
|
||||
|
||||
## Overrideable Class Variables
|
||||
|
||||
These class-level variables in `BaseEnv` can be overridden in your subclass to customize its behavior:
|
||||
|
||||
* **`name: Optional[str]`**:
|
||||
* Default: `None`
|
||||
* Purpose: You can set a string name for your environment. This name is used by default for `wandb_name` in the `BaseEnvConfig` if not otherwise specified, influencing how runs are grouped or named in Weights & Biases. It can also be useful for general identification or logging purposes.
|
||||
|
||||
* **`env_config_cls: Type[BaseEnvConfig]`**:
|
||||
* Default: `BaseEnvConfig`
|
||||
* Purpose: This variable holds the Pydantic model class that will be used for your environment's configuration. If your environment requires custom configuration fields beyond what `BaseEnvConfig` offers, you should create a new class that inherits from `BaseEnvConfig` (or a subclass thereof) and assign it to `env_config_cls`. This allows the CLI and other parts of the system to correctly parse and manage your environment's specific settings.
|
||||
```python
|
||||
from pydantic import Field
|
||||
from atroposlib.envs import BaseEnv, BaseEnvConfig
|
||||
|
||||
class MyEnvConfig(BaseEnvConfig):
|
||||
my_custom_param: str = Field(default="default_value", description="A custom parameter for MyEnv")
|
||||
|
||||
class MyEnv(BaseEnv):
|
||||
env_config_cls = MyEnvConfig
|
||||
name = "MyCustomEnvironment"
|
||||
# ... other implementations
|
||||
```
|
||||
|
||||
* **`server_cls: Type[APIServer]`**:
|
||||
* Default: `APIServer`
|
||||
* Purpose: Specifies the class to be used for managing interactions with API servers (e.g., inference endpoints). Should mostly be used for developing addiitonal API interfaces, but if you need a nonstandard way of connecting with an existing API you can use this to easily slot in any modifications you need.
|
||||
|
||||
## Provided Functionality
|
||||
|
||||
`BaseEnv` provides several helpful features:
|
||||
|
|
|
|||
|
|
@ -40,7 +40,8 @@ from atroposlib.utils.metrics import get_std_min_max_avg
|
|||
|
||||
from ..type_definitions import Item, Message
|
||||
from .server_handling.server_manager import (
|
||||
OpenaiConfig,
|
||||
APIServer,
|
||||
APIServerConfig,
|
||||
ServerBaseline,
|
||||
ServerManager,
|
||||
ServerManagerConfig,
|
||||
|
|
@ -163,13 +164,14 @@ class BaseEnvConfig(BaseModel):
|
|||
|
||||
class BaseEnv(ABC):
|
||||
|
||||
name = None
|
||||
env_config_cls = BaseEnvConfig
|
||||
name: Optional[str] = None
|
||||
env_config_cls: BaseEnvConfig = BaseEnvConfig
|
||||
server_cls: APIServer = APIServer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BaseEnvConfig,
|
||||
server_configs: Union[ServerBaseline, List[OpenaiConfig]],
|
||||
server_configs: Union[ServerBaseline, List[APIServerConfig]],
|
||||
slurm=False,
|
||||
testing=False,
|
||||
):
|
||||
|
|
@ -184,7 +186,9 @@ class BaseEnv(ABC):
|
|||
self.last_loop_time = None
|
||||
self.last_completed_item = None
|
||||
self.config = config
|
||||
self.server = ServerManager(server_configs, slurm=slurm, testing=testing)
|
||||
self.server = ServerManager(
|
||||
server_configs, slurm=slurm, testing=testing, server_class=self.server_cls
|
||||
)
|
||||
self.workers = set()
|
||||
self.eval_workers = set()
|
||||
self.backlog = []
|
||||
|
|
@ -234,7 +238,7 @@ class BaseEnv(ABC):
|
|||
@classmethod
|
||||
def config_init(
|
||||
cls,
|
||||
) -> Tuple[BaseEnvConfig, Union[ServerBaseline, List[OpenaiConfig]]]:
|
||||
) -> Tuple[BaseEnvConfig, Union[ServerBaseline, List[APIServerConfig]]]:
|
||||
"""
|
||||
Initialize the config
|
||||
"""
|
||||
|
|
@ -1020,7 +1024,6 @@ class BaseEnv(ABC):
|
|||
Returns:
|
||||
type: The CliServeConfig class for serving commands.
|
||||
"""
|
||||
|
||||
# Get the default configurations defined by the specific environment class
|
||||
default_env_config, default_server_configs = cls.config_init()
|
||||
|
||||
|
|
@ -1032,8 +1035,8 @@ class BaseEnv(ABC):
|
|||
class CliServeConfig(
|
||||
get_prefixed_pydantic_model(type(default_env_config), env_full_prefix),
|
||||
get_prefixed_pydantic_model(
|
||||
OpenaiConfig, openai_full_prefix
|
||||
), # Use OpenaiConfig for CLI args
|
||||
APIServerConfig, openai_full_prefix
|
||||
), # Use APIServerConfig for CLI args
|
||||
ServerManagerConfig, # ServerManager args are not namespaced by default
|
||||
Cmd,
|
||||
):
|
||||
|
|
@ -1089,7 +1092,7 @@ class BaseEnv(ABC):
|
|||
oai_cli_passed_args or yaml_oai_config
|
||||
):
|
||||
raise ValueError(
|
||||
"ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use OpenaiConfig." # noqa: E501
|
||||
"ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use APIServerConfig." # noqa: E501
|
||||
)
|
||||
if (
|
||||
isinstance(default_server_configs, list)
|
||||
|
|
@ -1101,11 +1104,11 @@ class BaseEnv(ABC):
|
|||
default_openai_config_ = default_server_configs
|
||||
if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1:
|
||||
yaml_oai_config = yaml_oai_config[0]
|
||||
if isinstance(default_openai_config_, OpenaiConfig) and isinstance(
|
||||
if isinstance(default_openai_config_, APIServerConfig) and isinstance(
|
||||
yaml_oai_config, dict
|
||||
):
|
||||
openai_config_dict = merge_dicts(
|
||||
default_openai_config_.model_dump(), # Default OpenaiConfig (or from class init)
|
||||
default_openai_config_.model_dump(), # Default APIServerConfig (or from class init)
|
||||
yaml_oai_config,
|
||||
oai_cli_passed_args,
|
||||
)
|
||||
|
|
@ -1189,7 +1192,7 @@ class BaseEnv(ABC):
|
|||
data_path_to_save_groups=f"data/{cls.name or 'groups'}.jsonl",
|
||||
use_wandb=True,
|
||||
)
|
||||
PROCESS_MODE_OPENAI_DEFAULT_CONFIG = OpenaiConfig(
|
||||
PROCESS_MODE_OPENAI_DEFAULT_CONFIG = APIServerConfig(
|
||||
model_name="gpt-4.1-nano",
|
||||
base_url=None,
|
||||
api_key=None,
|
||||
|
|
@ -1200,10 +1203,7 @@ class BaseEnv(ABC):
|
|||
)
|
||||
|
||||
# Get the base default configurations from the specific environment class
|
||||
(
|
||||
default_env_config,
|
||||
default_server_configs,
|
||||
) = cls.config_init()
|
||||
default_env_config, default_server_configs = cls.config_init()
|
||||
|
||||
# Define namespace prefixes
|
||||
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
|
||||
|
|
@ -1215,8 +1215,7 @@ class BaseEnv(ABC):
|
|||
type(default_env_config), PROCESS_MODE_ENV_DEFAULT_CONFIG
|
||||
)
|
||||
openai_config_cls_new_defaults = adjust_model_defaults(
|
||||
OpenaiConfig,
|
||||
PROCESS_MODE_OPENAI_DEFAULT_CONFIG,
|
||||
APIServerConfig, PROCESS_MODE_OPENAI_DEFAULT_CONFIG
|
||||
)
|
||||
server_manager_config_cls_new_defaults = adjust_model_defaults(
|
||||
ServerManagerConfig,
|
||||
|
|
@ -1283,7 +1282,7 @@ class BaseEnv(ABC):
|
|||
oai_cli_passed_args or yaml_oai_config
|
||||
):
|
||||
raise ValueError(
|
||||
"ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use OpenaiConfig." # noqa: E501
|
||||
"ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use APIServerConfig." # noqa: E501
|
||||
)
|
||||
|
||||
if (
|
||||
|
|
@ -1296,11 +1295,11 @@ class BaseEnv(ABC):
|
|||
default_openai_config_ = default_server_configs
|
||||
if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1:
|
||||
yaml_oai_config = yaml_oai_config[0]
|
||||
if isinstance(default_openai_config_, OpenaiConfig) and isinstance(
|
||||
if isinstance(default_openai_config_, APIServerConfig) and isinstance(
|
||||
yaml_oai_config, dict
|
||||
):
|
||||
openai_config_dict = merge_dicts(
|
||||
default_openai_config_.model_dump(), # Default OpenaiConfig (or from class init)
|
||||
default_openai_config_.model_dump(), # Default APIServerConfig (or from class init)
|
||||
PROCESS_MODE_OPENAI_DEFAULT_CONFIG.model_dump(), # Process Mode Defaults
|
||||
yaml_oai_config,
|
||||
oai_cli_passed_args,
|
||||
|
|
|
|||
|
|
@ -1,142 +1,28 @@
|
|||
import asyncio
|
||||
import collections
|
||||
import time
|
||||
from asyncio import exceptions
|
||||
from typing import Optional
|
||||
import warnings
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
import openai
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.completion import Completion
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_cli import FailedExecutionException
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from atroposlib.envs.constants import NAMESPACE_SEP, OPENAI_NAMESPACE
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServer, APIServerConfig
|
||||
|
||||
|
||||
class OpenaiConfig(BaseModel):
|
||||
class OpenAIServer(APIServer):
|
||||
"""
|
||||
Configuration for the server manager.
|
||||
OpenAI server handling.
|
||||
"""
|
||||
|
||||
api_key: Optional[str] = Field(
|
||||
default=None, description="API key for OpenAI API. Use 'x' for local servers."
|
||||
)
|
||||
base_url: Optional[str] = Field(
|
||||
default=None,
|
||||
description="URL of the API endpoint. None if using official OpenAI API, otherwise local server URL.",
|
||||
)
|
||||
timeout: int = Field(
|
||||
default=1200, description="Timeout for the request in seconds."
|
||||
)
|
||||
num_max_requests_at_once: int = Field(
|
||||
default=512,
|
||||
description="Maximum number of concurrent requests. Note: You should divide this by the n kwarg.",
|
||||
)
|
||||
num_requests_for_eval: int = Field(
|
||||
default=64, description="Maximum number of concurrent requests for evaluation."
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="default",
|
||||
description="The model name to use. Required for both OpenAI and local models.",
|
||||
)
|
||||
rolling_buffer_length: int = Field(
|
||||
default=1000, description="Length of the rolling buffer to store metrics."
|
||||
)
|
||||
|
||||
|
||||
class AsyncSemWithAdaptiveWeight(asyncio.Semaphore):
|
||||
def __init__(self, value: int):
|
||||
super().__init__(value=value)
|
||||
self.max_val = value
|
||||
self.weight = 1.0
|
||||
|
||||
def update_weight(self, weight: float) -> None:
|
||||
self.weight = weight
|
||||
|
||||
def min_val(self):
|
||||
return self.max_val * (1.0 - self.weight)
|
||||
|
||||
def release(self):
|
||||
"""Release a semaphore, incrementing the internal counter by one.
|
||||
|
||||
When it was zero on entry and another coroutine is waiting for it to
|
||||
become larger than zero again, wake up that coroutine.
|
||||
|
||||
If weight is set, it'll only wake up next if the value is greater than the max_val * weight
|
||||
"""
|
||||
self._value += 1
|
||||
if self._value > self.min_val():
|
||||
self._wake_up_next()
|
||||
|
||||
def locked(self):
|
||||
"""Returns True if semaphore cannot be acquired immediately."""
|
||||
return self._value <= self.min_val() or (
|
||||
any(not w.cancelled() for w in (self._waiters or ()))
|
||||
)
|
||||
|
||||
async def acquire(self):
|
||||
"""Acquire a semaphore.
|
||||
|
||||
If the internal counter is larger than zero on entry,
|
||||
decrement it by one and return True immediately. If it is
|
||||
zero on entry, block, waiting until some other coroutine has
|
||||
called release() to make it larger than 0, and then return
|
||||
True.
|
||||
"""
|
||||
if not self.locked():
|
||||
self._value -= 1
|
||||
return True
|
||||
|
||||
if self._waiters is None:
|
||||
self._waiters = collections.deque()
|
||||
fut = self._get_loop().create_future()
|
||||
self._waiters.append(fut)
|
||||
|
||||
# Finally block should be called before the CancelledError
|
||||
# handling as we don't want CancelledError to call
|
||||
# _wake_up_first() and attempt to wake up itself.
|
||||
try:
|
||||
try:
|
||||
await fut
|
||||
finally:
|
||||
self._waiters.remove(fut)
|
||||
except exceptions.CancelledError:
|
||||
if not fut.cancelled():
|
||||
self._value += 1
|
||||
self._wake_up_next()
|
||||
raise
|
||||
|
||||
if self._value > self.min_val():
|
||||
self._wake_up_next()
|
||||
return True
|
||||
|
||||
|
||||
class OpenAIServer:
|
||||
def __init__(self, config: OpenaiConfig):
|
||||
self.config = config
|
||||
def __init__(self, config: APIServerConfig):
|
||||
self.openai = openai.AsyncClient(
|
||||
api_key=config.api_key,
|
||||
base_url=config.base_url,
|
||||
timeout=config.timeout,
|
||||
)
|
||||
self.sem = AsyncSemWithAdaptiveWeight(config.num_max_requests_at_once)
|
||||
self.eval_sem = AsyncSemWithAdaptiveWeight(config.num_requests_for_eval)
|
||||
self.server_healthy = True
|
||||
self.attempts_list = []
|
||||
self.request_timings = []
|
||||
# in case eval is much different, we should keep different buffers
|
||||
self.eval_attempts_list = []
|
||||
self.eval_request_timings = []
|
||||
self.check_task = None
|
||||
self.initialized = False
|
||||
|
||||
async def update_weight(self, weight: float) -> None:
|
||||
# need to update sems
|
||||
self.sem.update_weight(weight)
|
||||
self.eval_sem.update_weight(weight)
|
||||
super().__init__(config)
|
||||
|
||||
async def check_server_status_task(self):
|
||||
while True:
|
||||
|
|
@ -156,147 +42,90 @@ class OpenAIServer:
|
|||
self.server_healthy = False
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def wandb_metrics(
|
||||
self, metrics_dict: Optional[dict], server_name: Optional[str]
|
||||
):
|
||||
if server_name is None:
|
||||
server_name = "server"
|
||||
if len(self.request_timings) > 0:
|
||||
metrics_dict[f"server/{server_name}_request_time_avg"] = np.mean(
|
||||
self.request_timings
|
||||
async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion:
|
||||
"""
|
||||
Wrapper for the chat completion using the openai client.
|
||||
"""
|
||||
assert (
|
||||
kwargs.get("model", None) is not None
|
||||
), "Model is required for chat completion!"
|
||||
assert (
|
||||
kwargs.get("messages", None) is not None
|
||||
), "Messages are required for chat completion!"
|
||||
if self.config.n_kwarg_is_ignored:
|
||||
n = kwargs.pop("n", 1)
|
||||
completion_list = await asyncio.gather(
|
||||
*[self.openai.chat.completions.create(**kwargs) for _ in range(n)]
|
||||
)
|
||||
metrics_dict[f"server/{server_name}_request_time_std"] = np.std(
|
||||
self.request_timings
|
||||
)
|
||||
metrics_dict[f"server/{server_name}_request_time_99p"] = np.percentile(
|
||||
self.request_timings, 99
|
||||
)
|
||||
if len(self.eval_request_timings) > 0:
|
||||
metrics_dict[f"server/{server_name}_eval_request_time_avg"] = np.mean(
|
||||
self.eval_request_timings
|
||||
)
|
||||
metrics_dict[f"server/{server_name}_eval_request_time_std"] = np.std(
|
||||
self.eval_request_timings
|
||||
)
|
||||
metrics_dict[f"server/{server_name}_eval_request_time_99p"] = np.percentile(
|
||||
self.eval_request_timings, 99
|
||||
)
|
||||
if len(self.attempts_list) > 0:
|
||||
metrics_dict[f"server/{server_name}_average_num_attempts"] = np.mean(
|
||||
self.attempts_list
|
||||
)
|
||||
if len(self.eval_attempts_list) > 0:
|
||||
metrics_dict[f"server/{server_name}_eval_retry_rate"] = np.mean(
|
||||
self.eval_attempts_list
|
||||
)
|
||||
return metrics_dict
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
)
|
||||
async def _chat_comp(self, stat_dict, **kwargs) -> ChatCompletion:
|
||||
while not self.server_healthy:
|
||||
await asyncio.sleep(1)
|
||||
async with self.sem:
|
||||
if stat_dict.get("start", None) is None:
|
||||
stat_dict["start"] = time.time()
|
||||
stat_dict["attempts"] += 1
|
||||
completions = await self.openai.chat.completions.create(**kwargs)
|
||||
stat_dict["end"] = time.time()
|
||||
return completions
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
)
|
||||
async def _chat_eval(self, stat_dict, **kwargs) -> ChatCompletion:
|
||||
while not self.server_healthy:
|
||||
await asyncio.sleep(1)
|
||||
async with self.eval_sem:
|
||||
if stat_dict.get("start", None) is None:
|
||||
stat_dict["start"] = time.time()
|
||||
stat_dict["attempts"] += 1
|
||||
completions = await self.openai.chat.completions.create(**kwargs)
|
||||
stat_dict["end"] = time.time()
|
||||
return completions
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
)
|
||||
async def chat_completion(self, **kwargs) -> ChatCompletion:
|
||||
if not self.initialized:
|
||||
if (
|
||||
self.config.base_url is not None
|
||||
): # skip health check if using OpenAI API
|
||||
self.check_task = asyncio.create_task(self.check_server_status_task())
|
||||
completions = completion_list[0]
|
||||
if n > 1:
|
||||
for c in completion_list[1:]:
|
||||
completions.choices.extend(c.choices)
|
||||
else:
|
||||
self.server_healthy = True
|
||||
self.initialized = True
|
||||
kwargs["model"] = self.config.model_name
|
||||
split = kwargs.pop("split", "train")
|
||||
stat_dict = {}
|
||||
stat_dict["attempts"] = 0
|
||||
if split == "train":
|
||||
ret_data = await self._chat_comp(stat_dict, **kwargs)
|
||||
self.request_timings.append(stat_dict["end"] - stat_dict["start"])
|
||||
self.attempts_list.append(stat_dict["attempts"])
|
||||
completions = await self.openai.chat.completions.create(**kwargs)
|
||||
else:
|
||||
# Give separate eval workers, if desired, gotta go fast for those evals
|
||||
ret_data = await self._chat_eval(stat_dict, **kwargs)
|
||||
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
|
||||
self.eval_attempts_list.append(stat_dict["attempts"])
|
||||
return ret_data
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
)
|
||||
async def _comp(self, stat_dict, **kwargs) -> Completion:
|
||||
while not self.server_healthy:
|
||||
await asyncio.sleep(1)
|
||||
async with self.sem:
|
||||
if stat_dict.get("start", None) is None:
|
||||
stat_dict["start"] = time.time()
|
||||
stat_dict["attempts"] += 1
|
||||
completions = await self.openai.completions.create(**kwargs)
|
||||
stat_dict["end"] = time.time()
|
||||
return completions
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
)
|
||||
async def _comp_eval(self, stat_dict, **kwargs) -> Completion:
|
||||
while not self.server_healthy:
|
||||
await asyncio.sleep(1)
|
||||
async with self.eval_sem:
|
||||
if stat_dict.get("start", None) is None:
|
||||
stat_dict["start"] = time.time()
|
||||
stat_dict["attempts"] += 1
|
||||
completions = await self.openai.completions.create(**kwargs)
|
||||
stat_dict["end"] = time.time()
|
||||
return completions
|
||||
|
||||
async def completion(self, **kwargs) -> Completion:
|
||||
if not self.initialized:
|
||||
if (
|
||||
self.config.base_url is not None
|
||||
): # skip health check if using OpenAI API
|
||||
self.check_task = asyncio.create_task(self.check_server_status_task())
|
||||
if "n" in kwargs:
|
||||
n = kwargs["n"]
|
||||
else:
|
||||
self.server_healthy = True
|
||||
self.initialized = True
|
||||
kwargs["model"] = self.config.model_name
|
||||
split = kwargs.pop("split", "train")
|
||||
stat_dict = {}
|
||||
stat_dict["attempts"] = 0
|
||||
if split == "train":
|
||||
ret_data = await self._comp(stat_dict, **kwargs)
|
||||
self.request_timings.append(stat_dict["end"] - stat_dict["start"])
|
||||
self.attempts_list.append(stat_dict["attempts"])
|
||||
n = 1
|
||||
completions = await self.openai.chat.completions.create(**kwargs)
|
||||
if len(completions.choices) != n:
|
||||
if len(completions.choices) != 1:
|
||||
raise ValueError(
|
||||
f"Expected 1 or {n} completions, got {len(completions.choices)}!"
|
||||
)
|
||||
else:
|
||||
warnings.warn("n kwarg is ignored by the API, setting to True")
|
||||
self.config.n_kwarg_is_ignored = True
|
||||
completion_list = await asyncio.gather(
|
||||
*[
|
||||
self.openai.chat.completions.create(**kwargs)
|
||||
for _ in range(1, n)
|
||||
]
|
||||
)
|
||||
for c in completion_list:
|
||||
completions.choices.extend(c.choices)
|
||||
return completions
|
||||
|
||||
async def _completion_wrapper(self, **kwargs) -> Completion:
|
||||
"""
|
||||
Wrapper for the completion using the openai client.
|
||||
"""
|
||||
assert (
|
||||
kwargs.get("model", None) is not None
|
||||
), "Model is required for completion!"
|
||||
assert (
|
||||
kwargs.get("prompt", None) is not None
|
||||
), "Prompt is required for completion!"
|
||||
if self.config.n_kwarg_is_ignored:
|
||||
n = kwargs.pop("n", 1)
|
||||
completion_list = await asyncio.gather(
|
||||
*[self.openai.completions.create(**kwargs) for _ in range(n)]
|
||||
)
|
||||
completions = completion_list[0]
|
||||
if n > 1:
|
||||
for c in completion_list[1:]:
|
||||
completions.choices.extend(c.choices)
|
||||
else:
|
||||
# Give separate eval workers, if desired, gotta go fast for those evals
|
||||
ret_data = await self._comp_eval(stat_dict, **kwargs)
|
||||
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
|
||||
self.eval_attempts_list.append(stat_dict["attempts"])
|
||||
return ret_data
|
||||
if "n" in kwargs:
|
||||
n = kwargs["n"]
|
||||
else:
|
||||
n = 1
|
||||
completions = await self.openai.completions.create(**kwargs)
|
||||
if len(completions.choices) != n:
|
||||
if len(completions.choices) != 1:
|
||||
raise ValueError(
|
||||
f"Expected 1 or {n} completions, got {len(completions.choices)}!"
|
||||
)
|
||||
else:
|
||||
warnings.warn("n kwarg is ignored by the API, setting to True")
|
||||
self.config.n_kwarg_is_ignored = True
|
||||
completion_list = await asyncio.gather(
|
||||
*[self.openai.completions.create(**kwargs) for _ in range(1, n)]
|
||||
)
|
||||
for c in completion_list:
|
||||
completions.choices.extend(c.choices)
|
||||
return completions
|
||||
|
||||
|
||||
def resolve_openai_configs(
|
||||
|
|
@ -338,7 +167,7 @@ def resolve_openai_configs(
|
|||
f"Using multi-server configuration defined in YAML under '{OPENAI_NAMESPACE}'."
|
||||
)
|
||||
try:
|
||||
server_configs = [OpenaiConfig(**cfg) for cfg in openai_yaml_config]
|
||||
server_configs = [APIServerConfig(**cfg) for cfg in openai_yaml_config]
|
||||
except Exception as e:
|
||||
raise FailedExecutionException(
|
||||
f"Error parsing multi-server OpenAI configuration from YAML under '{OPENAI_NAMESPACE}': {e}"
|
||||
|
|
@ -354,14 +183,14 @@ def resolve_openai_configs(
|
|||
"Using single OpenAI server configuration based on merged settings (default/YAML/CLI)."
|
||||
)
|
||||
try:
|
||||
final_openai_config = OpenaiConfig(**openai_config_dict)
|
||||
final_openai_config = APIServerConfig(**openai_config_dict)
|
||||
except Exception as e:
|
||||
raise FailedExecutionException(
|
||||
f"Error creating final OpenAI configuration from merged settings: {e}\n"
|
||||
f"Merged Dict: {openai_config_dict}"
|
||||
) from e
|
||||
|
||||
if isinstance(default_server_configs, OpenaiConfig):
|
||||
if isinstance(default_server_configs, APIServerConfig):
|
||||
server_configs = final_openai_config
|
||||
elif isinstance(default_server_configs, list):
|
||||
server_configs = [final_openai_config]
|
||||
|
|
|
|||
340
atroposlib/envs/server_handling/server_baseline.py
Normal file
340
atroposlib/envs/server_handling/server_baseline.py
Normal file
|
|
@ -0,0 +1,340 @@
|
|||
import asyncio
|
||||
import collections
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from asyncio import exceptions
|
||||
from typing import Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.completion import Completion
|
||||
from pydantic import BaseModel, Field
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
|
||||
class AsyncSemWithAdaptiveWeight(asyncio.Semaphore):
|
||||
def __init__(self, value: int):
|
||||
super().__init__(value=value)
|
||||
self.max_val = value
|
||||
self.weight = 1.0
|
||||
|
||||
def update_weight(self, weight: float) -> None:
|
||||
"""
|
||||
Update the weight of the semaphore.
|
||||
"""
|
||||
self.weight = weight
|
||||
|
||||
def min_val(self):
|
||||
"""
|
||||
Returns the minimum value of the semaphore.
|
||||
"""
|
||||
return self.max_val * (1.0 - self.weight)
|
||||
|
||||
def release(self):
|
||||
"""Release a semaphore, incrementing the internal counter by one.
|
||||
|
||||
When it was zero on entry and another coroutine is waiting for it to
|
||||
become larger than zero again, wake up that coroutine.
|
||||
|
||||
If weight is set, it'll only wake up next if the value is greater than the max_val * weight
|
||||
"""
|
||||
self._value += 1
|
||||
if self._value > self.min_val():
|
||||
self._wake_up_next()
|
||||
|
||||
def locked(self):
|
||||
"""Returns True if semaphore cannot be acquired immediately."""
|
||||
return self._value <= self.min_val() or (
|
||||
any(not w.cancelled() for w in (self._waiters or ()))
|
||||
)
|
||||
|
||||
async def acquire(self):
|
||||
"""Acquire a semaphore.
|
||||
|
||||
If the internal counter is larger than zero on entry,
|
||||
decrement it by one and return True immediately. If it is
|
||||
zero on entry, block, waiting until some other coroutine has
|
||||
called release() to make it larger than 0, and then return
|
||||
True.
|
||||
"""
|
||||
if not self.locked():
|
||||
self._value -= 1
|
||||
return True
|
||||
|
||||
if self._waiters is None:
|
||||
self._waiters = collections.deque()
|
||||
fut = self._get_loop().create_future()
|
||||
self._waiters.append(fut)
|
||||
|
||||
# Finally block should be called before the CancelledError
|
||||
# handling as we don't want CancelledError to call
|
||||
# _wake_up_first() and attempt to wake up itself.
|
||||
try:
|
||||
try:
|
||||
await fut
|
||||
finally:
|
||||
self._waiters.remove(fut)
|
||||
except exceptions.CancelledError:
|
||||
if not fut.cancelled():
|
||||
self._value += 1
|
||||
self._wake_up_next()
|
||||
raise
|
||||
|
||||
if self._value > self.min_val():
|
||||
self._wake_up_next()
|
||||
return True
|
||||
|
||||
|
||||
class ServerBaseline(BaseModel):
|
||||
"""
|
||||
Baseline configuration for server information. If local, uses ports 9004-9007 for the servers,
|
||||
assuming a 1:1 split of GPUs.
|
||||
"""
|
||||
|
||||
timeout: int = Field(
|
||||
default=1200, description="Timeout for the request in seconds."
|
||||
)
|
||||
num_max_requests_at_once: int = Field(
|
||||
default=512,
|
||||
description="Maximum number of concurrent requests. You should divide this by the n kwarg.",
|
||||
)
|
||||
num_requests_for_eval: int = Field(
|
||||
default=64, description="Maximum number of concurrent requests for evaluation."
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="default",
|
||||
description="The model name to use. Only works with sglang, please provide the model name.",
|
||||
)
|
||||
rolling_buffer_length: int = Field(
|
||||
default=1000, description="Length of the rolling buffer to store metrics."
|
||||
)
|
||||
server_type: Literal["openai", "trl"] = Field(
|
||||
default="openai", description="Type of server to use, openai or trl"
|
||||
)
|
||||
|
||||
|
||||
class APIServerConfig(ServerBaseline):
|
||||
"""
|
||||
API server configuration.
|
||||
"""
|
||||
|
||||
api_key: Optional[str] = Field(default="", description="API key for the server.")
|
||||
base_url: Optional[str] = Field(default="", description="Base URL for the server.")
|
||||
n_kwarg_is_ignored: bool = Field(
|
||||
default=False, description="Whether the n kwarg is ignored by this API server."
|
||||
)
|
||||
|
||||
|
||||
class APIServer(ABC):
|
||||
"""
|
||||
Abstract class for API servers.
|
||||
"""
|
||||
|
||||
def __init__(self, config: APIServerConfig):
|
||||
self.config = config
|
||||
self.sem = AsyncSemWithAdaptiveWeight(config.num_max_requests_at_once)
|
||||
self.eval_sem = AsyncSemWithAdaptiveWeight(config.num_requests_for_eval)
|
||||
self.server_healthy = True
|
||||
self.attempts_list = []
|
||||
self.request_timings = []
|
||||
# in case eval is much different, we should keep different buffers
|
||||
self.eval_attempts_list = []
|
||||
self.eval_request_timings = []
|
||||
self.check_task = None
|
||||
self.initialized = False
|
||||
|
||||
async def update_weight(self, weight: float) -> None:
|
||||
"""
|
||||
Update the weight of the semaphores
|
||||
"""
|
||||
# need to update sems
|
||||
self.sem.update_weight(weight)
|
||||
self.eval_sem.update_weight(weight)
|
||||
|
||||
@abstractmethod
|
||||
async def check_server_status_task(self):
|
||||
"""
|
||||
Check the status of the server. Should be overridden by the child class.
|
||||
Set self.server_healthy to True if the server is healthy.
|
||||
"""
|
||||
self.server_healthy = False
|
||||
|
||||
async def wandb_metrics(
|
||||
self, metrics_dict: Optional[dict], server_name: Optional[str]
|
||||
):
|
||||
"""
|
||||
Add metrics to the metrics dictionary.
|
||||
|
||||
If you want to add more metrics, you can do so by overriding this method, but make sure to call
|
||||
super().wandb_metrics(metrics_dict, server_name) first to get the default metrics, if you still want them.
|
||||
"""
|
||||
if server_name is None:
|
||||
server_name = "server"
|
||||
if len(self.request_timings) > 0:
|
||||
metrics_dict[f"server/{server_name}_request_time_avg"] = np.mean(
|
||||
self.request_timings
|
||||
)
|
||||
metrics_dict[f"server/{server_name}_request_time_std"] = np.std(
|
||||
self.request_timings
|
||||
)
|
||||
metrics_dict[f"server/{server_name}_request_time_99p"] = np.percentile(
|
||||
self.request_timings, 99
|
||||
)
|
||||
if len(self.eval_request_timings) > 0:
|
||||
metrics_dict[f"server/{server_name}_eval_request_time_avg"] = np.mean(
|
||||
self.eval_request_timings
|
||||
)
|
||||
metrics_dict[f"server/{server_name}_eval_request_time_std"] = np.std(
|
||||
self.eval_request_timings
|
||||
)
|
||||
metrics_dict[f"server/{server_name}_eval_request_time_99p"] = np.percentile(
|
||||
self.eval_request_timings, 99
|
||||
)
|
||||
if len(self.attempts_list) > 0:
|
||||
metrics_dict[f"server/{server_name}_average_num_attempts"] = np.mean(
|
||||
self.attempts_list
|
||||
)
|
||||
if len(self.eval_attempts_list) > 0:
|
||||
metrics_dict[f"server/{server_name}_eval_retry_rate"] = np.mean(
|
||||
self.eval_attempts_list
|
||||
)
|
||||
return metrics_dict
|
||||
|
||||
@abstractmethod
|
||||
async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion:
|
||||
"""
|
||||
Wrapper for the chat completion. Should be overridden by the child class and return a ChatCompletion object.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _completion_wrapper(self, **kwargs) -> Completion:
|
||||
"""
|
||||
Wrapper for the completion. Should be overridden by the child class and return a Completion object.
|
||||
"""
|
||||
pass
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
)
|
||||
async def _chat_comp(self, stat_dict, **kwargs) -> ChatCompletion:
|
||||
"""
|
||||
Simple retry and stat collection wrapper for the chat completion.
|
||||
"""
|
||||
while not self.server_healthy:
|
||||
await asyncio.sleep(1)
|
||||
async with self.sem:
|
||||
if stat_dict.get("start", None) is None:
|
||||
stat_dict["start"] = time.time()
|
||||
stat_dict["attempts"] += 1
|
||||
completions = await self._chat_completion_wrapper(**kwargs)
|
||||
stat_dict["end"] = time.time()
|
||||
return completions
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
)
|
||||
async def _chat_eval(self, stat_dict, **kwargs) -> ChatCompletion:
|
||||
"""
|
||||
Simple retry and stat collection wrapper for the chat completion.
|
||||
"""
|
||||
while not self.server_healthy:
|
||||
await asyncio.sleep(1)
|
||||
async with self.eval_sem:
|
||||
if stat_dict.get("start", None) is None:
|
||||
stat_dict["start"] = time.time()
|
||||
stat_dict["attempts"] += 1
|
||||
completions = await self._chat_completion_wrapper(**kwargs)
|
||||
stat_dict["end"] = time.time()
|
||||
return completions
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
)
|
||||
async def chat_completion(self, **kwargs) -> ChatCompletion:
|
||||
"""
|
||||
Chat completion handler, waits for the server to be healthy and then calls the chat completion wrapper.
|
||||
"""
|
||||
if not self.initialized:
|
||||
if (
|
||||
self.config.base_url is not None
|
||||
): # skip health check if using OpenAI API
|
||||
self.check_task = asyncio.create_task(self.check_server_status_task())
|
||||
else:
|
||||
self.server_healthy = True
|
||||
self.initialized = True
|
||||
kwargs["model"] = self.config.model_name
|
||||
split = kwargs.pop("split", "train")
|
||||
stat_dict = {}
|
||||
stat_dict["attempts"] = 0
|
||||
if split == "train":
|
||||
ret_data = await self._chat_comp(stat_dict, **kwargs)
|
||||
self.request_timings.append(stat_dict["end"] - stat_dict["start"])
|
||||
self.attempts_list.append(stat_dict["attempts"])
|
||||
else:
|
||||
# Give separate eval workers, if desired, gotta go fast for those evals
|
||||
ret_data = await self._chat_eval(stat_dict, **kwargs)
|
||||
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
|
||||
self.eval_attempts_list.append(stat_dict["attempts"])
|
||||
return ret_data
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
)
|
||||
async def _comp(self, stat_dict, **kwargs) -> Completion:
|
||||
"""
|
||||
Simple retry and stat collection wrapper for the completion.
|
||||
"""
|
||||
while not self.server_healthy:
|
||||
await asyncio.sleep(1)
|
||||
async with self.sem:
|
||||
if stat_dict.get("start", None) is None:
|
||||
stat_dict["start"] = time.time()
|
||||
stat_dict["attempts"] += 1
|
||||
completions = await self._completion_wrapper(**kwargs)
|
||||
stat_dict["end"] = time.time()
|
||||
return completions
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
)
|
||||
async def _comp_eval(self, stat_dict, **kwargs) -> Completion:
|
||||
"""
|
||||
Simple retry and stat collection wrapper for the completion.
|
||||
"""
|
||||
while not self.server_healthy:
|
||||
await asyncio.sleep(1)
|
||||
async with self.eval_sem:
|
||||
if stat_dict.get("start", None) is None:
|
||||
stat_dict["start"] = time.time()
|
||||
stat_dict["attempts"] += 1
|
||||
completions = await self._completion_wrapper(**kwargs)
|
||||
stat_dict["end"] = time.time()
|
||||
return completions
|
||||
|
||||
async def completion(self, **kwargs) -> Completion:
|
||||
"""
|
||||
Completion handler, waits for the server to be healthy and then calls the completion wrapper.
|
||||
"""
|
||||
if not self.initialized:
|
||||
if (
|
||||
self.config.base_url is not None
|
||||
): # skip health check if using OpenAI API
|
||||
self.check_task = asyncio.create_task(self.check_server_status_task())
|
||||
else:
|
||||
self.server_healthy = True
|
||||
self.initialized = True
|
||||
kwargs["model"] = self.config.model_name
|
||||
split = kwargs.pop("split", "train")
|
||||
stat_dict = {}
|
||||
stat_dict["attempts"] = 0
|
||||
if split == "train":
|
||||
ret_data = await self._comp(stat_dict, **kwargs)
|
||||
self.request_timings.append(stat_dict["end"] - stat_dict["start"])
|
||||
self.attempts_list.append(stat_dict["attempts"])
|
||||
else:
|
||||
# Give separate eval workers, if desired, gotta go fast for those evals
|
||||
ret_data = await self._comp_eval(stat_dict, **kwargs)
|
||||
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
|
||||
self.eval_attempts_list.append(stat_dict["attempts"])
|
||||
return ret_data
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator, List, Union
|
||||
|
|
@ -7,56 +8,56 @@ from openai.types.chat.chat_completion import ChatCompletion
|
|||
from openai.types.completion import Completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from atroposlib.envs.server_handling.openai_server import OpenaiConfig, OpenAIServer
|
||||
from atroposlib.envs.server_handling.openai_server import OpenAIServer
|
||||
from atroposlib.envs.server_handling.server_baseline import (
|
||||
APIServer,
|
||||
APIServerConfig,
|
||||
ServerBaseline,
|
||||
)
|
||||
from atroposlib.envs.server_handling.server_harness import ServerHarness
|
||||
from atroposlib.envs.server_handling.trl_vllm_server import TrlVllmServer
|
||||
|
||||
|
||||
class ServerManagerConfig(BaseModel):
|
||||
slurm: bool = Field(
|
||||
default=True, description="Whether environment is running on slurm or not."
|
||||
default=False, description="Whether environment is running on slurm or not."
|
||||
)
|
||||
testing: bool = Field(
|
||||
default=False, description="If set to True, environment uses mock OpenAI data."
|
||||
)
|
||||
|
||||
|
||||
class ServerBaseline(BaseModel):
|
||||
"""
|
||||
Baseline configuration for server information. If local, uses ports 9004-9007 for the servers,
|
||||
assuming a 1:1 split of GPUs.
|
||||
"""
|
||||
|
||||
timeout: int = Field(
|
||||
default=1200, description="Timeout for the request in seconds."
|
||||
)
|
||||
num_max_requests_at_once: int = Field(
|
||||
default=512,
|
||||
description="Maximum number of concurrent requests. You should divide this by the n kwarg.",
|
||||
)
|
||||
num_requests_for_eval: int = Field(
|
||||
default=64, description="Maximum number of concurrent requests for evaluation."
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="default",
|
||||
description="The model name to use. Only works with sglang, please provide the model name.",
|
||||
)
|
||||
rolling_buffer_length: int = Field(
|
||||
default=1000, description="Length of the rolling buffer to store metrics."
|
||||
)
|
||||
|
||||
|
||||
class ServerManager:
|
||||
def __init__(
|
||||
self,
|
||||
configs: Union[ServerBaseline, List[OpenaiConfig]],
|
||||
configs: Union[ServerBaseline, List[APIServerConfig]],
|
||||
server_class: APIServer = APIServer,
|
||||
slurm=False,
|
||||
testing=False,
|
||||
):
|
||||
# First we check to see if it's the base server class, and if so, we need to select the appropriate server class
|
||||
# You can't use type() to check if it's the base server class, because it's an abstract class, it'll appear as
|
||||
# an ABCMeta, not what you're expecting.
|
||||
if inspect.isabstract(server_class):
|
||||
if not isinstance(configs, list):
|
||||
if configs.server_type == "openai":
|
||||
server_class = OpenAIServer
|
||||
elif configs.server_type == "trl":
|
||||
server_class = TrlVllmServer
|
||||
else:
|
||||
raise ValueError(f"Invalid server type: {configs.server_type}")
|
||||
else:
|
||||
if configs[0].server_type == "openai":
|
||||
server_class = OpenAIServer
|
||||
elif configs[0].server_type == "trl":
|
||||
server_class = TrlVllmServer
|
||||
else:
|
||||
raise ValueError(f"Invalid server type: {configs[0].server_type}")
|
||||
if testing:
|
||||
# testing :)
|
||||
self.servers = [ServerHarness()]
|
||||
return
|
||||
if isinstance(configs, ServerBaseline):
|
||||
if not isinstance(configs, list):
|
||||
urls = []
|
||||
if os.environ.get("SLURM_JOB_NODELIST", None) is not None:
|
||||
nodelist = (
|
||||
|
|
@ -84,7 +85,7 @@ class ServerManager:
|
|||
openai_configs = []
|
||||
for url in urls:
|
||||
openai_configs.append(
|
||||
OpenaiConfig(
|
||||
APIServerConfig(
|
||||
base_url=url,
|
||||
timeout=configs.timeout,
|
||||
num_max_requests_at_once=configs.num_max_requests_at_once,
|
||||
|
|
@ -94,9 +95,9 @@ class ServerManager:
|
|||
api_key="x",
|
||||
)
|
||||
)
|
||||
self.servers = [OpenAIServer(config) for config in openai_configs]
|
||||
self.servers = [server_class(config) for config in openai_configs]
|
||||
elif not slurm:
|
||||
self.servers = [OpenAIServer(config) for config in configs]
|
||||
self.servers = [server_class(config) for config in configs]
|
||||
else:
|
||||
nodelist = (
|
||||
os.popen(f'scontrol show hostnames {os.environ["SLURM_JOB_NODELIST"]}')
|
||||
|
|
@ -109,7 +110,7 @@ class ServerManager:
|
|||
"Not enough nodes to distribute to, assuming single node"
|
||||
" and you've setup your sglang appropriately."
|
||||
)
|
||||
self.servers = [OpenAIServer(config) for config in configs]
|
||||
self.servers = [server_class(config) for config in configs]
|
||||
return
|
||||
urls = []
|
||||
num_training_nodes = int(os.environ.get("NUM_TRAINING_NODES"))
|
||||
|
|
@ -124,7 +125,7 @@ class ServerManager:
|
|||
new_conf = configs[0].model_copy(deep=True)
|
||||
new_conf.base_url = urls[i]
|
||||
new_configs.append(new_conf)
|
||||
self.servers = [OpenAIServer(config) for config in new_configs]
|
||||
self.servers = [server_class(config) for config in new_configs]
|
||||
|
||||
async def update_weight(self, weight: float):
|
||||
for server in self.servers:
|
||||
|
|
|
|||
126
atroposlib/envs/server_handling/trl_vllm_server.py
Normal file
126
atroposlib/envs/server_handling/trl_vllm_server.py
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
"""
|
||||
This is a server that interfaces with trl's vLLM server.
|
||||
|
||||
Developed with much help from @winglian when they worked on integrating Atropos into Axolotl.
|
||||
"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import aiohttp
|
||||
from openai.types.chat.chat_completion import (
|
||||
ChatCompletion,
|
||||
ChatCompletionMessage,
|
||||
Choice,
|
||||
)
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServer, APIServerConfig
|
||||
|
||||
|
||||
class TrlVllmServer(APIServer):
|
||||
"""
|
||||
A server that interfaces with trl's vLLM server.
|
||||
"""
|
||||
|
||||
def __init__(self, config: APIServerConfig):
|
||||
self.config = config
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
||||
super().__init__(config)
|
||||
|
||||
async def check_server_status_task(self):
|
||||
"""
|
||||
TODO: Implement server health check for trl's vLLM server
|
||||
"""
|
||||
self.server_healthy = True
|
||||
|
||||
async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion:
|
||||
"""
|
||||
Wrapper for the chat completion using the trl's vLLM server.
|
||||
"""
|
||||
url = f"{self.config.base_url}/generate/"
|
||||
prompt = kwargs.get("messages", [])
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
prompt, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url,
|
||||
json={
|
||||
"prompts": [prompt],
|
||||
"n": kwargs.get("n", 1),
|
||||
"repetition_penalty": kwargs.get("repetition_penalty", 1.0),
|
||||
"temperature": kwargs.get("temperature", 1.0),
|
||||
"top_p": kwargs.get("top_p", 1.0),
|
||||
"top_k": kwargs.get("top_k", -1),
|
||||
"min_p": kwargs.get("min_p", 0.0),
|
||||
"max_tokens": kwargs.get("max_tokens", 1024),
|
||||
},
|
||||
) as response:
|
||||
completions = await response.json()
|
||||
completions = ChatCompletion(
|
||||
id=str(uuid.uuid4()),
|
||||
object="chat.completion",
|
||||
created=int(time.time()),
|
||||
model=self.config.model_name,
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason=(
|
||||
"stop"
|
||||
if self.tokenizer.eos_token_id in completion
|
||||
else "length"
|
||||
),
|
||||
index=i,
|
||||
message=ChatCompletionMessage(
|
||||
content=self.tokenizer.decode(completion),
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
for i, completion in enumerate(completions["completion_ids"])
|
||||
],
|
||||
)
|
||||
return completions
|
||||
|
||||
async def _completion_wrapper(self, **kwargs) -> ChatCompletion:
|
||||
"""
|
||||
Wrapper for the completion using the trl's vLLM server.
|
||||
"""
|
||||
url = f"{self.config.base_url}/generate/"
|
||||
prompt = kwargs.get("prompt", "")
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url,
|
||||
json={
|
||||
"prompts": [prompt],
|
||||
"n": kwargs.get("n", 1),
|
||||
"repetition_penalty": kwargs.get("repetition_penalty", 1.0),
|
||||
"temperature": kwargs.get("temperature", 1.0),
|
||||
"top_p": kwargs.get("top_p", 1.0),
|
||||
"top_k": kwargs.get("top_k", -1),
|
||||
"min_p": kwargs.get("min_p", 0.0),
|
||||
"max_tokens": kwargs.get("max_tokens", 1024),
|
||||
},
|
||||
) as response:
|
||||
completions = await response.json()
|
||||
completions = ChatCompletion(
|
||||
id=str(uuid.uuid4()),
|
||||
object="chat.completion",
|
||||
created=int(time.time()),
|
||||
model=self.config.model_name,
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason=(
|
||||
"stop"
|
||||
if self.tokenizer.eos_token_id in completion
|
||||
else "length"
|
||||
),
|
||||
index=i,
|
||||
message=ChatCompletionMessage(
|
||||
content=self.tokenizer.decode(completion),
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
for i, completion in enumerate(completions["completion_ids"])
|
||||
],
|
||||
)
|
||||
return completions
|
||||
Loading…
Add table
Add a link
Reference in a new issue