mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-29 17:35:07 +00:00
refactor to not mess up process...
This commit is contained in:
parent
6e9405ba95
commit
df62979b90
5 changed files with 346 additions and 424 deletions
|
|
@ -1,112 +1,28 @@
|
|||
import asyncio
|
||||
import collections
|
||||
import time
|
||||
import warnings
|
||||
from asyncio import exceptions
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
import openai
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.completion import Completion
|
||||
from pydantic_cli import FailedExecutionException
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from atroposlib.envs.constants import NAMESPACE_SEP, OPENAI_NAMESPACE
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServer, APIServerConfig
|
||||
|
||||
|
||||
class AsyncSemWithAdaptiveWeight(asyncio.Semaphore):
|
||||
def __init__(self, value: int):
|
||||
super().__init__(value=value)
|
||||
self.max_val = value
|
||||
self.weight = 1.0
|
||||
class OpenAIServer(APIServer):
|
||||
"""
|
||||
OpenAI server handling.
|
||||
"""
|
||||
|
||||
def update_weight(self, weight: float) -> None:
|
||||
self.weight = weight
|
||||
|
||||
def min_val(self):
|
||||
return self.max_val * (1.0 - self.weight)
|
||||
|
||||
def release(self):
|
||||
"""Release a semaphore, incrementing the internal counter by one.
|
||||
|
||||
When it was zero on entry and another coroutine is waiting for it to
|
||||
become larger than zero again, wake up that coroutine.
|
||||
|
||||
If weight is set, it'll only wake up next if the value is greater than the max_val * weight
|
||||
"""
|
||||
self._value += 1
|
||||
if self._value > self.min_val():
|
||||
self._wake_up_next()
|
||||
|
||||
def locked(self):
|
||||
"""Returns True if semaphore cannot be acquired immediately."""
|
||||
return self._value <= self.min_val() or (
|
||||
any(not w.cancelled() for w in (self._waiters or ()))
|
||||
)
|
||||
|
||||
async def acquire(self):
|
||||
"""Acquire a semaphore.
|
||||
|
||||
If the internal counter is larger than zero on entry,
|
||||
decrement it by one and return True immediately. If it is
|
||||
zero on entry, block, waiting until some other coroutine has
|
||||
called release() to make it larger than 0, and then return
|
||||
True.
|
||||
"""
|
||||
if not self.locked():
|
||||
self._value -= 1
|
||||
return True
|
||||
|
||||
if self._waiters is None:
|
||||
self._waiters = collections.deque()
|
||||
fut = self._get_loop().create_future()
|
||||
self._waiters.append(fut)
|
||||
|
||||
# Finally block should be called before the CancelledError
|
||||
# handling as we don't want CancelledError to call
|
||||
# _wake_up_first() and attempt to wake up itself.
|
||||
try:
|
||||
try:
|
||||
await fut
|
||||
finally:
|
||||
self._waiters.remove(fut)
|
||||
except exceptions.CancelledError:
|
||||
if not fut.cancelled():
|
||||
self._value += 1
|
||||
self._wake_up_next()
|
||||
raise
|
||||
|
||||
if self._value > self.min_val():
|
||||
self._wake_up_next()
|
||||
return True
|
||||
|
||||
|
||||
class OpenAIServer:
|
||||
def __init__(self, config: APIServerConfig):
|
||||
self.config = config
|
||||
self.openai = openai.AsyncClient(
|
||||
api_key=config.api_key,
|
||||
base_url=config.base_url,
|
||||
timeout=config.timeout,
|
||||
)
|
||||
self.sem = AsyncSemWithAdaptiveWeight(config.num_max_requests_at_once)
|
||||
self.eval_sem = AsyncSemWithAdaptiveWeight(config.num_requests_for_eval)
|
||||
self.server_healthy = True
|
||||
self.attempts_list = []
|
||||
self.request_timings = []
|
||||
# in case eval is much different, we should keep different buffers
|
||||
self.eval_attempts_list = []
|
||||
self.eval_request_timings = []
|
||||
self.check_task = None
|
||||
self.initialized = False
|
||||
|
||||
async def update_weight(self, weight: float) -> None:
|
||||
# need to update sems
|
||||
self.sem.update_weight(weight)
|
||||
self.eval_sem.update_weight(weight)
|
||||
super().__init__(config)
|
||||
|
||||
async def check_server_status_task(self):
|
||||
while True:
|
||||
|
|
@ -126,42 +42,16 @@ class OpenAIServer:
|
|||
self.server_healthy = False
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def wandb_metrics(
|
||||
self, metrics_dict: Optional[dict], server_name: Optional[str]
|
||||
):
|
||||
if server_name is None:
|
||||
server_name = "server"
|
||||
if len(self.request_timings) > 0:
|
||||
metrics_dict[f"server/{server_name}_request_time_avg"] = np.mean(
|
||||
self.request_timings
|
||||
)
|
||||
metrics_dict[f"server/{server_name}_request_time_std"] = np.std(
|
||||
self.request_timings
|
||||
)
|
||||
metrics_dict[f"server/{server_name}_request_time_99p"] = np.percentile(
|
||||
self.request_timings, 99
|
||||
)
|
||||
if len(self.eval_request_timings) > 0:
|
||||
metrics_dict[f"server/{server_name}_eval_request_time_avg"] = np.mean(
|
||||
self.eval_request_timings
|
||||
)
|
||||
metrics_dict[f"server/{server_name}_eval_request_time_std"] = np.std(
|
||||
self.eval_request_timings
|
||||
)
|
||||
metrics_dict[f"server/{server_name}_eval_request_time_99p"] = np.percentile(
|
||||
self.eval_request_timings, 99
|
||||
)
|
||||
if len(self.attempts_list) > 0:
|
||||
metrics_dict[f"server/{server_name}_average_num_attempts"] = np.mean(
|
||||
self.attempts_list
|
||||
)
|
||||
if len(self.eval_attempts_list) > 0:
|
||||
metrics_dict[f"server/{server_name}_eval_retry_rate"] = np.mean(
|
||||
self.eval_attempts_list
|
||||
)
|
||||
return metrics_dict
|
||||
|
||||
async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion:
|
||||
"""
|
||||
Wrapper for the chat completion using the openai client.
|
||||
"""
|
||||
assert (
|
||||
kwargs.get("model", None) is not None
|
||||
), "Model is required for chat completion!"
|
||||
assert (
|
||||
kwargs.get("messages", None) is not None
|
||||
), "Messages are required for chat completion!"
|
||||
if self.config.n_kwarg_is_ignored:
|
||||
n = kwargs.pop("n", 1)
|
||||
completion_list = await asyncio.gather(
|
||||
|
|
@ -197,62 +87,16 @@ class OpenAIServer:
|
|||
completions.choices.extend(c.choices)
|
||||
return completions
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
)
|
||||
async def _chat_comp(self, stat_dict, **kwargs) -> ChatCompletion:
|
||||
while not self.server_healthy:
|
||||
await asyncio.sleep(1)
|
||||
async with self.sem:
|
||||
if stat_dict.get("start", None) is None:
|
||||
stat_dict["start"] = time.time()
|
||||
stat_dict["attempts"] += 1
|
||||
completions = await self._chat_completion_wrapper(**kwargs)
|
||||
stat_dict["end"] = time.time()
|
||||
return completions
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
)
|
||||
async def _chat_eval(self, stat_dict, **kwargs) -> ChatCompletion:
|
||||
while not self.server_healthy:
|
||||
await asyncio.sleep(1)
|
||||
async with self.eval_sem:
|
||||
if stat_dict.get("start", None) is None:
|
||||
stat_dict["start"] = time.time()
|
||||
stat_dict["attempts"] += 1
|
||||
completions = await self._chat_completion_wrapper(**kwargs)
|
||||
stat_dict["end"] = time.time()
|
||||
return completions
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
)
|
||||
async def chat_completion(self, **kwargs) -> ChatCompletion:
|
||||
if not self.initialized:
|
||||
if (
|
||||
self.config.base_url is not None
|
||||
): # skip health check if using OpenAI API
|
||||
self.check_task = asyncio.create_task(self.check_server_status_task())
|
||||
else:
|
||||
self.server_healthy = True
|
||||
self.initialized = True
|
||||
kwargs["model"] = self.config.model_name
|
||||
split = kwargs.pop("split", "train")
|
||||
stat_dict = {}
|
||||
stat_dict["attempts"] = 0
|
||||
if split == "train":
|
||||
ret_data = await self._chat_comp(stat_dict, **kwargs)
|
||||
self.request_timings.append(stat_dict["end"] - stat_dict["start"])
|
||||
self.attempts_list.append(stat_dict["attempts"])
|
||||
else:
|
||||
# Give separate eval workers, if desired, gotta go fast for those evals
|
||||
ret_data = await self._chat_eval(stat_dict, **kwargs)
|
||||
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
|
||||
self.eval_attempts_list.append(stat_dict["attempts"])
|
||||
return ret_data
|
||||
|
||||
async def _completion_wrapper(self, **kwargs) -> Completion:
|
||||
"""
|
||||
Wrapper for the completion using the openai client.
|
||||
"""
|
||||
assert (
|
||||
kwargs.get("model", None) is not None
|
||||
), "Model is required for completion!"
|
||||
assert (
|
||||
kwargs.get("prompt", None) is not None
|
||||
), "Prompt is required for completion!"
|
||||
if self.config.n_kwarg_is_ignored:
|
||||
n = kwargs.pop("n", 1)
|
||||
completion_list = await asyncio.gather(
|
||||
|
|
@ -281,58 +125,7 @@ class OpenAIServer:
|
|||
)
|
||||
for c in completion_list:
|
||||
completions.choices.extend(c.choices)
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
)
|
||||
async def _comp(self, stat_dict, **kwargs) -> Completion:
|
||||
while not self.server_healthy:
|
||||
await asyncio.sleep(1)
|
||||
async with self.sem:
|
||||
if stat_dict.get("start", None) is None:
|
||||
stat_dict["start"] = time.time()
|
||||
stat_dict["attempts"] += 1
|
||||
completions = await self._completion_wrapper(**kwargs)
|
||||
stat_dict["end"] = time.time()
|
||||
return completions
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
)
|
||||
async def _comp_eval(self, stat_dict, **kwargs) -> Completion:
|
||||
while not self.server_healthy:
|
||||
await asyncio.sleep(1)
|
||||
async with self.eval_sem:
|
||||
if stat_dict.get("start", None) is None:
|
||||
stat_dict["start"] = time.time()
|
||||
stat_dict["attempts"] += 1
|
||||
completions = await self._completion_wrapper(**kwargs)
|
||||
stat_dict["end"] = time.time()
|
||||
return completions
|
||||
|
||||
async def completion(self, **kwargs) -> Completion:
|
||||
if not self.initialized:
|
||||
if (
|
||||
self.config.base_url is not None
|
||||
): # skip health check if using OpenAI API
|
||||
self.check_task = asyncio.create_task(self.check_server_status_task())
|
||||
else:
|
||||
self.server_healthy = True
|
||||
self.initialized = True
|
||||
kwargs["model"] = self.config.model_name
|
||||
split = kwargs.pop("split", "train")
|
||||
stat_dict = {}
|
||||
stat_dict["attempts"] = 0
|
||||
if split == "train":
|
||||
ret_data = await self._comp(stat_dict, **kwargs)
|
||||
self.request_timings.append(stat_dict["end"] - stat_dict["start"])
|
||||
self.attempts_list.append(stat_dict["attempts"])
|
||||
else:
|
||||
# Give separate eval workers, if desired, gotta go fast for those evals
|
||||
ret_data = await self._comp_eval(stat_dict, **kwargs)
|
||||
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
|
||||
self.eval_attempts_list.append(stat_dict["attempts"])
|
||||
return ret_data
|
||||
return completions
|
||||
|
||||
|
||||
def resolve_openai_configs(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue