managed_Server pass through and centralize sem logic

This commit is contained in:
Jai Suphavadeeprasit 2026-03-05 15:46:33 -05:00
parent c85a3e5ee7
commit b91922082e
4 changed files with 208 additions and 26 deletions

View file

@ -421,6 +421,15 @@ class APIServer(ABC):
"""
pass
async def _get_logprobs_wrapper(self, **kwargs) -> Dict[str, Any]:
"""
Wrapper for prompt logprobs. Can be overridden by child classes.
Returns a dict containing prompt_tokens, prompt_topk_token_ids, prompt_topk_logprobs.
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement _get_logprobs_wrapper."
)
@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
)
@ -639,6 +648,40 @@ class APIServer(ABC):
self.eval_attempts_list.append(stat_dict["attempts"])
return ret_data
@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
)
async def _logprobs(self, stat_dict, **kwargs) -> Dict[str, Any]:
"""
Simple retry and stat collection wrapper for get_logprobs.
"""
while not self.server_healthy:
await asyncio.sleep(1)
async with self.sem:
if stat_dict.get("start", None) is None:
stat_dict["start"] = time.time()
stat_dict["attempts"] += 1
payload = await self._get_logprobs_wrapper(**kwargs)
stat_dict["end"] = time.time()
return payload
@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
)
async def _logprobs_eval(self, stat_dict, **kwargs) -> Dict[str, Any]:
"""
Simple retry and stat collection wrapper for get_logprobs eval.
"""
while not self.server_healthy:
await asyncio.sleep(1)
async with self.eval_sem:
if stat_dict.get("start", None) is None:
stat_dict["start"] = time.time()
stat_dict["attempts"] += 1
payload = await self._get_logprobs_wrapper(**kwargs)
stat_dict["end"] = time.time()
return payload
async def get_logprobs(self, **kwargs) -> Dict[str, Any]:
"""
Prompt-logprob API with strict normalized output schema.
@ -649,8 +692,27 @@ class APIServer(ABC):
- prompt_topk_token_ids: List[List[int]]
- prompt_topk_logprobs: List[List[float]]
"""
raise NotImplementedError(
f"{self.__class__.__name__}.get_logprobs must be implemented by the "
"server backend and must return prompt_tokens, "
"prompt_topk_token_ids, and prompt_topk_logprobs."
)
if not self.initialized:
if self.config.health_check:
if self.config.base_url is not None:
self.check_task = asyncio.create_task(
self.check_server_status_task(chat_completion=False)
)
else:
self.server_healthy = True
else:
self.server_healthy = True
self.initialized = True
kwargs["model"] = self.config.model_name
split = kwargs.pop("split", "train")
stat_dict = {"attempts": 0}
if split == "train":
payload = await self._logprobs(stat_dict, **kwargs)
self.request_timings.append(stat_dict["end"] - stat_dict["start"])
self.attempts_list.append(stat_dict["attempts"])
else:
payload = await self._logprobs_eval(stat_dict, **kwargs)
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
self.eval_attempts_list.append(stat_dict["attempts"])
return payload