add sglang specific token level logprob handling and server manager/baseline logprob/token fn

This commit is contained in:
Dakota 2025-10-16 12:38:03 -05:00
parent 4862e9972f
commit c36ec29656
4 changed files with 512 additions and 37 deletions

View file

@ -108,7 +108,7 @@ class ServerBaseline(BaseModel):
rolling_buffer_length: int = Field(
default=1000, description="Length of the rolling buffer to store metrics."
)
server_type: Literal["openai", "trl"] = Field(
server_type: Literal["openai", "trl", "sglang"] = Field(
default="openai", description="Type of server to use, openai or trl"
)
@ -217,6 +217,16 @@ class APIServer(ABC):
"""
pass
@abstractmethod
async def _tokens_and_logprobs_completion_wrapper(
self, **kwargs
) -> tuple[list, list, list, list]:
"""
Wrapper for tokens and logprobs completion. Should be overridden by the child class.
Returns a tuple of (prompt_tokens, output_tokens, output_logprobs, finish_reasons).
"""
pass
@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
)
@ -352,3 +362,77 @@ class APIServer(ABC):
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
self.eval_attempts_list.append(stat_dict["attempts"])
return ret_data
@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
)
async def _tokens_and_logprobs_comp(
self, stat_dict, **kwargs
) -> tuple[list, list, list, list]:
"""
Simple retry and stat collection wrapper for tokens and logprobs completion.
"""
while not self.server_healthy:
await asyncio.sleep(1)
async with self.sem:
if stat_dict.get("start", None) is None:
stat_dict["start"] = time.time()
stat_dict["attempts"] += 1
completions = await self._tokens_and_logprobs_completion_wrapper(**kwargs)
stat_dict["end"] = time.time()
return completions
@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
)
async def _tokens_and_logprobs_comp_eval(
self, stat_dict, **kwargs
) -> tuple[list, list, list, list]:
"""
Simple retry and stat collection wrapper for tokens and logprobs completion.
"""
while not self.server_healthy:
await asyncio.sleep(1)
async with self.eval_sem:
if stat_dict.get("start", None) is None:
stat_dict["start"] = time.time()
stat_dict["attempts"] += 1
completions = await self._tokens_and_logprobs_completion_wrapper(**kwargs)
stat_dict["end"] = time.time()
return completions
async def tokens_and_logprobs_completion(
self, **kwargs
) -> tuple[list, list, list, list]:
"""
Tokens and logprobs completion handler, waits for the server to be healthy and then calls the wrapper.
Returns a tuple of (prompt_tokens, output_tokens, output_logprobs, finish_reasons).
"""
if not self.initialized:
if self.config.health_check:
if (
self.config.base_url is not None
): # skip health check if using OpenAI API
self.check_task = asyncio.create_task(
self.check_server_status_task(chat_completion=False)
)
else:
self.server_healthy = True
else:
# If health_check is False, always assume healthy
self.server_healthy = True
self.initialized = True
kwargs["model"] = self.config.model_name
split = kwargs.pop("split", "train")
stat_dict = {}
stat_dict["attempts"] = 0
if split == "train":
ret_data = await self._tokens_and_logprobs_comp(stat_dict, **kwargs)
self.request_timings.append(stat_dict["end"] - stat_dict["start"])
self.attempts_list.append(stat_dict["attempts"])
else:
# Give separate eval workers, if desired, gotta go fast for those evals
ret_data = await self._tokens_and_logprobs_comp_eval(stat_dict, **kwargs)
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
self.eval_attempts_list.append(stat_dict["attempts"])
return ret_data