prompt logprobs

This commit is contained in:
Jai Suphavadeeprasit 2026-03-03 21:56:11 -05:00
parent e98100e5f6
commit 439b9b129b
7 changed files with 73 additions and 138 deletions

View file

@ -267,7 +267,7 @@ async def test_bos_token_handling(mock_server):
@pytest.mark.asyncio
async def test_get_logprobs_normalized_schema(mock_server):
"""ManagedServer.get_logprobs returns normalized schema."""
"""ManagedServer.get_logprobs returns normalized prompt schema."""
managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer)
prompt = "Hello"
@ -286,11 +286,8 @@ async def test_get_logprobs_normalized_schema(mock_server):
payload = await managed.get_logprobs(prompt=prompt, n=1)
assert payload["prompt_tokens"] == prompt_tokens
assert payload["sequence_token_ids"] == output_tokens
assert payload["sequence_logprobs"] == output_logprobs
assert payload["finish_reasons"] == ["stop"]
assert payload["sequence_topk_token_ids"] == [[[ord("!")], [ord("?")]]]
assert payload["sequence_topk_logprobs"] == [[[-0.1], [-0.2]]]
assert payload["prompt_topk_token_ids"] == [[tok] for tok in prompt_tokens]
assert payload["prompt_topk_logprobs"] == [[1.0] for _ in prompt_tokens]
@pytest.mark.asyncio