mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
managed_Server pass through and centralize sem logic
This commit is contained in:
parent
c85a3e5ee7
commit
b91922082e
4 changed files with 208 additions and 26 deletions
|
|
@ -421,6 +421,15 @@ class APIServer(ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
async def _get_logprobs_wrapper(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Wrapper for prompt logprobs. Can be overridden by child classes.
|
||||
Returns a dict containing prompt_tokens, prompt_topk_token_ids, prompt_topk_logprobs.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not implement _get_logprobs_wrapper."
|
||||
)
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
)
|
||||
|
|
@ -639,6 +648,40 @@ class APIServer(ABC):
|
|||
self.eval_attempts_list.append(stat_dict["attempts"])
|
||||
return ret_data
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
)
|
||||
async def _logprobs(self, stat_dict, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Simple retry and stat collection wrapper for get_logprobs.
|
||||
"""
|
||||
while not self.server_healthy:
|
||||
await asyncio.sleep(1)
|
||||
async with self.sem:
|
||||
if stat_dict.get("start", None) is None:
|
||||
stat_dict["start"] = time.time()
|
||||
stat_dict["attempts"] += 1
|
||||
payload = await self._get_logprobs_wrapper(**kwargs)
|
||||
stat_dict["end"] = time.time()
|
||||
return payload
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
)
|
||||
async def _logprobs_eval(self, stat_dict, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Simple retry and stat collection wrapper for get_logprobs eval.
|
||||
"""
|
||||
while not self.server_healthy:
|
||||
await asyncio.sleep(1)
|
||||
async with self.eval_sem:
|
||||
if stat_dict.get("start", None) is None:
|
||||
stat_dict["start"] = time.time()
|
||||
stat_dict["attempts"] += 1
|
||||
payload = await self._get_logprobs_wrapper(**kwargs)
|
||||
stat_dict["end"] = time.time()
|
||||
return payload
|
||||
|
||||
async def get_logprobs(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Prompt-logprob API with strict normalized output schema.
|
||||
|
|
@ -649,8 +692,27 @@ class APIServer(ABC):
|
|||
- prompt_topk_token_ids: List[List[int]]
|
||||
- prompt_topk_logprobs: List[List[float]]
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__}.get_logprobs must be implemented by the "
|
||||
"server backend and must return prompt_tokens, "
|
||||
"prompt_topk_token_ids, and prompt_topk_logprobs."
|
||||
)
|
||||
if not self.initialized:
|
||||
if self.config.health_check:
|
||||
if self.config.base_url is not None:
|
||||
self.check_task = asyncio.create_task(
|
||||
self.check_server_status_task(chat_completion=False)
|
||||
)
|
||||
else:
|
||||
self.server_healthy = True
|
||||
else:
|
||||
self.server_healthy = True
|
||||
self.initialized = True
|
||||
|
||||
kwargs["model"] = self.config.model_name
|
||||
split = kwargs.pop("split", "train")
|
||||
stat_dict = {"attempts": 0}
|
||||
if split == "train":
|
||||
payload = await self._logprobs(stat_dict, **kwargs)
|
||||
self.request_timings.append(stat_dict["end"] - stat_dict["start"])
|
||||
self.attempts_list.append(stat_dict["attempts"])
|
||||
else:
|
||||
payload = await self._logprobs_eval(stat_dict, **kwargs)
|
||||
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
|
||||
self.eval_attempts_list.append(stat_dict["attempts"])
|
||||
return payload
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue