mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
add tests
This commit is contained in:
parent
8f304d44fd
commit
51088ac24d
1 changed files with 22 additions and 0 deletions
|
|
@ -292,6 +292,28 @@ async def test_get_logprobs_normalized_schema(mock_server):
|
||||||
assert payload["prompt_topk_logprobs"] == prompt_topk_logprobs
|
assert payload["prompt_topk_logprobs"] == prompt_topk_logprobs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_logprobs_strict_mode_rejects_misaligned_payload(mock_server):
|
||||||
|
"""ManagedServer.get_logprobs fails fast on malformed prompt top-k payload."""
|
||||||
|
managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer)
|
||||||
|
|
||||||
|
prompt = "Hello"
|
||||||
|
prompt_tokens = mock_server.tokenizer.encode(prompt)
|
||||||
|
|
||||||
|
async def _mock_get_logprobs(**kwargs):
|
||||||
|
return {
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
|
"prompt_topk_token_ids": [[tok] for tok in prompt_tokens],
|
||||||
|
# Missing one row on purpose -> misaligned with prompt length
|
||||||
|
"prompt_topk_logprobs": [[-0.1] for _ in prompt_tokens[:-1]],
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_server.get_logprobs = _mock_get_logprobs
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="align with prompt_tokens length"):
|
||||||
|
await managed.get_logprobs(prompt=prompt, n=1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reset_clears_sequences(mock_server):
|
async def test_reset_clears_sequences(mock_server):
|
||||||
"""Test that reset() clears all tracked sequences."""
|
"""Test that reset() clears all tracked sequences."""
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue