mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
add managed server to make grabbing logprobs easier w/ tokenized items
This commit is contained in:
parent
312f8859e3
commit
7bf4cfbf80
6 changed files with 1138 additions and 29 deletions
|
|
@ -87,6 +87,7 @@ def create_completion(
|
|||
class ServerHarness:
|
||||
def __init__(self):
|
||||
self.response_map = dict()
|
||||
self.tokens_and_logprobs_map = dict() # Map for tokens/logprobs responses
|
||||
self.sem = asyncio.Semaphore(1)
|
||||
self.eval_sem = asyncio.Semaphore(1)
|
||||
pass
|
||||
|
|
@ -110,6 +111,31 @@ class ServerHarness:
|
|||
def set_desired_completion(self, input_message: str, completion: Completion):
|
||||
self.response_map[input_message] = completion
|
||||
|
||||
def set_tokens_and_logprobs_response(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_tokens: list,
|
||||
output_tokens_list: list,
|
||||
output_logprobs_list: list,
|
||||
finish_reasons: list,
|
||||
):
|
||||
"""
|
||||
Set expected response for _tokens_and_logprobs_completion_wrapper.
|
||||
|
||||
Args:
|
||||
prompt: The prompt string (key)
|
||||
prompt_tokens: List of prompt token IDs
|
||||
output_tokens_list: List of lists of output token IDs (one per completion)
|
||||
output_logprobs_list: List of lists of output logprobs (one per completion)
|
||||
finish_reasons: List of finish reasons (one per completion)
|
||||
"""
|
||||
self.tokens_and_logprobs_map[prompt] = (
|
||||
prompt_tokens,
|
||||
output_tokens_list,
|
||||
output_logprobs_list,
|
||||
finish_reasons,
|
||||
)
|
||||
|
||||
async def chat_completion(self, *args, **kwargs) -> ChatCompletion:
|
||||
messages = kwargs.get("messages")
|
||||
dictkey = self.conv_to_dictkey(messages)
|
||||
|
|
@ -125,6 +151,21 @@ class ServerHarness:
|
|||
except KeyError as e:
|
||||
raise KeyError(f"KeyError: {e} for key:\n{prompt}")
|
||||
|
||||
async def _tokens_and_logprobs_completion_wrapper(
|
||||
self, **kwargs
|
||||
) -> tuple[list, list, list, list]:
|
||||
"""
|
||||
Mock implementation of tokens and logprobs completion wrapper.
|
||||
|
||||
Returns:
|
||||
Tuple of (prompt_tokens, output_tokens_list, output_logprobs_list, finish_reasons)
|
||||
"""
|
||||
prompt = kwargs.get("prompt")
|
||||
try:
|
||||
return self.tokens_and_logprobs_map.get(prompt)
|
||||
except KeyError as e:
|
||||
raise KeyError(f"KeyError: {e} for prompt:\n{prompt}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue