add managed server to make grabbing logprobs easier w/ tokenized items

This commit is contained in:
dmahan93 2025-10-24 13:09:46 -07:00
parent 312f8859e3
commit 7bf4cfbf80
6 changed files with 1138 additions and 29 deletions

View file

@ -87,6 +87,7 @@ def create_completion(
class ServerHarness:
def __init__(self):
self.response_map = dict()
self.tokens_and_logprobs_map = dict() # Map for tokens/logprobs responses
self.sem = asyncio.Semaphore(1)
self.eval_sem = asyncio.Semaphore(1)
pass
@ -110,6 +111,31 @@ class ServerHarness:
def set_desired_completion(self, input_message: str, completion: Completion):
self.response_map[input_message] = completion
def set_tokens_and_logprobs_response(
self,
prompt: str,
prompt_tokens: list,
output_tokens_list: list,
output_logprobs_list: list,
finish_reasons: list,
):
"""
Set expected response for _tokens_and_logprobs_completion_wrapper.
Args:
prompt: The prompt string (key)
prompt_tokens: List of prompt token IDs
output_tokens_list: List of lists of output token IDs (one per completion)
output_logprobs_list: List of lists of output logprobs (one per completion)
finish_reasons: List of finish reasons (one per completion)
"""
self.tokens_and_logprobs_map[prompt] = (
prompt_tokens,
output_tokens_list,
output_logprobs_list,
finish_reasons,
)
async def chat_completion(self, *args, **kwargs) -> ChatCompletion:
messages = kwargs.get("messages")
dictkey = self.conv_to_dictkey(messages)
@ -125,6 +151,21 @@ class ServerHarness:
except KeyError as e:
raise KeyError(f"KeyError: {e} for key:\n{prompt}")
async def _tokens_and_logprobs_completion_wrapper(
self, **kwargs
) -> tuple[list, list, list, list]:
"""
Mock implementation of tokens and logprobs completion wrapper.
Returns:
Tuple of (prompt_tokens, output_tokens_list, output_logprobs_list, finish_reasons)
"""
prompt = kwargs.get("prompt")
try:
return self.tokens_and_logprobs_map.get(prompt)
except KeyError as e:
raise KeyError(f"KeyError: {e} for prompt:\n{prompt}")
if __name__ == "__main__":