prompt logprobs simplicity

This commit is contained in:
Jai Suphavadeeprasit 2026-03-03 22:06:49 -05:00
parent f1c20591b6
commit 5aaf7a346c
5 changed files with 69 additions and 69 deletions

View file

@ -272,22 +272,24 @@ async def test_get_logprobs_normalized_schema(mock_server):
prompt = "Hello"
prompt_tokens = mock_server.tokenizer.encode(prompt)
output_tokens = [[ord("!"), ord("?")]]
output_logprobs = [[-0.1, -0.2]]
prompt_topk_token_ids = [[t, t + 1] for t in prompt_tokens]
prompt_topk_logprobs = [[-0.1, -0.2] for _ in prompt_tokens]
mock_server.set_tokens_and_logprobs_response(
prompt=prompt,
prompt_tokens=prompt_tokens,
output_tokens_list=output_tokens,
output_logprobs_list=output_logprobs,
finish_reasons=["stop"],
)
async def _mock_get_logprobs(**kwargs):
assert kwargs.get("prompt") == prompt
return {
"prompt_tokens": prompt_tokens,
"prompt_topk_token_ids": prompt_topk_token_ids,
"prompt_topk_logprobs": prompt_topk_logprobs,
}
mock_server.get_logprobs = _mock_get_logprobs
payload = await managed.get_logprobs(prompt=prompt, n=1)
assert payload["prompt_tokens"] == prompt_tokens
assert payload["prompt_topk_token_ids"] == [[tok] for tok in prompt_tokens]
assert payload["prompt_topk_logprobs"] == [[1.0] for _ in prompt_tokens]
assert payload["prompt_topk_token_ids"] == prompt_topk_token_ids
assert payload["prompt_topk_logprobs"] == prompt_topk_logprobs
@pytest.mark.asyncio