mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
add sglang specific token level logprob handling and server manager/baseline logprob/token fn
This commit is contained in:
parent
4862e9972f
commit
c36ec29656
4 changed files with 512 additions and 37 deletions
|
|
@ -108,7 +108,7 @@ class ServerBaseline(BaseModel):
|
|||
rolling_buffer_length: int = Field(
|
||||
default=1000, description="Length of the rolling buffer to store metrics."
|
||||
)
|
||||
server_type: Literal["openai", "trl"] = Field(
|
||||
server_type: Literal["openai", "trl", "sglang"] = Field(
|
||||
default="openai", description="Type of server to use, openai or trl"
|
||||
)
|
||||
|
||||
|
|
@ -217,6 +217,16 @@ class APIServer(ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _tokens_and_logprobs_completion_wrapper(
|
||||
self, **kwargs
|
||||
) -> tuple[list, list, list, list]:
|
||||
"""
|
||||
Wrapper for tokens and logprobs completion. Should be overridden by the child class.
|
||||
Returns a tuple of (prompt_tokens, output_tokens, output_logprobs, finish_reasons).
|
||||
"""
|
||||
pass
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
)
|
||||
|
|
@ -352,3 +362,77 @@ class APIServer(ABC):
|
|||
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
|
||||
self.eval_attempts_list.append(stat_dict["attempts"])
|
||||
return ret_data
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
)
|
||||
async def _tokens_and_logprobs_comp(
|
||||
self, stat_dict, **kwargs
|
||||
) -> tuple[list, list, list, list]:
|
||||
"""
|
||||
Simple retry and stat collection wrapper for tokens and logprobs completion.
|
||||
"""
|
||||
while not self.server_healthy:
|
||||
await asyncio.sleep(1)
|
||||
async with self.sem:
|
||||
if stat_dict.get("start", None) is None:
|
||||
stat_dict["start"] = time.time()
|
||||
stat_dict["attempts"] += 1
|
||||
completions = await self._tokens_and_logprobs_completion_wrapper(**kwargs)
|
||||
stat_dict["end"] = time.time()
|
||||
return completions
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
)
|
||||
async def _tokens_and_logprobs_comp_eval(
|
||||
self, stat_dict, **kwargs
|
||||
) -> tuple[list, list, list, list]:
|
||||
"""
|
||||
Simple retry and stat collection wrapper for tokens and logprobs completion.
|
||||
"""
|
||||
while not self.server_healthy:
|
||||
await asyncio.sleep(1)
|
||||
async with self.eval_sem:
|
||||
if stat_dict.get("start", None) is None:
|
||||
stat_dict["start"] = time.time()
|
||||
stat_dict["attempts"] += 1
|
||||
completions = await self._tokens_and_logprobs_completion_wrapper(**kwargs)
|
||||
stat_dict["end"] = time.time()
|
||||
return completions
|
||||
|
||||
async def tokens_and_logprobs_completion(
|
||||
self, **kwargs
|
||||
) -> tuple[list, list, list, list]:
|
||||
"""
|
||||
Tokens and logprobs completion handler, waits for the server to be healthy and then calls the wrapper.
|
||||
Returns a tuple of (prompt_tokens, output_tokens, output_logprobs, finish_reasons).
|
||||
"""
|
||||
if not self.initialized:
|
||||
if self.config.health_check:
|
||||
if (
|
||||
self.config.base_url is not None
|
||||
): # skip health check if using OpenAI API
|
||||
self.check_task = asyncio.create_task(
|
||||
self.check_server_status_task(chat_completion=False)
|
||||
)
|
||||
else:
|
||||
self.server_healthy = True
|
||||
else:
|
||||
# If health_check is False, always assume healthy
|
||||
self.server_healthy = True
|
||||
self.initialized = True
|
||||
kwargs["model"] = self.config.model_name
|
||||
split = kwargs.pop("split", "train")
|
||||
stat_dict = {}
|
||||
stat_dict["attempts"] = 0
|
||||
if split == "train":
|
||||
ret_data = await self._tokens_and_logprobs_comp(stat_dict, **kwargs)
|
||||
self.request_timings.append(stat_dict["end"] - stat_dict["start"])
|
||||
self.attempts_list.append(stat_dict["attempts"])
|
||||
else:
|
||||
# Give separate eval workers, if desired, gotta go fast for those evals
|
||||
ret_data = await self._tokens_and_logprobs_comp_eval(stat_dict, **kwargs)
|
||||
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
|
||||
self.eval_attempts_list.append(stat_dict["attempts"])
|
||||
return ret_data
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue