init commit

This commit is contained in:
Jai Suphavadeeprasit 2026-03-03 11:32:09 -05:00
parent 887a94374c
commit b9291aa29f
5 changed files with 357 additions and 0 deletions

View file

@ -265,6 +265,34 @@ async def test_bos_token_handling(mock_server):
assert mock_server.tokenizer.bos_token_id not in node.tokens[1:]
@pytest.mark.asyncio
async def test_get_logprobs_normalized_schema(mock_server):
"""ManagedServer.get_logprobs returns normalized schema."""
managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer)
prompt = "Hello"
prompt_tokens = mock_server.tokenizer.encode(prompt)
output_tokens = [[ord("!"), ord("?")]]
output_logprobs = [[-0.1, -0.2]]
mock_server.set_tokens_and_logprobs_response(
prompt=prompt,
prompt_tokens=prompt_tokens,
output_tokens_list=output_tokens,
output_logprobs_list=output_logprobs,
finish_reasons=["stop"],
)
payload = await managed.get_logprobs(prompt=prompt, n=1)
assert payload["prompt_tokens"] == prompt_tokens
assert payload["sequence_token_ids"] == output_tokens
assert payload["sequence_logprobs"] == output_logprobs
assert payload["finish_reasons"] == ["stop"]
assert payload["sequence_topk_token_ids"] == [[[ord("!")], [ord("?")]]]
assert payload["sequence_topk_logprobs"] == [[[-0.1], [-0.2]]]
@pytest.mark.asyncio
async def test_reset_clears_sequences(mock_server):
"""Test that reset() clears all tracked sequences."""