mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
added health check flag to skip entirely
This commit is contained in:
parent
f84934363c
commit
c05d9f7f53
3 changed files with 34 additions and 13 deletions
|
|
@ -24,7 +24,12 @@ class OpenAIServer(APIServer):
|
|||
)
|
||||
super().__init__(config)
|
||||
|
||||
async def check_server_status_task(self, chat_completion: bool = True):
|
||||
async def check_server_status_task(
|
||||
self, chat_completion: bool = True, skip_check: bool = False
|
||||
):
|
||||
if skip_check:
|
||||
self.server_healthy = True
|
||||
return
|
||||
while True:
|
||||
try:
|
||||
if chat_completion:
|
||||
|
|
|
|||
|
|
@ -123,6 +123,9 @@ class APIServerConfig(ServerBaseline):
|
|||
n_kwarg_is_ignored: bool = Field(
|
||||
default=False, description="Whether the n kwarg is ignored by this API server."
|
||||
)
|
||||
health_check: bool = Field(
|
||||
default=True, description="Whether to perform a health check on the server."
|
||||
)
|
||||
|
||||
|
||||
class APIServer(ABC):
|
||||
|
|
@ -152,7 +155,9 @@ class APIServer(ABC):
|
|||
self.eval_sem.update_weight(weight)
|
||||
|
||||
@abstractmethod
|
||||
async def check_server_status_task(self, chat_completion: bool = True):
|
||||
async def check_server_status_task(
|
||||
self, chat_completion: bool = True, skip_check: bool = False
|
||||
):
|
||||
"""
|
||||
Check the status of the server. Should be overridden by the child class.
|
||||
Set self.server_healthy to True if the server is healthy.
|
||||
|
|
@ -256,10 +261,15 @@ class APIServer(ABC):
|
|||
Chat completion handler, waits for the server to be healthy and then calls the chat completion wrapper.
|
||||
"""
|
||||
if not self.initialized:
|
||||
if (
|
||||
self.config.base_url is not None
|
||||
): # skip health check if using OpenAI API
|
||||
self.check_task = asyncio.create_task(self.check_server_status_task())
|
||||
if self.config.health_check:
|
||||
if (
|
||||
self.config.base_url is not None
|
||||
): # skip health check if using OpenAI API
|
||||
self.check_task = asyncio.create_task(
|
||||
self.check_server_status_task()
|
||||
)
|
||||
else:
|
||||
self.server_healthy = True
|
||||
else:
|
||||
self.server_healthy = True
|
||||
self.initialized = True
|
||||
|
|
@ -317,13 +327,17 @@ class APIServer(ABC):
|
|||
Completion handler, waits for the server to be healthy and then calls the completion wrapper.
|
||||
"""
|
||||
if not self.initialized:
|
||||
if (
|
||||
self.config.base_url is not None
|
||||
): # skip health check if using OpenAI API
|
||||
self.check_task = asyncio.create_task(
|
||||
self.check_server_status_task(chat_completion=False)
|
||||
)
|
||||
if self.config.health_check:
|
||||
if (
|
||||
self.config.base_url is not None
|
||||
): # skip health check if using OpenAI API
|
||||
self.check_task = asyncio.create_task(
|
||||
self.check_server_status_task(chat_completion=False)
|
||||
)
|
||||
else:
|
||||
self.server_healthy = True
|
||||
else:
|
||||
# If health_check is False, always assume healthy
|
||||
self.server_healthy = True
|
||||
self.initialized = True
|
||||
kwargs["model"] = self.config.model_name
|
||||
|
|
|
|||
|
|
@ -28,7 +28,9 @@ class TrlVllmServer(APIServer):
|
|||
self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
||||
super().__init__(config)
|
||||
|
||||
async def check_server_status_task(self, chat_completion: bool = True):
|
||||
async def check_server_status_task(
|
||||
self, chat_completion: bool = True, skip_check: bool = False
|
||||
):
|
||||
"""
|
||||
TODO: Implement server health check for trl's vLLM server
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue