mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Merge conflict commit
This commit is contained in:
commit
f198c1738e
13 changed files with 579 additions and 14 deletions
|
|
@ -268,6 +268,91 @@ async def test_bos_token_handling(mock_server):
|
|||
assert mock_server.tokenizer.bos_token_id not in node.tokens[1:]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_logprobs_normalized_schema(mock_server):
|
||||
"""ManagedServer.get_logprobs returns normalized prompt schema."""
|
||||
managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer)
|
||||
|
||||
prompt = "Hello"
|
||||
prompt_tokens = mock_server.tokenizer.encode(prompt)
|
||||
prompt_topk_token_ids = [[t, t + 1] for t in prompt_tokens]
|
||||
prompt_topk_logprobs = [[-0.1, -0.2] for _ in prompt_tokens]
|
||||
|
||||
async def _mock_get_logprobs(**kwargs):
|
||||
assert kwargs.get("prompt") == prompt
|
||||
return {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"prompt_topk_token_ids": prompt_topk_token_ids,
|
||||
"prompt_topk_logprobs": prompt_topk_logprobs,
|
||||
}
|
||||
|
||||
mock_server.get_logprobs = _mock_get_logprobs
|
||||
|
||||
payload = await managed.get_logprobs(prompt=prompt, n=1)
|
||||
|
||||
assert payload["prompt_tokens"] == prompt_tokens
|
||||
assert payload["prompt_topk_token_ids"] == prompt_topk_token_ids
|
||||
assert payload["prompt_topk_logprobs"] == prompt_topk_logprobs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_logprobs_messages_passthrough(mock_server):
|
||||
"""ManagedServer.get_logprobs converts messages and passes prompt through."""
|
||||
managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer)
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
expected_prompt = managed._convert_messages_to_prompt(messages)
|
||||
prompt_tokens = mock_server.tokenizer.encode(expected_prompt)
|
||||
|
||||
async def _mock_get_logprobs(**kwargs):
|
||||
assert kwargs.get("prompt") == expected_prompt
|
||||
return {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"prompt_topk_token_ids": [[t] for t in prompt_tokens],
|
||||
"prompt_topk_logprobs": [[-0.1] for _ in prompt_tokens],
|
||||
}
|
||||
|
||||
mock_server.get_logprobs = _mock_get_logprobs
|
||||
payload = await managed.get_logprobs(messages=messages, top_k=1)
|
||||
|
||||
assert payload["prompt_tokens"] == prompt_tokens
|
||||
assert len(payload["prompt_topk_token_ids"]) == len(prompt_tokens)
|
||||
assert len(payload["prompt_topk_logprobs"]) == len(prompt_tokens)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_logprobs_input_ids_only_passthrough(mock_server):
|
||||
"""ManagedServer.get_logprobs supports input_ids-only without requiring prompt."""
|
||||
managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer)
|
||||
input_ids = [10, 20, 30]
|
||||
|
||||
async def _mock_get_logprobs(**kwargs):
|
||||
assert "input_ids" in kwargs
|
||||
assert kwargs["input_ids"] == input_ids
|
||||
assert kwargs.get("prompt") is None
|
||||
return {
|
||||
"prompt_tokens": input_ids,
|
||||
"prompt_topk_token_ids": [[t] for t in input_ids],
|
||||
"prompt_topk_logprobs": [[-0.1] for _ in input_ids],
|
||||
}
|
||||
|
||||
mock_server.get_logprobs = _mock_get_logprobs
|
||||
payload = await managed.get_logprobs(input_ids=input_ids, top_k=1)
|
||||
|
||||
assert payload["prompt_tokens"] == input_ids
|
||||
assert payload["prompt_topk_token_ids"] == [[10], [20], [30]]
|
||||
assert payload["prompt_topk_logprobs"] == [[-0.1], [-0.1], [-0.1]]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_logprobs_strict_mode_requires_backend_impl(mock_server):
|
||||
"""ManagedServer.get_logprobs requires backend get_logprobs in strict mode."""
|
||||
managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer)
|
||||
|
||||
prompt = "Hello"
|
||||
with pytest.raises(NotImplementedError, match="does not implement get_logprobs"):
|
||||
await managed.get_logprobs(prompt=prompt, n=1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_clears_sequences(mock_server):
|
||||
"""Test that reset() clears all tracked sequences."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue