thanks cursor

This commit is contained in:
Shannon Sands 2025-05-17 13:00:35 -07:00
parent 59b08a1aa9
commit 90138376f9

View file

@ -6,7 +6,6 @@ Developed with much help from @winglian when they worked on integrating Atropos
import time
import uuid
import asyncio
import aiohttp
from openai.types.chat.chat_completion import (
@ -29,39 +28,11 @@ class TrlVllmServer(APIServer):
self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
super().__init__(config)
async def check_server_status_task(self, _ = True):
async def check_server_status_task(self, chat_completion: bool = True):
"""
Perform a health check for the TRL VLLM server by sending a minimal request to /generate.
TODO: Implement server health check for trl's vLLM server
"""
health_check_url = f"{self.config.base_url}/generate/"
# Minimal payload that the /generate endpoint would accept without erroring.
# This might need adjustment based on the TRL VLLM server's specific requirements.
minimal_payload = {
"prompts": ["test"], # Using a non-empty prompt as some servers might require it
"max_tokens": 1,
"n": 1
}
while True:
try:
async with aiohttp.ClientSession() as session:
async with session.post(health_check_url, json=minimal_payload, timeout=10) as response: # Added timeout
if response.status == 200:
# Optionally, further check response content if a specific "healthy" body is expected
# For now, status 200 is considered healthy.
self.server_healthy = True
else:
# Log warning for non-200 status if a logger is available/configured
# logger.warning(f"TRL VLLM server health check failed: Status {response.status} for {health_check_url}")
self.server_healthy = False
except aiohttp.ClientError as e:
# Log error for connection issues if a logger is available/configured
# logger.error(f"TRL VLLM server health check connection error: {e}")
self.server_healthy = False
except Exception as e: # Catch any other unexpected errors during the check
# Log error for unexpected issues if a logger is available/configured
# logger.error(f"Unexpected error during TRL VLLM server health check: {e}")
self.server_healthy = False
await asyncio.sleep(10) # Check periodically (e.g., every 10 seconds)
self.server_healthy = True
async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion:
"""