added health check flag to skip entirely

This commit is contained in:
Shannon Sands 2025-05-17 13:36:26 -07:00
parent f84934363c
commit c05d9f7f53
3 changed files with 34 additions and 13 deletions

View file

@ -24,7 +24,12 @@ class OpenAIServer(APIServer):
)
super().__init__(config)
async def check_server_status_task(self, chat_completion: bool = True):
async def check_server_status_task(
self, chat_completion: bool = True, skip_check: bool = False
):
if skip_check:
self.server_healthy = True
return
while True:
try:
if chat_completion:

View file

@ -123,6 +123,9 @@ class APIServerConfig(ServerBaseline):
n_kwarg_is_ignored: bool = Field(
default=False, description="Whether the n kwarg is ignored by this API server."
)
health_check: bool = Field(
default=True, description="Whether to perform a health check on the server."
)
class APIServer(ABC):
@ -152,7 +155,9 @@ class APIServer(ABC):
self.eval_sem.update_weight(weight)
@abstractmethod
async def check_server_status_task(self, chat_completion: bool = True):
async def check_server_status_task(
self, chat_completion: bool = True, skip_check: bool = False
):
"""
Check the status of the server. Should be overridden by the child class.
Set self.server_healthy to True if the server is healthy.
@ -256,10 +261,15 @@ class APIServer(ABC):
Chat completion handler, waits for the server to be healthy and then calls the chat completion wrapper.
"""
if not self.initialized:
if (
self.config.base_url is not None
): # skip health check if using OpenAI API
self.check_task = asyncio.create_task(self.check_server_status_task())
if self.config.health_check:
if (
self.config.base_url is not None
): # skip health check if using OpenAI API
self.check_task = asyncio.create_task(
self.check_server_status_task()
)
else:
self.server_healthy = True
else:
self.server_healthy = True
self.initialized = True
@ -317,13 +327,17 @@ class APIServer(ABC):
Completion handler, waits for the server to be healthy and then calls the completion wrapper.
"""
if not self.initialized:
if (
self.config.base_url is not None
): # skip health check if using OpenAI API
self.check_task = asyncio.create_task(
self.check_server_status_task(chat_completion=False)
)
if self.config.health_check:
if (
self.config.base_url is not None
): # skip health check if using OpenAI API
self.check_task = asyncio.create_task(
self.check_server_status_task(chat_completion=False)
)
else:
self.server_healthy = True
else:
# If health_check is False, always assume healthy
self.server_healthy = True
self.initialized = True
kwargs["model"] = self.config.model_name

View file

@ -28,7 +28,9 @@ class TrlVllmServer(APIServer):
self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
super().__init__(config)
async def check_server_status_task(self, chat_completion: bool = True):
async def check_server_status_task(
self, chat_completion: bool = True, skip_check: bool = False
):
"""
TODO: Implement server health check for trl's vLLM server
"""