import asyncio from typing import Dict, List, Literal, Union from openai.types.chat.chat_completion import ( ChatCompletion, ChatCompletionMessage, Choice, ) from openai.types.completion import Completion, CompletionChoice def create_chat_completion( resp: Union[str, List[str]], n: int = 1, finish_reason: Union[ Literal["stop", "length", "tool_calls", "content_filter", "function_call"], List[ Literal["stop", "length", "tool_calls", "content_filter", "function_call"] ], ] = "stop", ) -> ChatCompletion: """ Simple helper for creating a ChatCompletion object, if you need it :param resp: :param n: :param finish_reason: :return: """ choices = [ Choice( finish_reason=( finish_reason if isinstance(finish_reason, str) else finish_reason[i] ), index=i, message=ChatCompletionMessage( content=resp if isinstance(resp, str) else resp[i], role="assistant", ), ) for i in range(n) ] return ChatCompletion( id="test_id", created=0, model="test_model", object="chat.completion", choices=choices, ) def create_completion( resp: Union[str, List[str]], n: int = 1, finish_reason: Union[ Literal["stop", "length", "tool_calls", "content_filter", "function_call"], List[ Literal["stop", "length", "tool_calls", "content_filter", "function_call"] ], ] = "stop", ) -> Completion: """ Simple helper for creating a Completion object, if you need it :param resp: :param n: :param finish_reason: :return: """ choices = [ CompletionChoice( finish_reason=( finish_reason if isinstance(finish_reason, str) else finish_reason[i] ), index=i, text=resp if isinstance(resp, str) else resp[i], ) for i in range(n) ] return Completion( id="test_id", created=0, model="test_model", object="text_completion", choices=choices, ) class ServerHarness: def __init__(self): self.response_map = dict() self.tokens_and_logprobs_map = dict() # Map for tokens/logprobs responses self.sem = asyncio.Semaphore(1) self.eval_sem = asyncio.Semaphore(1) pass def conv_to_dictkey(self, input_message: List[Dict[str, str]]) -> str: dictkey = list() for item in input_message: dictkey.append(f"role:{item['role']}") dictkey.append(f"content:{item['content']}") return "\n".join(dictkey) async def update_weight(self, weight): pass def set_desired_response( self, input_message: List[Dict[str, str]], desired_response: ChatCompletion ): dictkey = self.conv_to_dictkey(input_message) self.response_map[dictkey] = desired_response def set_desired_completion(self, input_message: str, completion: Completion): self.response_map[input_message] = completion def set_tokens_and_logprobs_response( self, prompt: str, prompt_tokens: list, output_tokens_list: list, output_logprobs_list: list, finish_reasons: list, ): """ Set expected response for _tokens_and_logprobs_completion_wrapper. Args: prompt: The prompt string (key) prompt_tokens: List of prompt token IDs output_tokens_list: List of lists of output token IDs (one per completion) output_logprobs_list: List of lists of output logprobs (one per completion) finish_reasons: List of finish reasons (one per completion) """ self.tokens_and_logprobs_map[prompt] = ( prompt_tokens, output_tokens_list, output_logprobs_list, finish_reasons, ) async def chat_completion(self, *args, **kwargs) -> ChatCompletion: messages = kwargs.get("messages") dictkey = self.conv_to_dictkey(messages) try: return self.response_map.get(dictkey) except KeyError as e: raise KeyError(f"KeyError: {e} for key:\n{dictkey}") async def completion(self, *args, **kwargs) -> Completion: prompt = kwargs.get("prompt") try: return self.response_map.get(prompt) except KeyError as e: raise KeyError(f"KeyError: {e} for key:\n{prompt}") async def tokens_and_logprobs_completion( self, **kwargs ) -> tuple[list, list, list, list]: """ Mock implementation of tokens and logprobs completion wrapper. Returns: Tuple of (prompt_tokens, output_tokens_list, output_logprobs_list, finish_reasons) """ prompt = kwargs.get("prompt") try: return self.tokens_and_logprobs_map.get(prompt) except KeyError as e: raise KeyError(f"KeyError: {e} for prompt:\n{prompt}") if __name__ == "__main__": async def main(): test_compl = create_chat_completion("hello") harness = ServerHarness() harness.set_desired_response([{"role": "user", "content": "hi"}], test_compl) print(harness.response_map) print(harness.conv_to_dictkey([{"role": "user", "content": "hi"}])) print( await harness.chat_completion(messages=[{"role": "user", "content": "hi"}]) ) # now, let's test the completion test_completion = create_completion("\nhello") harness.set_desired_completion("hi", test_completion) print(harness.response_map) print(await harness.completion(prompt="hi")) asyncio.run(main())