diff --git a/atroposlib/tests/test_managed_server.py b/atroposlib/tests/test_managed_server.py index fefe414c..fe68e4fb 100644 --- a/atroposlib/tests/test_managed_server.py +++ b/atroposlib/tests/test_managed_server.py @@ -292,6 +292,28 @@ async def test_get_logprobs_normalized_schema(mock_server): assert payload["prompt_topk_logprobs"] == prompt_topk_logprobs +@pytest.mark.asyncio +async def test_get_logprobs_strict_mode_rejects_misaligned_payload(mock_server): + """ManagedServer.get_logprobs fails fast on malformed prompt top-k payload.""" + managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer) + + prompt = "Hello" + prompt_tokens = mock_server.tokenizer.encode(prompt) + + async def _mock_get_logprobs(**kwargs): + return { + "prompt_tokens": prompt_tokens, + "prompt_topk_token_ids": [[tok] for tok in prompt_tokens], + # Missing one row on purpose -> misaligned with prompt length + "prompt_topk_logprobs": [[-0.1] for _ in prompt_tokens[:-1]], + } + + mock_server.get_logprobs = _mock_get_logprobs + + with pytest.raises(ValueError, match="align with prompt_tokens length"): + await managed.get_logprobs(prompt=prompt, n=1) + + @pytest.mark.asyncio async def test_reset_clears_sequences(mock_server): """Test that reset() clears all tracked sequences."""