managed_Server pass through and centralize sem logic

This commit is contained in:
Jai Suphavadeeprasit 2026-03-05 15:46:33 -05:00
parent c85a3e5ee7
commit b91922082e
4 changed files with 208 additions and 26 deletions

View file

@ -292,6 +292,30 @@ async def test_get_logprobs_normalized_schema(mock_server):
assert payload["prompt_topk_logprobs"] == prompt_topk_logprobs
@pytest.mark.asyncio
async def test_get_logprobs_messages_passthrough(mock_server):
"""ManagedServer.get_logprobs converts messages and passes prompt through."""
managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer)
messages = [{"role": "user", "content": "Hello"}]
expected_prompt = managed._convert_messages_to_prompt(messages)
prompt_tokens = mock_server.tokenizer.encode(expected_prompt)
async def _mock_get_logprobs(**kwargs):
assert kwargs.get("prompt") == expected_prompt
return {
"prompt_tokens": prompt_tokens,
"prompt_topk_token_ids": [[t] for t in prompt_tokens],
"prompt_topk_logprobs": [[-0.1] for _ in prompt_tokens],
}
mock_server.get_logprobs = _mock_get_logprobs
payload = await managed.get_logprobs(messages=messages, top_k=1)
assert payload["prompt_tokens"] == prompt_tokens
assert len(payload["prompt_topk_token_ids"]) == len(prompt_tokens)
assert len(payload["prompt_topk_logprobs"]) == len(prompt_tokens)
@pytest.mark.asyncio
async def test_get_logprobs_strict_mode_requires_backend_impl(mock_server):
"""ManagedServer.get_logprobs requires backend get_logprobs in strict mode."""