managed_Server pass through and centralize sem logic

This commit is contained in:
Jai Suphavadeeprasit 2026-03-05 15:46:33 -05:00
parent c85a3e5ee7
commit b91922082e
4 changed files with 208 additions and 26 deletions

View file

@ -421,6 +421,15 @@ class APIServer(ABC):
"""
pass
async def _get_logprobs_wrapper(self, **kwargs) -> Dict[str, Any]:
"""
Wrapper for prompt logprobs. Can be overridden by child classes.
Returns a dict containing prompt_tokens, prompt_topk_token_ids, prompt_topk_logprobs.
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement _get_logprobs_wrapper."
)
@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
)
@ -639,6 +648,40 @@ class APIServer(ABC):
self.eval_attempts_list.append(stat_dict["attempts"])
return ret_data
@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
)
async def _logprobs(self, stat_dict, **kwargs) -> Dict[str, Any]:
"""
Simple retry and stat collection wrapper for get_logprobs.
"""
while not self.server_healthy:
await asyncio.sleep(1)
async with self.sem:
if stat_dict.get("start", None) is None:
stat_dict["start"] = time.time()
stat_dict["attempts"] += 1
payload = await self._get_logprobs_wrapper(**kwargs)
stat_dict["end"] = time.time()
return payload
@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
)
async def _logprobs_eval(self, stat_dict, **kwargs) -> Dict[str, Any]:
"""
Simple retry and stat collection wrapper for get_logprobs eval.
"""
while not self.server_healthy:
await asyncio.sleep(1)
async with self.eval_sem:
if stat_dict.get("start", None) is None:
stat_dict["start"] = time.time()
stat_dict["attempts"] += 1
payload = await self._get_logprobs_wrapper(**kwargs)
stat_dict["end"] = time.time()
return payload
async def get_logprobs(self, **kwargs) -> Dict[str, Any]:
"""
Prompt-logprob API with strict normalized output schema.
@ -649,8 +692,27 @@ class APIServer(ABC):
- prompt_topk_token_ids: List[List[int]]
- prompt_topk_logprobs: List[List[float]]
"""
raise NotImplementedError(
f"{self.__class__.__name__}.get_logprobs must be implemented by the "
"server backend and must return prompt_tokens, "
"prompt_topk_token_ids, and prompt_topk_logprobs."
)
if not self.initialized:
if self.config.health_check:
if self.config.base_url is not None:
self.check_task = asyncio.create_task(
self.check_server_status_task(chat_completion=False)
)
else:
self.server_healthy = True
else:
self.server_healthy = True
self.initialized = True
kwargs["model"] = self.config.model_name
split = kwargs.pop("split", "train")
stat_dict = {"attempts": 0}
if split == "train":
payload = await self._logprobs(stat_dict, **kwargs)
self.request_timings.append(stat_dict["end"] - stat_dict["start"])
self.attempts_list.append(stat_dict["attempts"])
else:
payload = await self._logprobs_eval(stat_dict, **kwargs)
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
self.eval_attempts_list.append(stat_dict["attempts"])
return payload

View file

@ -260,7 +260,7 @@ class VLLMServer(APIServer):
return [], []
async def get_logprobs(self, **kwargs) -> Dict[str, Any]:
async def _get_logprobs_wrapper(self, **kwargs) -> Dict[str, Any]:
"""
Fetch normalized prompt logprobs from vLLM /generate with optional top-k.
@ -315,26 +315,19 @@ class VLLMServer(APIServer):
request_data["top_p"] = 1.0
request_data.setdefault("max_tokens", 1)
# Keep semaphore behavior consistent with other server calls.
split = request_data.pop("split", "train")
sem = self.sem if split == "train" else self.eval_sem
while not self.server_healthy:
await asyncio.sleep(1)
async with sem:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.config.base_url.replace('/v1', '')}/generate",
json=request_data,
headers=(
{"Authorization": f"Bearer {self.config.api_key}"}
if self.config.api_key
else {}
),
timeout=aiohttp.ClientTimeout(total=self.config.timeout),
) as response:
response.raise_for_status()
results = await response.json()
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.config.base_url.replace('/v1', '')}/generate",
json=request_data,
headers=(
{"Authorization": f"Bearer {self.config.api_key}"}
if self.config.api_key
else {}
),
timeout=aiohttp.ClientTimeout(total=self.config.timeout),
) as response:
response.raise_for_status()
results = await response.json()
raw_prompt_logprobs = results.get("prompt_logprobs")
if raw_prompt_logprobs is None: