mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
add tool call parsing based on vllm impl and an openai server endpoint
This commit is contained in:
parent
887a94374c
commit
add42a2afb
11 changed files with 3370 additions and 34 deletions
|
|
@ -5,19 +5,32 @@ def pytest_addoption(parser):
|
|||
parser.addoption(
|
||||
"--runproviders", action="store_true", default=False, help="run provider tests"
|
||||
)
|
||||
parser.addoption(
|
||||
"--run-gpu",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="run GPU integration tests",
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line(
|
||||
"markers", "providers: mark test as requires providers api keys to run"
|
||||
)
|
||||
config.addinivalue_line(
|
||||
"markers", "gpu: mark test as requiring GPU (skipped unless --run-gpu)"
|
||||
)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
if config.getoption("--runproviders"):
|
||||
# --runproviders given in cli: do not skip slow tests
|
||||
return
|
||||
skip_providers = pytest.mark.skip(reason="need --runproviders option to run")
|
||||
for item in items:
|
||||
if "providers" in item.keywords:
|
||||
item.add_marker(skip_providers)
|
||||
if not config.getoption("--runproviders"):
|
||||
skip_providers = pytest.mark.skip(reason="need --runproviders option to run")
|
||||
for item in items:
|
||||
if "providers" in item.keywords:
|
||||
item.add_marker(skip_providers)
|
||||
|
||||
if not config.getoption("--run-gpu"):
|
||||
skip_gpu = pytest.mark.skip(reason="need --run-gpu option to run")
|
||||
for item in items:
|
||||
if "gpu" in item.keywords:
|
||||
item.add_marker(skip_gpu)
|
||||
|
|
|
|||
|
|
@ -493,6 +493,295 @@ async def test_multi_turn_chat_with_branching(mock_server):
|
|||
assert f"More{actual_i}" in node.full_text # Has third turn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool call support in ManagedServer.chat_completion()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockTokenizerWithTools(MockTokenizer):
|
||||
"""Extended mock tokenizer that supports tools kwarg in apply_chat_template."""
|
||||
|
||||
def apply_chat_template(
|
||||
self, messages, tokenize=False, add_generation_prompt=True, tools=None
|
||||
):
|
||||
result = ""
|
||||
if tools:
|
||||
import json
|
||||
|
||||
result += f"<tools>{json.dumps(tools)}</tools>\n"
|
||||
for msg in messages:
|
||||
content = msg.get("content", "") or ""
|
||||
result += f"<{msg['role']}>{content}</{msg['role']}>"
|
||||
if add_generation_prompt:
|
||||
result += "<assistant>"
|
||||
if tokenize:
|
||||
return self.encode(result)
|
||||
return result
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_server_with_tools():
|
||||
"""Mock server with tool-aware tokenizer."""
|
||||
server = ServerHarness()
|
||||
server.tokenizer = MockTokenizerWithTools()
|
||||
|
||||
class Config:
|
||||
model_name = "test_model"
|
||||
|
||||
server.config = Config()
|
||||
return server
|
||||
|
||||
|
||||
def _setup_chat_completion(server, tokenizer, messages, output_texts, tools=None):
|
||||
"""Helper: set up mock tokens_and_logprobs for a chat_completion call."""
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True, tools=tools
|
||||
)
|
||||
prompt_tokens = tokenizer.encode(prompt)
|
||||
output_tokens_list = [[ord(c) for c in text] for text in output_texts]
|
||||
output_logprobs_list = [[-0.1] * len(tokens) for tokens in output_tokens_list]
|
||||
finish_reasons = ["stop"] * len(output_texts)
|
||||
|
||||
server.set_tokens_and_logprobs_response(
|
||||
prompt=prompt,
|
||||
prompt_tokens=prompt_tokens,
|
||||
output_tokens_list=output_tokens_list,
|
||||
output_logprobs_list=output_logprobs_list,
|
||||
finish_reasons=finish_reasons,
|
||||
)
|
||||
return prompt
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_parsing_outbound(mock_server_with_tools):
|
||||
"""Model generates <tool_call> → chat_completion returns structured tool_calls."""
|
||||
managed = ManagedServer(
|
||||
mock_server_with_tools,
|
||||
tokenizer=mock_server_with_tools.tokenizer,
|
||||
tool_parser="hermes",
|
||||
)
|
||||
|
||||
tools = [{"type": "function", "function": {"name": "search", "parameters": {}}}]
|
||||
messages = [{"role": "user", "content": "Search cats"}]
|
||||
raw_output = (
|
||||
'<tool_call>{"name": "search", "arguments": {"query": "cats"}}</tool_call>'
|
||||
)
|
||||
|
||||
_setup_chat_completion(
|
||||
mock_server_with_tools,
|
||||
mock_server_with_tools.tokenizer,
|
||||
messages,
|
||||
[raw_output],
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
result = await managed.chat_completion(
|
||||
messages=messages, tools=tools, tool_choice="auto"
|
||||
)
|
||||
|
||||
assert len(result.choices) == 1
|
||||
choice = result.choices[0]
|
||||
assert choice.finish_reason == "tool_calls"
|
||||
assert choice.message.tool_calls is not None
|
||||
assert len(choice.message.tool_calls) == 1
|
||||
tc = choice.message.tool_calls[0]
|
||||
assert tc["function"]["name"] == "search"
|
||||
|
||||
# Node should have raw text (not parsed)
|
||||
state = managed.get_state()
|
||||
assert len(state["nodes"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_choice_none_skips(mock_server_with_tools):
|
||||
"""tool_choice='none' returns raw text, no parsing."""
|
||||
managed = ManagedServer(
|
||||
mock_server_with_tools,
|
||||
tokenizer=mock_server_with_tools.tokenizer,
|
||||
tool_parser="hermes",
|
||||
)
|
||||
|
||||
tools = [{"type": "function", "function": {"name": "search", "parameters": {}}}]
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
raw_output = '<tool_call>{"name": "search", "arguments": {"q": "x"}}</tool_call>'
|
||||
|
||||
_setup_chat_completion(
|
||||
mock_server_with_tools,
|
||||
mock_server_with_tools.tokenizer,
|
||||
messages,
|
||||
[raw_output],
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
result = await managed.chat_completion(
|
||||
messages=messages, tools=tools, tool_choice="none"
|
||||
)
|
||||
|
||||
assert result.choices[0].message.tool_calls is None
|
||||
assert result.choices[0].finish_reason == "stop"
|
||||
# Raw text should be content
|
||||
assert "<tool_call>" in result.choices[0].message.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_tool_parser_passes_through(mock_server_with_tools):
|
||||
"""Without tool_parser, tools kwarg is ignored — no parsing."""
|
||||
managed = ManagedServer(
|
||||
mock_server_with_tools,
|
||||
tokenizer=mock_server_with_tools.tokenizer,
|
||||
# No tool_parser
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
raw_output = '<tool_call>{"name": "search", "arguments": {"q": "x"}}</tool_call>'
|
||||
|
||||
_setup_chat_completion(
|
||||
mock_server_with_tools, mock_server_with_tools.tokenizer, messages, [raw_output]
|
||||
)
|
||||
|
||||
result = await managed.chat_completion(messages=messages)
|
||||
|
||||
# No tool parsing — raw text as content
|
||||
assert result.choices[0].message.tool_calls is None
|
||||
assert result.choices[0].finish_reason == "stop"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_multi_turn_extends_node(mock_server_with_tools):
|
||||
"""Multi-turn with tool calls should extend to 1 node."""
|
||||
managed = ManagedServer(
|
||||
mock_server_with_tools,
|
||||
tokenizer=mock_server_with_tools.tokenizer,
|
||||
tool_parser="hermes",
|
||||
)
|
||||
tok = mock_server_with_tools.tokenizer
|
||||
tools = [{"type": "function", "function": {"name": "search", "parameters": {}}}]
|
||||
|
||||
# Step 1: user → tool_call
|
||||
messages_1 = [{"role": "user", "content": "Search cats"}]
|
||||
output_1 = '<tool_call>{"name": "search", "arguments": {"q": "cats"}}</tool_call>'
|
||||
_setup_chat_completion(
|
||||
mock_server_with_tools, tok, messages_1, [output_1], tools=tools
|
||||
)
|
||||
|
||||
result_1 = await managed.chat_completion(
|
||||
messages=messages_1, tools=tools, tool_choice="auto"
|
||||
)
|
||||
tc_1 = result_1.choices[0].message.tool_calls
|
||||
|
||||
assert len(managed.get_state()["nodes"]) == 1
|
||||
|
||||
# Step 2: include tool result → plain response
|
||||
# Reconstruct the assistant message with tool_calls for the translator
|
||||
messages_2 = [
|
||||
{"role": "user", "content": "Search cats"},
|
||||
{"role": "assistant", "content": None, "tool_calls": tc_1},
|
||||
{"role": "tool", "tool_call_id": tc_1[0]["id"], "content": "Found 5 cats"},
|
||||
]
|
||||
|
||||
# The translator will reconstruct the tool_call to raw text,
|
||||
# so we need the prompt to match what it produces
|
||||
output_2 = "Here are 5 cats!"
|
||||
prompt_2 = tok.apply_chat_template(
|
||||
managed._get_translator().convert_messages_for_template(messages_2),
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
tools=tools,
|
||||
)
|
||||
prompt_tokens_2 = tok.encode(prompt_2)
|
||||
output_tokens_2 = [ord(c) for c in output_2]
|
||||
mock_server_with_tools.set_tokens_and_logprobs_response(
|
||||
prompt=prompt_2,
|
||||
prompt_tokens=prompt_tokens_2,
|
||||
output_tokens_list=[output_tokens_2],
|
||||
output_logprobs_list=[[-0.1] * len(output_tokens_2)],
|
||||
finish_reasons=["stop"],
|
||||
)
|
||||
|
||||
result_2 = await managed.chat_completion(
|
||||
messages=messages_2, tools=tools, tool_choice="auto"
|
||||
)
|
||||
assert result_2.choices[0].message.content == output_2
|
||||
|
||||
# Still 1 node — step 2 extended step 1
|
||||
assert len(managed.get_state()["nodes"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_multiple_tools_parsed(mock_server_with_tools):
|
||||
"""Multiple tool calls in one response are all parsed."""
|
||||
managed = ManagedServer(
|
||||
mock_server_with_tools,
|
||||
tokenizer=mock_server_with_tools.tokenizer,
|
||||
tool_parser="hermes",
|
||||
)
|
||||
|
||||
tools = [
|
||||
{"type": "function", "function": {"name": "get_weather", "parameters": {}}},
|
||||
{"type": "function", "function": {"name": "get_time", "parameters": {}}},
|
||||
]
|
||||
messages = [{"role": "user", "content": "Weather and time?"}]
|
||||
raw_output = (
|
||||
'<tool_call>{"name": "get_weather", "arguments": {"city": "SF"}}</tool_call>\n'
|
||||
'<tool_call>{"name": "get_time", "arguments": {"tz": "PST"}}</tool_call>'
|
||||
)
|
||||
_setup_chat_completion(
|
||||
mock_server_with_tools,
|
||||
mock_server_with_tools.tokenizer,
|
||||
messages,
|
||||
[raw_output],
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
result = await managed.chat_completion(
|
||||
messages=messages, tools=tools, tool_choice="auto"
|
||||
)
|
||||
|
||||
assert result.choices[0].finish_reason == "tool_calls"
|
||||
assert len(result.choices[0].message.tool_calls) == 2
|
||||
names = {tc["function"]["name"] for tc in result.choices[0].message.tool_calls}
|
||||
assert names == {"get_weather", "get_time"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_node_masking(mock_server_with_tools):
|
||||
"""Nodes have proper masking even with tool parsing active."""
|
||||
managed = ManagedServer(
|
||||
mock_server_with_tools,
|
||||
tokenizer=mock_server_with_tools.tokenizer,
|
||||
tool_parser="hermes",
|
||||
)
|
||||
|
||||
tools = [{"type": "function", "function": {"name": "search", "parameters": {}}}]
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
raw_output = '<tool_call>{"name": "search", "arguments": {"q": "x"}}</tool_call>'
|
||||
|
||||
_setup_chat_completion(
|
||||
mock_server_with_tools,
|
||||
mock_server_with_tools.tokenizer,
|
||||
messages,
|
||||
[raw_output],
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
await managed.chat_completion(messages=messages, tools=tools)
|
||||
|
||||
node = managed.get_state()["nodes"][0]
|
||||
|
||||
# Lengths must match
|
||||
assert len(node.tokens) == len(node.masked_tokens) == len(node.logprobs)
|
||||
|
||||
# Should have masked prompt tokens and actual completion tokens
|
||||
num_masked = sum(1 for t in node.masked_tokens if t == -100)
|
||||
num_actual = sum(1 for t in node.masked_tokens if t != -100)
|
||||
assert num_masked > 0
|
||||
assert num_actual > 0
|
||||
|
||||
# Prompt logprobs = 1.0, completion logprobs < 0
|
||||
assert all(lp == 1.0 for lp in node.logprobs[:num_masked])
|
||||
assert all(lp < 0 for lp in node.logprobs[num_masked:])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run tests
|
||||
pytest.main([__file__, "-v"])
|
||||
|
|
|
|||
852
atroposlib/tests/test_managed_server_proxy.py
Normal file
852
atroposlib/tests/test_managed_server_proxy.py
Normal file
|
|
@ -0,0 +1,852 @@
|
|||
"""Mock-based tests for the ManagedServer OpenAI proxy.
|
||||
|
||||
Uses ServerHarness as the backend — no real model or GPU needed.
|
||||
Tests the full HTTP layer: session management, chat completions,
|
||||
tool call translation, render endpoint, nodes, cleanup.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from atroposlib.envs.server_handling.managed_server_proxy import create_app
|
||||
from atroposlib.envs.server_handling.server_harness import ServerHarness
|
||||
from atroposlib.envs.server_handling.server_manager import ServerManager
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock tokenizer (same as test_managed_server.py / test_tool_call_translator.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockTokenizer:
|
||||
def __init__(self):
|
||||
self.eos_token_id = 2
|
||||
self.bos_token_id = 1
|
||||
|
||||
def encode(self, text, add_special_tokens=True):
|
||||
tokens = [ord(c) for c in text]
|
||||
if add_special_tokens:
|
||||
tokens = [self.bos_token_id] + tokens
|
||||
return tokens
|
||||
|
||||
def decode(self, tokens, skip_special_tokens=False):
|
||||
if skip_special_tokens:
|
||||
tokens = [
|
||||
t for t in tokens if t not in [self.bos_token_id, self.eos_token_id]
|
||||
]
|
||||
return "".join([chr(t) if t > 31 else "" for t in tokens])
|
||||
|
||||
def get_vocab(self):
|
||||
return {chr(i): i for i in range(128)}
|
||||
|
||||
def apply_chat_template(
|
||||
self, messages, tokenize=False, add_generation_prompt=True, tools=None
|
||||
):
|
||||
result = ""
|
||||
if tools:
|
||||
result += f"<tools>{json.dumps(tools)}</tools>\n"
|
||||
for msg in messages:
|
||||
content = msg.get("content", "") or ""
|
||||
result += f"<{msg['role']}>{content}</{msg['role']}>"
|
||||
if add_generation_prompt:
|
||||
result += "<assistant>"
|
||||
if tokenize:
|
||||
return self.encode(result)
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_backend():
|
||||
"""Create a mock server backend with tokenizer."""
|
||||
server = ServerHarness()
|
||||
server.tokenizer = MockTokenizer()
|
||||
# ServerManager's _select_server checks these attributes
|
||||
server.server_healthy = True
|
||||
|
||||
class Config:
|
||||
model_name = "test_model"
|
||||
|
||||
server.config = Config()
|
||||
return server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server_manager(mock_backend):
|
||||
"""Create a ServerManager wrapping the mock backend."""
|
||||
# Can't use ServerManager constructor with empty configs, so build manually
|
||||
mgr = object.__new__(ServerManager)
|
||||
mgr.max_n_completions = 8
|
||||
mgr.reasoning_config = None
|
||||
mgr.servers = [mock_backend]
|
||||
return mgr
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(server_manager):
|
||||
"""Create a test client for the proxy app."""
|
||||
tokenizer = MockTokenizer()
|
||||
app = create_app(
|
||||
server_manager=server_manager,
|
||||
tokenizer=tokenizer,
|
||||
model_name="test_model",
|
||||
)
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client_and_backend(mock_backend, server_manager):
|
||||
"""Return both client and backend for tests that need to set up mock responses."""
|
||||
tokenizer = MockTokenizer()
|
||||
app = create_app(
|
||||
server_manager=server_manager,
|
||||
tokenizer=tokenizer,
|
||||
model_name="test_model",
|
||||
)
|
||||
return TestClient(app), mock_backend, tokenizer
|
||||
|
||||
|
||||
def _setup_completion(
|
||||
backend, tokenizer, prompt_text, output_texts, finish_reasons=None
|
||||
):
|
||||
"""Helper to set up a mock tokens_and_logprobs response."""
|
||||
prompt_tokens = tokenizer.encode(prompt_text)
|
||||
output_tokens_list = [[ord(c) for c in text] for text in output_texts]
|
||||
output_logprobs_list = [[-0.1] * len(tokens) for tokens in output_tokens_list]
|
||||
if finish_reasons is None:
|
||||
finish_reasons = ["stop"] * len(output_texts)
|
||||
|
||||
backend.set_tokens_and_logprobs_response(
|
||||
prompt=prompt_text,
|
||||
prompt_tokens=prompt_tokens,
|
||||
output_tokens_list=output_tokens_list,
|
||||
output_logprobs_list=output_logprobs_list,
|
||||
finish_reasons=finish_reasons,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Health / Models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHealth:
|
||||
def test_health(self, client):
|
||||
resp = client.get("/health")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["model"] == "test_model"
|
||||
assert data["sessions"] == 0
|
||||
|
||||
def test_models(self, client):
|
||||
resp = client.get("/v1/models")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["object"] == "list"
|
||||
assert len(data["data"]) == 1
|
||||
assert data["data"][0]["id"] == "test_model"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session Management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSessionManagement:
|
||||
def test_create_session(self, client):
|
||||
resp = client.post("/sessions/create", json={})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "uuid" in data
|
||||
assert data["model_name"] == "test_model"
|
||||
assert data["tool_parser"] == "hermes"
|
||||
|
||||
def test_create_session_custom_parser(self, client):
|
||||
resp = client.post("/sessions/create", json={"tool_parser": "hermes"})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["tool_parser"] == "hermes"
|
||||
|
||||
def test_list_sessions(self, client):
|
||||
# Create 3 sessions
|
||||
uuids = []
|
||||
for _ in range(3):
|
||||
resp = client.post("/sessions/create", json={})
|
||||
uuids.append(resp.json()["uuid"])
|
||||
|
||||
resp = client.get("/sessions")
|
||||
assert resp.status_code == 200
|
||||
sessions = resp.json()["sessions"]
|
||||
assert len(sessions) == 3
|
||||
listed_uuids = {s["uuid"] for s in sessions}
|
||||
assert listed_uuids == set(uuids)
|
||||
|
||||
def test_delete_session(self, client):
|
||||
resp = client.post("/sessions/create", json={})
|
||||
uuid = resp.json()["uuid"]
|
||||
|
||||
resp = client.delete(f"/{uuid}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "deleted"
|
||||
|
||||
# Should be gone
|
||||
resp = client.get(f"/{uuid}/nodes")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_delete_nonexistent_session(self, client):
|
||||
resp = client.delete("/nonexistent-uuid")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_session_not_found(self, client):
|
||||
resp = client.post(
|
||||
"/nonexistent-uuid/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hi"}]},
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Chat Completions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestChatCompletions:
|
||||
def test_basic_completion(self, client_and_backend):
|
||||
client, backend, tokenizer = client_and_backend
|
||||
|
||||
# Create session
|
||||
resp = client.post("/sessions/create", json={})
|
||||
uuid = resp.json()["uuid"]
|
||||
|
||||
# Set up mock response
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
prompt_text = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
_setup_completion(backend, tokenizer, prompt_text, ["Hi there!"])
|
||||
|
||||
# Make request
|
||||
resp = client.post(
|
||||
f"/{uuid}/v1/chat/completions",
|
||||
json={"messages": messages, "max_tokens": 100},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
|
||||
assert data["object"] == "chat.completion"
|
||||
assert data["model"] == "test_model"
|
||||
assert len(data["choices"]) == 1
|
||||
assert data["choices"][0]["message"]["role"] == "assistant"
|
||||
assert data["choices"][0]["message"]["content"] == "Hi there!"
|
||||
assert data["choices"][0]["finish_reason"] == "stop"
|
||||
assert data["id"].startswith("chatcmpl-")
|
||||
|
||||
def test_completion_with_n(self, client_and_backend):
|
||||
client, backend, tokenizer = client_and_backend
|
||||
|
||||
resp = client.post("/sessions/create", json={})
|
||||
uuid = resp.json()["uuid"]
|
||||
|
||||
messages = [{"role": "user", "content": "Pick a number"}]
|
||||
prompt_text = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
_setup_completion(backend, tokenizer, prompt_text, ["One", "Two", "Three"])
|
||||
|
||||
resp = client.post(
|
||||
f"/{uuid}/v1/chat/completions",
|
||||
json={"messages": messages, "n": 3, "max_tokens": 50},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["choices"]) == 3
|
||||
|
||||
def test_empty_messages_error(self, client):
|
||||
resp = client.post("/sessions/create", json={})
|
||||
uuid = resp.json()["uuid"]
|
||||
|
||||
resp = client.post(
|
||||
f"/{uuid}/v1/chat/completions",
|
||||
json={"messages": []},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_completion_with_system_prompt(self, client_and_backend):
|
||||
client, backend, tokenizer = client_and_backend
|
||||
|
||||
resp = client.post("/sessions/create", json={})
|
||||
uuid = resp.json()["uuid"]
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hi"},
|
||||
]
|
||||
prompt_text = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
_setup_completion(backend, tokenizer, prompt_text, ["Hello!"])
|
||||
|
||||
resp = client.post(
|
||||
f"/{uuid}/v1/chat/completions",
|
||||
json={"messages": messages},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["choices"][0]["message"]["content"] == "Hello!"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool Call Handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestToolCalls:
|
||||
def test_tool_call_outbound(self, client_and_backend):
|
||||
"""Model generates <tool_call> tags → response has structured tool_calls."""
|
||||
client, backend, tokenizer = client_and_backend
|
||||
|
||||
resp = client.post("/sessions/create", json={})
|
||||
uuid = resp.json()["uuid"]
|
||||
|
||||
tools = [{"type": "function", "function": {"name": "search", "parameters": {}}}]
|
||||
messages = [{"role": "user", "content": "Search cats"}]
|
||||
prompt_text = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True, tools=tools
|
||||
)
|
||||
|
||||
# Model output includes tool call tags
|
||||
raw_output = (
|
||||
'<tool_call>{"name": "search", "arguments": {"query": "cats"}}</tool_call>'
|
||||
)
|
||||
_setup_completion(backend, tokenizer, prompt_text, [raw_output])
|
||||
|
||||
resp = client.post(
|
||||
f"/{uuid}/v1/chat/completions",
|
||||
json={"messages": messages, "tools": tools},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
choice = data["choices"][0]
|
||||
|
||||
assert choice["finish_reason"] == "tool_calls"
|
||||
assert "tool_calls" in choice["message"]
|
||||
assert len(choice["message"]["tool_calls"]) == 1
|
||||
tc = choice["message"]["tool_calls"][0]
|
||||
assert tc["function"]["name"] == "search"
|
||||
assert json.loads(tc["function"]["arguments"]) == {"query": "cats"}
|
||||
|
||||
def test_tool_choice_none(self, client_and_backend):
|
||||
"""tool_choice=none → no parsing, raw text returned."""
|
||||
client, backend, tokenizer = client_and_backend
|
||||
|
||||
resp = client.post("/sessions/create", json={})
|
||||
uuid = resp.json()["uuid"]
|
||||
|
||||
tools = [{"type": "function", "function": {"name": "search", "parameters": {}}}]
|
||||
messages = [{"role": "user", "content": "Search cats"}]
|
||||
prompt_text = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True, tools=tools
|
||||
)
|
||||
|
||||
raw_output = (
|
||||
'<tool_call>{"name": "search", "arguments": {"query": "cats"}}</tool_call>'
|
||||
)
|
||||
_setup_completion(backend, tokenizer, prompt_text, [raw_output])
|
||||
|
||||
resp = client.post(
|
||||
f"/{uuid}/v1/chat/completions",
|
||||
json={"messages": messages, "tools": tools, "tool_choice": "none"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
choice = resp.json()["choices"][0]
|
||||
|
||||
# Should NOT have tool_calls since tool_choice is "none"
|
||||
assert choice["finish_reason"] == "stop"
|
||||
assert (
|
||||
"tool_calls" not in choice["message"]
|
||||
or choice["message"].get("tool_calls") is None
|
||||
)
|
||||
|
||||
def test_nodes_preserve_raw_text(self, client_and_backend):
|
||||
"""ManagedServer nodes should have raw text, not parsed tool_calls."""
|
||||
client, backend, tokenizer = client_and_backend
|
||||
|
||||
resp = client.post("/sessions/create", json={})
|
||||
uuid = resp.json()["uuid"]
|
||||
|
||||
tools = [{"type": "function", "function": {"name": "search", "parameters": {}}}]
|
||||
messages = [{"role": "user", "content": "Search cats"}]
|
||||
prompt_text = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True, tools=tools
|
||||
)
|
||||
|
||||
raw_output = (
|
||||
'<tool_call>{"name": "search", "arguments": {"query": "cats"}}</tool_call>'
|
||||
)
|
||||
_setup_completion(backend, tokenizer, prompt_text, [raw_output])
|
||||
|
||||
client.post(
|
||||
f"/{uuid}/v1/chat/completions",
|
||||
json={"messages": messages, "tools": tools},
|
||||
)
|
||||
|
||||
# Check nodes — should have the raw tokens, not parsed
|
||||
resp = client.get(f"/{uuid}/nodes")
|
||||
assert resp.status_code == 200
|
||||
nodes = resp.json()["nodes"]
|
||||
assert len(nodes) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Render Endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRender:
|
||||
def test_render_basic(self, client):
|
||||
resp = client.post("/sessions/create", json={})
|
||||
uuid = resp.json()["uuid"]
|
||||
|
||||
resp = client.post(
|
||||
f"/{uuid}/v1/chat/completions/render",
|
||||
json={"messages": [{"role": "user", "content": "Hello"}]},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "prompt_text" in data
|
||||
assert "token_ids" in data
|
||||
assert "num_tokens" in data
|
||||
assert data["num_tokens"] == len(data["token_ids"])
|
||||
assert "<user>Hello</user>" in data["prompt_text"]
|
||||
|
||||
def test_render_with_tools(self, client):
|
||||
resp = client.post("/sessions/create", json={})
|
||||
uuid = resp.json()["uuid"]
|
||||
|
||||
tools = [{"type": "function", "function": {"name": "search"}}]
|
||||
resp = client.post(
|
||||
f"/{uuid}/v1/chat/completions/render",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
"tools": tools,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
# Tool definitions should appear in the rendered prompt
|
||||
assert "search" in data["prompt_text"]
|
||||
|
||||
def test_render_does_not_create_nodes(self, client):
|
||||
"""Render should not cause any generation or node creation."""
|
||||
resp = client.post("/sessions/create", json={})
|
||||
uuid = resp.json()["uuid"]
|
||||
|
||||
client.post(
|
||||
f"/{uuid}/v1/chat/completions/render",
|
||||
json={"messages": [{"role": "user", "content": "Hi"}]},
|
||||
)
|
||||
|
||||
resp = client.get(f"/{uuid}/nodes")
|
||||
assert resp.json()["nodes"] == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Nodes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNodes:
|
||||
def test_get_nodes_empty(self, client):
|
||||
resp = client.post("/sessions/create", json={})
|
||||
uuid = resp.json()["uuid"]
|
||||
|
||||
resp = client.get(f"/{uuid}/nodes")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["nodes"] == []
|
||||
|
||||
def test_get_nodes_after_completion(self, client_and_backend):
|
||||
client, backend, tokenizer = client_and_backend
|
||||
|
||||
resp = client.post("/sessions/create", json={})
|
||||
uuid = resp.json()["uuid"]
|
||||
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
prompt_text = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
_setup_completion(backend, tokenizer, prompt_text, ["Hello!"])
|
||||
|
||||
client.post(
|
||||
f"/{uuid}/v1/chat/completions",
|
||||
json={"messages": messages},
|
||||
)
|
||||
|
||||
resp = client.get(f"/{uuid}/nodes")
|
||||
assert resp.status_code == 200
|
||||
nodes = resp.json()["nodes"]
|
||||
assert len(nodes) == 1
|
||||
|
||||
node = nodes[0]
|
||||
assert "tokens" in node
|
||||
assert "masked_tokens" in node
|
||||
assert "logprobs" in node
|
||||
assert "full_text" in node
|
||||
assert (
|
||||
len(node["tokens"]) == len(node["masked_tokens"]) == len(node["logprobs"])
|
||||
)
|
||||
|
||||
def test_nodes_have_proper_masking(self, client_and_backend):
|
||||
client, backend, tokenizer = client_and_backend
|
||||
|
||||
resp = client.post("/sessions/create", json={})
|
||||
uuid = resp.json()["uuid"]
|
||||
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
prompt_text = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
prompt_tokens = tokenizer.encode(prompt_text)
|
||||
prompt_len = len(prompt_tokens)
|
||||
|
||||
_setup_completion(backend, tokenizer, prompt_text, ["Hello!"])
|
||||
|
||||
client.post(
|
||||
f"/{uuid}/v1/chat/completions",
|
||||
json={"messages": messages},
|
||||
)
|
||||
|
||||
resp = client.get(f"/{uuid}/nodes")
|
||||
node = resp.json()["nodes"][0]
|
||||
|
||||
# Prompt tokens should be masked with -100
|
||||
assert all(t == -100 for t in node["masked_tokens"][:prompt_len])
|
||||
# Prompt logprobs should be 1.0
|
||||
assert all(lp == 1.0 for lp in node["logprobs"][:prompt_len])
|
||||
# Completion logprobs should be actual values (negative)
|
||||
assert all(lp < 0 for lp in node["logprobs"][prompt_len:])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Deep multi-step node handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMultiStepNodeHandling:
|
||||
"""Test that multi-step conversations with tool calls produce exactly 1 node.
|
||||
|
||||
Simulates a realistic 10+ message agentic conversation:
|
||||
user → assistant(tool_call) → tool_result → assistant(text) →
|
||||
user → assistant(tool_call) → tool_result → assistant(tool_call) →
|
||||
tool_result → assistant(text) → user → assistant(text)
|
||||
|
||||
Each step extends the previous node, so we should end up with exactly
|
||||
1 node containing the full tokenized conversation.
|
||||
"""
|
||||
|
||||
def _do_step(
|
||||
self,
|
||||
client,
|
||||
backend,
|
||||
tokenizer,
|
||||
uuid,
|
||||
messages,
|
||||
output_text,
|
||||
tools=None,
|
||||
expect_tool_calls=False,
|
||||
):
|
||||
"""Helper: use render endpoint to get exact prompt, set up mock, call endpoint."""
|
||||
body = {"messages": messages, "max_tokens": 200}
|
||||
if tools:
|
||||
body["tools"] = tools
|
||||
|
||||
# Use the render endpoint to get the exact prompt the proxy will generate
|
||||
# (this includes tool_call reconstruction through the translator)
|
||||
render_resp = client.post(f"/{uuid}/v1/chat/completions/render", json=body)
|
||||
assert render_resp.status_code == 200, f"Render failed: {render_resp.json()}"
|
||||
prompt_text = render_resp.json()["prompt_text"]
|
||||
|
||||
_setup_completion(backend, tokenizer, prompt_text, [output_text])
|
||||
|
||||
resp = client.post(f"/{uuid}/v1/chat/completions", json=body)
|
||||
assert resp.status_code == 200, f"Step failed: {resp.json()}"
|
||||
return resp.json()
|
||||
|
||||
def test_10_message_conversation_one_node(self, client_and_backend):
|
||||
"""Full 10-message conversation with tool calls → exactly 1 node."""
|
||||
client, backend, tokenizer = client_and_backend
|
||||
|
||||
resp = client.post("/sessions/create", json={})
|
||||
uuid = resp.json()["uuid"]
|
||||
|
||||
tools = [
|
||||
{"type": "function", "function": {"name": "get_weather", "parameters": {}}},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {"name": "get_forecast", "parameters": {}},
|
||||
},
|
||||
]
|
||||
|
||||
# -- Step 1: user asks about weather --
|
||||
messages = [{"role": "user", "content": "What's the weather in SF?"}]
|
||||
output_1 = '<tool_call>{"name": "get_weather", "arguments": {"city": "SF"}}</tool_call>'
|
||||
data = self._do_step(
|
||||
client, backend, tokenizer, uuid, messages, output_1, tools=tools
|
||||
)
|
||||
assert data["choices"][0]["finish_reason"] == "tool_calls"
|
||||
tc_1 = data["choices"][0]["message"]["tool_calls"]
|
||||
|
||||
# Check: 1 node so far
|
||||
nodes = client.get(f"/{uuid}/nodes").json()["nodes"]
|
||||
assert len(nodes) == 1, f"Expected 1 node after step 1, got {len(nodes)}"
|
||||
|
||||
# -- Step 2: tool result --
|
||||
messages = [
|
||||
{"role": "user", "content": "What's the weather in SF?"},
|
||||
{"role": "assistant", "content": None, "tool_calls": tc_1},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tc_1[0]["id"],
|
||||
"content": "72°F and sunny",
|
||||
},
|
||||
]
|
||||
output_2 = "The weather in SF is 72°F and sunny! Want the forecast too?"
|
||||
self._do_step(client, backend, tokenizer, uuid, messages, output_2, tools=tools)
|
||||
|
||||
nodes = client.get(f"/{uuid}/nodes").json()["nodes"]
|
||||
assert len(nodes) == 1, f"Expected 1 node after step 2, got {len(nodes)}"
|
||||
|
||||
# -- Step 3: user says yes --
|
||||
messages.extend(
|
||||
[
|
||||
{"role": "assistant", "content": output_2},
|
||||
{"role": "user", "content": "Yes please, get the forecast"},
|
||||
]
|
||||
)
|
||||
output_3 = '<tool_call>{"name": "get_forecast", "arguments": {"city": "SF"}}</tool_call>'
|
||||
data = self._do_step(
|
||||
client, backend, tokenizer, uuid, messages, output_3, tools=tools
|
||||
)
|
||||
tc_3 = data["choices"][0]["message"]["tool_calls"]
|
||||
|
||||
nodes = client.get(f"/{uuid}/nodes").json()["nodes"]
|
||||
assert len(nodes) == 1, f"Expected 1 node after step 3, got {len(nodes)}"
|
||||
|
||||
# -- Step 4: forecast tool result --
|
||||
messages.extend(
|
||||
[
|
||||
{"role": "assistant", "content": None, "tool_calls": tc_3},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tc_3[0]["id"],
|
||||
"content": "Rain expected tomorrow",
|
||||
},
|
||||
]
|
||||
)
|
||||
output_4 = "The forecast says rain is expected tomorrow in SF."
|
||||
self._do_step(client, backend, tokenizer, uuid, messages, output_4, tools=tools)
|
||||
|
||||
nodes = client.get(f"/{uuid}/nodes").json()["nodes"]
|
||||
assert len(nodes) == 1, f"Expected 1 node after step 4, got {len(nodes)}"
|
||||
|
||||
# -- Step 5: user asks about another city --
|
||||
messages.extend(
|
||||
[
|
||||
{"role": "assistant", "content": output_4},
|
||||
{"role": "user", "content": "What about NYC?"},
|
||||
]
|
||||
)
|
||||
output_5 = '<tool_call>{"name": "get_weather", "arguments": {"city": "NYC"}}</tool_call>'
|
||||
data = self._do_step(
|
||||
client, backend, tokenizer, uuid, messages, output_5, tools=tools
|
||||
)
|
||||
tc_5 = data["choices"][0]["message"]["tool_calls"]
|
||||
|
||||
nodes = client.get(f"/{uuid}/nodes").json()["nodes"]
|
||||
assert len(nodes) == 1, f"Expected 1 node after step 5, got {len(nodes)}"
|
||||
|
||||
# -- Step 6: NYC tool result --
|
||||
messages.extend(
|
||||
[
|
||||
{"role": "assistant", "content": None, "tool_calls": tc_5},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tc_5[0]["id"],
|
||||
"content": "55°F and cloudy",
|
||||
},
|
||||
]
|
||||
)
|
||||
output_6 = "NYC is 55°F and cloudy. Quite different from SF!"
|
||||
self._do_step(client, backend, tokenizer, uuid, messages, output_6, tools=tools)
|
||||
|
||||
# -- FINAL CHECK: still exactly 1 node after 6 completions / 12+ messages --
|
||||
nodes = client.get(f"/{uuid}/nodes").json()["nodes"]
|
||||
assert (
|
||||
len(nodes) == 1
|
||||
), f"Expected 1 node after full conversation, got {len(nodes)}"
|
||||
|
||||
# Verify the node has proper structure
|
||||
node = nodes[0]
|
||||
assert (
|
||||
len(node["tokens"]) == len(node["masked_tokens"]) == len(node["logprobs"])
|
||||
)
|
||||
assert len(node["tokens"]) > 0
|
||||
|
||||
# Verify masking: there should be SOME -100 (prompt) and SOME actual tokens
|
||||
num_masked = sum(1 for t in node["masked_tokens"] if t == -100)
|
||||
num_actual = sum(1 for t in node["masked_tokens"] if t != -100)
|
||||
assert num_masked > 0, "Should have masked prompt tokens"
|
||||
assert num_actual > 0, "Should have unmasked completion tokens"
|
||||
|
||||
def test_plain_multi_turn_no_tools_one_node(self, client_and_backend):
|
||||
"""5-turn conversation without tools → exactly 1 node."""
|
||||
client, backend, tokenizer = client_and_backend
|
||||
|
||||
resp = client.post("/sessions/create", json={})
|
||||
uuid = resp.json()["uuid"]
|
||||
|
||||
conversation = []
|
||||
|
||||
for i in range(5):
|
||||
# Add user message
|
||||
conversation.append({"role": "user", "content": f"Turn {i+1} question"})
|
||||
|
||||
prompt_text = tokenizer.apply_chat_template(
|
||||
conversation, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
output = f"Response to turn {i+1}"
|
||||
_setup_completion(backend, tokenizer, prompt_text, [output])
|
||||
|
||||
resp = client.post(
|
||||
f"/{uuid}/v1/chat/completions",
|
||||
json={"messages": conversation},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Add assistant response for next turn
|
||||
conversation.append({"role": "assistant", "content": output})
|
||||
|
||||
# After 5 turns (10 messages), should still be 1 node
|
||||
nodes = client.get(f"/{uuid}/nodes").json()["nodes"]
|
||||
assert len(nodes) == 1, f"Expected 1 node after 5 turns, got {len(nodes)}"
|
||||
|
||||
def test_tool_then_plain_then_tool_one_node(self, client_and_backend):
|
||||
"""Mixed: tool call → plain text → tool call → plain → exactly 1 node."""
|
||||
client, backend, tokenizer = client_and_backend
|
||||
|
||||
resp = client.post("/sessions/create", json={})
|
||||
uuid = resp.json()["uuid"]
|
||||
|
||||
tools = [{"type": "function", "function": {"name": "search", "parameters": {}}}]
|
||||
|
||||
# Step 1: tool call
|
||||
messages = [{"role": "user", "content": "Search for cats"}]
|
||||
output = '<tool_call>{"name": "search", "arguments": {"q": "cats"}}</tool_call>'
|
||||
data = self._do_step(
|
||||
client, backend, tokenizer, uuid, messages, output, tools=tools
|
||||
)
|
||||
tc = data["choices"][0]["message"]["tool_calls"]
|
||||
|
||||
# Step 2: tool result → plain response
|
||||
messages = [
|
||||
{"role": "user", "content": "Search for cats"},
|
||||
{"role": "assistant", "content": None, "tool_calls": tc},
|
||||
{"role": "tool", "tool_call_id": tc[0]["id"], "content": "Found 10 cats"},
|
||||
]
|
||||
self._do_step(
|
||||
client, backend, tokenizer, uuid, messages, "Here are 10 cats!", tools=tools
|
||||
)
|
||||
|
||||
# Step 3: user asks for more → another tool call
|
||||
messages.extend(
|
||||
[
|
||||
{"role": "assistant", "content": "Here are 10 cats!"},
|
||||
{"role": "user", "content": "Search for dogs too"},
|
||||
]
|
||||
)
|
||||
output = '<tool_call>{"name": "search", "arguments": {"q": "dogs"}}</tool_call>'
|
||||
data = self._do_step(
|
||||
client, backend, tokenizer, uuid, messages, output, tools=tools
|
||||
)
|
||||
tc2 = data["choices"][0]["message"]["tool_calls"]
|
||||
|
||||
# Step 4: tool result → plain response
|
||||
messages.extend(
|
||||
[
|
||||
{"role": "assistant", "content": None, "tool_calls": tc2},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tc2[0]["id"],
|
||||
"content": "Found 5 dogs",
|
||||
},
|
||||
]
|
||||
)
|
||||
self._do_step(
|
||||
client, backend, tokenizer, uuid, messages, "Found 5 dogs too!", tools=tools
|
||||
)
|
||||
|
||||
# Step 5: plain follow-up, no tools
|
||||
messages.extend(
|
||||
[
|
||||
{"role": "assistant", "content": "Found 5 dogs too!"},
|
||||
{"role": "user", "content": "Thanks!"},
|
||||
]
|
||||
)
|
||||
self._do_step(
|
||||
client, backend, tokenizer, uuid, messages, "You're welcome!", tools=tools
|
||||
)
|
||||
|
||||
# 5 completion steps, 11 messages — still 1 node
|
||||
nodes = client.get(f"/{uuid}/nodes").json()["nodes"]
|
||||
assert len(nodes) == 1, f"Expected 1 node, got {len(nodes)}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cleanup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCleanup:
|
||||
def test_delete_resets_nodes(self, client_and_backend):
|
||||
client, backend, tokenizer = client_and_backend
|
||||
|
||||
resp = client.post("/sessions/create", json={})
|
||||
uuid = resp.json()["uuid"]
|
||||
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
prompt_text = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
_setup_completion(backend, tokenizer, prompt_text, ["Hello!"])
|
||||
|
||||
client.post(
|
||||
f"/{uuid}/v1/chat/completions",
|
||||
json={"messages": messages},
|
||||
)
|
||||
|
||||
# Delete
|
||||
resp = client.delete(f"/{uuid}")
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Session gone
|
||||
resp = client.get(f"/{uuid}/nodes")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Error format
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestErrorFormat:
|
||||
def test_404_is_openai_format(self, client):
|
||||
resp = client.get("/nonexistent-uuid/nodes")
|
||||
assert resp.status_code == 404
|
||||
data = resp.json()
|
||||
assert "error" in data
|
||||
assert "message" in data["error"]
|
||||
assert "type" in data["error"]
|
||||
assert "code" in data["error"]
|
||||
363
atroposlib/tests/test_managed_server_proxy_integration.py
Normal file
363
atroposlib/tests/test_managed_server_proxy_integration.py
Normal file
|
|
@ -0,0 +1,363 @@
|
|||
"""Integration tests for ManagedServer OpenAI proxy against real vLLM backend.
|
||||
|
||||
Spins up example_trainer/vllm_api_server.py with Qwen3-4B as a subprocess.
|
||||
Requires GPU — skipped by default. Run with:
|
||||
|
||||
pytest --run-gpu atroposlib/tests/test_managed_server_proxy_integration.py -v -s
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from atroposlib.envs.server_handling.managed_server_proxy import create_app
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
|
||||
from atroposlib.envs.server_handling.server_manager import ServerManager
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
VLLM_PORT = 8123
|
||||
VLLM_MODEL = "Qwen/Qwen3-4B-Thinking-2507"
|
||||
PROXY_MODEL = VLLM_MODEL
|
||||
VLLM_BASE_URL = f"http://localhost:{VLLM_PORT}/v1"
|
||||
REPO_ROOT = os.path.join(os.path.dirname(__file__), "..", "..")
|
||||
VLLM_SCRIPT = os.path.join(REPO_ROOT, "example_trainer", "vllm_api_server.py")
|
||||
VENV_PYTHON = sys.executable # use the current interpreter
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vllm_backend():
|
||||
"""Start vLLM api server as a subprocess. Module-scoped so it's shared."""
|
||||
cmd = [
|
||||
VENV_PYTHON,
|
||||
VLLM_SCRIPT,
|
||||
"--model",
|
||||
VLLM_MODEL,
|
||||
"--port",
|
||||
str(VLLM_PORT),
|
||||
"--max-model-len",
|
||||
"32000",
|
||||
"--max-num-seqs",
|
||||
"32",
|
||||
]
|
||||
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
cwd=REPO_ROOT,
|
||||
)
|
||||
|
||||
# Wait for health
|
||||
deadline = time.time() + 180 # 3 min for model loading
|
||||
healthy = False
|
||||
while time.time() < deadline:
|
||||
try:
|
||||
resp = requests.get(f"http://localhost:{VLLM_PORT}/health", timeout=2)
|
||||
if resp.status_code == 200:
|
||||
healthy = True
|
||||
break
|
||||
except (requests.ConnectionError, requests.Timeout):
|
||||
pass
|
||||
|
||||
if proc.poll() is not None:
|
||||
stdout = proc.stdout.read().decode() if proc.stdout else ""
|
||||
pytest.fail(
|
||||
f"vLLM server exited early (code={proc.returncode}):\n{stdout[-3000:]}"
|
||||
)
|
||||
|
||||
time.sleep(3)
|
||||
|
||||
if not healthy:
|
||||
proc.kill()
|
||||
stdout = proc.stdout.read().decode() if proc.stdout else ""
|
||||
pytest.fail(f"vLLM server didn't become healthy within 180s:\n{stdout[-3000:]}")
|
||||
|
||||
yield proc
|
||||
|
||||
proc.send_signal(signal.SIGTERM)
|
||||
try:
|
||||
proc.wait(timeout=15)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
proc.wait()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def tokenizer_instance():
|
||||
return AutoTokenizer.from_pretrained(VLLM_MODEL)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def proxy_client(vllm_backend, tokenizer_instance):
|
||||
"""Create a test client for the proxy backed by the real vLLM server."""
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
config = APIServerConfig(
|
||||
model_name=VLLM_MODEL,
|
||||
base_url=VLLM_BASE_URL,
|
||||
api_key="",
|
||||
server_type="vllm",
|
||||
health_check=False,
|
||||
)
|
||||
server_manager = ServerManager(configs=[config])
|
||||
|
||||
app = create_app(
|
||||
server_manager=server_manager,
|
||||
tokenizer=tokenizer_instance,
|
||||
model_name=VLLM_MODEL,
|
||||
)
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
class TestRealChatCompletion:
|
||||
def test_basic_completion(self, proxy_client):
|
||||
resp = proxy_client.post("/sessions/create", json={})
|
||||
uuid = resp.json()["uuid"]
|
||||
|
||||
resp = proxy_client.post(
|
||||
f"/{uuid}/v1/chat/completions",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": "Say hello in one word."}],
|
||||
"max_tokens": 30,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["choices"]) == 1
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
assert content is not None
|
||||
assert len(content) > 0
|
||||
assert data["model"] == VLLM_MODEL
|
||||
|
||||
def test_n_completions(self, proxy_client):
|
||||
resp = proxy_client.post("/sessions/create", json={})
|
||||
uuid = resp.json()["uuid"]
|
||||
|
||||
resp = proxy_client.post(
|
||||
f"/{uuid}/v1/chat/completions",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": "Pick a random number"}],
|
||||
"max_tokens": 20,
|
||||
"temperature": 1.0,
|
||||
"n": 4,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["choices"]) == 4
|
||||
|
||||
# Check nodes
|
||||
resp = proxy_client.get(f"/{uuid}/nodes")
|
||||
assert len(resp.json()["nodes"]) == 4
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
class TestRealLogprobs:
|
||||
def test_logprobs_are_valid(self, proxy_client):
|
||||
resp = proxy_client.post("/sessions/create", json={})
|
||||
uuid = resp.json()["uuid"]
|
||||
|
||||
proxy_client.post(
|
||||
f"/{uuid}/v1/chat/completions",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
"max_tokens": 20,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
)
|
||||
|
||||
resp = proxy_client.get(f"/{uuid}/nodes")
|
||||
nodes = resp.json()["nodes"]
|
||||
assert len(nodes) == 1
|
||||
|
||||
node = nodes[0]
|
||||
# Find where completion starts (logprobs transition from 1.0 to negative)
|
||||
prompt_end = 0
|
||||
for i, lp in enumerate(node["logprobs"]):
|
||||
if lp != 1.0:
|
||||
prompt_end = i
|
||||
break
|
||||
|
||||
# Prompt logprobs should be 1.0
|
||||
assert all(lp == 1.0 for lp in node["logprobs"][:prompt_end])
|
||||
# Completion logprobs should be negative
|
||||
completion_lps = [lp for lp in node["logprobs"][prompt_end:] if lp != 1.0]
|
||||
assert len(completion_lps) > 0
|
||||
assert all(lp < 0 for lp in completion_lps)
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
class TestRealTokenAlignment:
|
||||
def test_tokens_decode_to_full_text(self, proxy_client, tokenizer_instance):
|
||||
resp = proxy_client.post("/sessions/create", json={})
|
||||
uuid = resp.json()["uuid"]
|
||||
|
||||
proxy_client.post(
|
||||
f"/{uuid}/v1/chat/completions",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": "Say exactly: test123"}],
|
||||
"max_tokens": 30,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
)
|
||||
|
||||
resp = proxy_client.get(f"/{uuid}/nodes")
|
||||
node = resp.json()["nodes"][0]
|
||||
|
||||
# Lengths must match
|
||||
assert len(node["tokens"]) == len(node["masked_tokens"])
|
||||
assert len(node["tokens"]) == len(node["logprobs"])
|
||||
|
||||
# Decode tokens and check they match full_text
|
||||
decoded = tokenizer_instance.decode(node["tokens"])
|
||||
# The decoded text should be close to (or contain) the full_text
|
||||
# Exact match may differ due to special token handling, but content should match
|
||||
assert len(decoded) > 0
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
class TestRealRender:
|
||||
def test_render_matches_tokenizer(self, proxy_client, tokenizer_instance):
|
||||
resp = proxy_client.post("/sessions/create", json={})
|
||||
uuid = resp.json()["uuid"]
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
]
|
||||
|
||||
resp = proxy_client.post(
|
||||
f"/{uuid}/v1/chat/completions/render",
|
||||
json={"messages": messages},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
|
||||
# Compare with direct tokenizer rendering
|
||||
expected_text = tokenizer_instance.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
assert data["prompt_text"] == expected_text
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
class TestRealSequenceExtension:
|
||||
def test_multi_turn_extends(self, proxy_client):
|
||||
resp = proxy_client.post("/sessions/create", json={})
|
||||
uuid = resp.json()["uuid"]
|
||||
|
||||
# Turn 1
|
||||
resp = proxy_client.post(
|
||||
f"/{uuid}/v1/chat/completions",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": "Say hello"}],
|
||||
"max_tokens": 20,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
turn1_content = resp.json()["choices"][0]["message"]["content"]
|
||||
|
||||
# Turn 2 — extends turn 1
|
||||
resp = proxy_client.post(
|
||||
f"/{uuid}/v1/chat/completions",
|
||||
json={
|
||||
"messages": [
|
||||
{"role": "user", "content": "Say hello"},
|
||||
{"role": "assistant", "content": turn1_content},
|
||||
{"role": "user", "content": "Now say goodbye"},
|
||||
],
|
||||
"max_tokens": 20,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Should have nodes (extension behavior depends on prefix matching)
|
||||
resp = proxy_client.get(f"/{uuid}/nodes")
|
||||
nodes = resp.json()["nodes"]
|
||||
assert len(nodes) >= 1
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
class TestRealConcurrentSessions:
|
||||
def test_sessions_independent(self, proxy_client):
|
||||
"""Multiple sessions should not contaminate each other."""
|
||||
uuids = []
|
||||
for _ in range(3):
|
||||
resp = proxy_client.post("/sessions/create", json={})
|
||||
uuids.append(resp.json()["uuid"])
|
||||
|
||||
# Complete on each
|
||||
for i, uuid in enumerate(uuids):
|
||||
resp = proxy_client.post(
|
||||
f"/{uuid}/v1/chat/completions",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": f"Count to {i+1}"}],
|
||||
"max_tokens": 30,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Each should have exactly 1 node
|
||||
for uuid in uuids:
|
||||
resp = proxy_client.get(f"/{uuid}/nodes")
|
||||
assert len(resp.json()["nodes"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
class TestRealOpenAIClientCompat:
|
||||
def test_openai_client_works(self, proxy_client):
|
||||
"""Verify the standard openai Python client can talk to our proxy."""
|
||||
resp = proxy_client.post("/sessions/create", json={})
|
||||
uuid = resp.json()["uuid"]
|
||||
|
||||
# The TestClient doesn't expose a real port, so we test the
|
||||
# response format is compatible by checking structure manually
|
||||
resp = proxy_client.post(
|
||||
f"/{uuid}/v1/chat/completions",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
"max_tokens": 10,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
)
|
||||
data = resp.json()
|
||||
|
||||
# Verify all fields the openai client expects
|
||||
assert "id" in data
|
||||
assert "object" in data
|
||||
assert data["object"] == "chat.completion"
|
||||
assert "created" in data
|
||||
assert "model" in data
|
||||
assert "choices" in data
|
||||
assert isinstance(data["choices"], list)
|
||||
for choice in data["choices"]:
|
||||
assert "index" in choice
|
||||
assert "message" in choice
|
||||
assert "finish_reason" in choice
|
||||
assert "role" in choice["message"]
|
||||
assert "content" in choice["message"]
|
||||
478
atroposlib/tests/test_tool_call_translator.py
Normal file
478
atroposlib/tests/test_tool_call_translator.py
Normal file
|
|
@ -0,0 +1,478 @@
|
|||
"""Unit tests for ToolCallTranslator — vLLM parser wrapper and lookup table.
|
||||
|
||||
These are pure logic tests, no server or model needed. Uses a mock tokenizer.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from atroposlib.envs.server_handling.tool_call_translator import ToolCallTranslator
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock tokenizer (same one from test_managed_server.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockTokenizer:
|
||||
def __init__(self):
|
||||
self.eos_token_id = 2
|
||||
self.bos_token_id = 1
|
||||
|
||||
def encode(self, text, add_special_tokens=True):
|
||||
tokens = [ord(c) for c in text]
|
||||
if add_special_tokens:
|
||||
tokens = [self.bos_token_id] + tokens
|
||||
return tokens
|
||||
|
||||
def decode(self, tokens, skip_special_tokens=False):
|
||||
if skip_special_tokens:
|
||||
tokens = [
|
||||
t for t in tokens if t not in [self.bos_token_id, self.eos_token_id]
|
||||
]
|
||||
return "".join([chr(t) if t > 31 else "" for t in tokens])
|
||||
|
||||
def get_vocab(self):
|
||||
# Minimal vocab for the parser — hermes parser calls this
|
||||
return {chr(i): i for i in range(128)}
|
||||
|
||||
def apply_chat_template(
|
||||
self, messages, tokenize=False, add_generation_prompt=True, tools=None
|
||||
):
|
||||
result = ""
|
||||
if tools:
|
||||
result += f"<tools>{json.dumps(tools)}</tools>\n"
|
||||
for msg in messages:
|
||||
result += f"<{msg['role']}>{msg.get('content', '')}</{msg['role']}>"
|
||||
if add_generation_prompt:
|
||||
result += "<assistant>"
|
||||
if tokenize:
|
||||
return self.encode(result)
|
||||
return result
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def translator():
|
||||
tok = MockTokenizer()
|
||||
return ToolCallTranslator(tokenizer=tok, parser_name="hermes")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Outbound: model output → OpenAI format
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseModelOutput:
|
||||
def test_single_tool_call(self, translator):
|
||||
raw = (
|
||||
'<tool_call>{"name": "search", "arguments": {"query": "cats"}}</tool_call>'
|
||||
)
|
||||
content, tool_calls, finish_reason = translator.parse_model_output(
|
||||
raw,
|
||||
tool_choice="auto",
|
||||
tools=[{"type": "function", "function": {"name": "search"}}],
|
||||
)
|
||||
|
||||
assert finish_reason == "tool_calls"
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].function.name == "search"
|
||||
assert json.loads(tool_calls[0].function.arguments) == {"query": "cats"}
|
||||
# content is None or empty when full text is a tool call
|
||||
assert content is None or content.strip() == ""
|
||||
|
||||
def test_multiple_tool_calls(self, translator):
|
||||
raw = (
|
||||
'<tool_call>{"name": "get_weather", "arguments": {"city": "SF"}}</tool_call>\n'
|
||||
'<tool_call>{"name": "get_time", "arguments": {"tz": "PST"}}</tool_call>'
|
||||
)
|
||||
tools = [
|
||||
{"type": "function", "function": {"name": "get_weather"}},
|
||||
{"type": "function", "function": {"name": "get_time"}},
|
||||
]
|
||||
content, tool_calls, finish_reason = translator.parse_model_output(
|
||||
raw, tool_choice="auto", tools=tools
|
||||
)
|
||||
|
||||
assert finish_reason == "tool_calls"
|
||||
assert len(tool_calls) == 2
|
||||
names = {tc.function.name for tc in tool_calls}
|
||||
assert names == {"get_weather", "get_time"}
|
||||
|
||||
def test_no_tool_calls(self, translator):
|
||||
raw = "The weather in SF is 72 degrees."
|
||||
content, tool_calls, finish_reason = translator.parse_model_output(
|
||||
raw,
|
||||
tool_choice="auto",
|
||||
tools=[{"type": "function", "function": {"name": "search"}}],
|
||||
)
|
||||
|
||||
assert finish_reason == "stop"
|
||||
assert tool_calls is None
|
||||
assert content == raw
|
||||
|
||||
def test_content_before_tool_call(self, translator):
|
||||
raw = 'Let me search for that.\n<tool_call>{"name": "search", "arguments": {"query": "cats"}}</tool_call>'
|
||||
content, tool_calls, finish_reason = translator.parse_model_output(
|
||||
raw,
|
||||
tool_choice="auto",
|
||||
tools=[{"type": "function", "function": {"name": "search"}}],
|
||||
)
|
||||
|
||||
assert finish_reason == "tool_calls"
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 1
|
||||
# Content before the tool call tag should be preserved
|
||||
assert content is not None
|
||||
assert "search for that" in content
|
||||
|
||||
def test_tool_choice_none_skips_parsing(self, translator):
|
||||
raw = (
|
||||
'<tool_call>{"name": "search", "arguments": {"query": "cats"}}</tool_call>'
|
||||
)
|
||||
content, tool_calls, finish_reason = translator.parse_model_output(
|
||||
raw,
|
||||
tool_choice="none",
|
||||
tools=[{"type": "function", "function": {"name": "search"}}],
|
||||
)
|
||||
|
||||
assert finish_reason == "stop"
|
||||
assert tool_calls is None
|
||||
assert content == raw # Raw text returned as-is
|
||||
|
||||
def test_no_tools_skips_parsing(self, translator):
|
||||
raw = (
|
||||
'<tool_call>{"name": "search", "arguments": {"query": "cats"}}</tool_call>'
|
||||
)
|
||||
content, tool_calls, finish_reason = translator.parse_model_output(
|
||||
raw, tool_choice="auto", tools=None
|
||||
)
|
||||
|
||||
assert finish_reason == "stop"
|
||||
assert tool_calls is None
|
||||
assert content == raw
|
||||
|
||||
def test_malformed_json_graceful_fallback(self, translator):
|
||||
raw = "<tool_call>not valid json at all</tool_call>"
|
||||
content, tool_calls, finish_reason = translator.parse_model_output(
|
||||
raw,
|
||||
tool_choice="auto",
|
||||
tools=[{"type": "function", "function": {"name": "search"}}],
|
||||
)
|
||||
|
||||
# Parser should handle gracefully — either no tools or raw content
|
||||
assert finish_reason == "stop"
|
||||
assert tool_calls is None
|
||||
|
||||
def test_unclosed_tool_call(self, translator):
|
||||
raw = '<tool_call>{"name": "search", "arguments": {"query": "cats"}}'
|
||||
content, tool_calls, finish_reason = translator.parse_model_output(
|
||||
raw,
|
||||
tool_choice="auto",
|
||||
tools=[{"type": "function", "function": {"name": "search"}}],
|
||||
)
|
||||
|
||||
# The hermes regex has a branch for unclosed tags
|
||||
assert finish_reason == "tool_calls"
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 1
|
||||
|
||||
def test_nested_json_arguments(self, translator):
|
||||
args = {
|
||||
"filter": {
|
||||
"type": "date",
|
||||
"range": {"start": "2024-01-01", "end": "2024-12-31"},
|
||||
}
|
||||
}
|
||||
raw = f'<tool_call>{{"name": "search", "arguments": {json.dumps(args)}}}</tool_call>'
|
||||
content, tool_calls, finish_reason = translator.parse_model_output(
|
||||
raw,
|
||||
tool_choice="auto",
|
||||
tools=[{"type": "function", "function": {"name": "search"}}],
|
||||
)
|
||||
|
||||
assert finish_reason == "tool_calls"
|
||||
assert json.loads(tool_calls[0].function.arguments) == args
|
||||
|
||||
def test_tool_call_with_think_tags(self, translator):
|
||||
raw = (
|
||||
"<think>I should search for this information.</think>\n"
|
||||
'<tool_call>{"name": "search", "arguments": {"query": "cats"}}</tool_call>'
|
||||
)
|
||||
content, tool_calls, finish_reason = translator.parse_model_output(
|
||||
raw,
|
||||
tool_choice="auto",
|
||||
tools=[{"type": "function", "function": {"name": "search"}}],
|
||||
)
|
||||
|
||||
assert finish_reason == "tool_calls"
|
||||
assert tool_calls is not None
|
||||
# Think content should be in the content field
|
||||
if content:
|
||||
assert "think" in content or "search for this" in content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lookup table
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLookupTable:
|
||||
def test_parse_populates_lookup(self, translator):
|
||||
raw = (
|
||||
'<tool_call>{"name": "search", "arguments": {"query": "cats"}}</tool_call>'
|
||||
)
|
||||
_, tool_calls, _ = translator.parse_model_output(
|
||||
raw,
|
||||
tool_choice="auto",
|
||||
tools=[{"type": "function", "function": {"name": "search"}}],
|
||||
)
|
||||
|
||||
assert len(translator.call_id_to_raw_text) == 1
|
||||
tc_id = tool_calls[0].id
|
||||
assert tc_id in translator.call_id_to_raw_text
|
||||
assert translator.call_id_to_raw_text[tc_id] == raw
|
||||
|
||||
def test_lookup_accumulates(self, translator):
|
||||
tools = [{"type": "function", "function": {"name": "search"}}]
|
||||
|
||||
raw1 = (
|
||||
'<tool_call>{"name": "search", "arguments": {"query": "cats"}}</tool_call>'
|
||||
)
|
||||
_, tc1, _ = translator.parse_model_output(raw1, tool_choice="auto", tools=tools)
|
||||
|
||||
raw2 = (
|
||||
'<tool_call>{"name": "search", "arguments": {"query": "dogs"}}</tool_call>'
|
||||
)
|
||||
_, tc2, _ = translator.parse_model_output(raw2, tool_choice="auto", tools=tools)
|
||||
|
||||
assert len(translator.call_id_to_raw_text) == 2
|
||||
assert tc1[0].id in translator.call_id_to_raw_text
|
||||
assert tc2[0].id in translator.call_id_to_raw_text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Inbound: OpenAI messages → raw text
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReconstructRawText:
|
||||
def test_reconstruct_from_lookup(self, translator):
|
||||
# First, parse to populate lookup
|
||||
raw = (
|
||||
'<tool_call>{"name": "search", "arguments": {"query": "cats"}}</tool_call>'
|
||||
)
|
||||
tools = [{"type": "function", "function": {"name": "search"}}]
|
||||
_, tool_calls, _ = translator.parse_model_output(
|
||||
raw, tool_choice="auto", tools=tools
|
||||
)
|
||||
|
||||
# Now reconstruct
|
||||
tc_dicts = [tc.model_dump() for tc in tool_calls]
|
||||
reconstructed = translator.reconstruct_raw_text_from_tool_calls(tc_dicts)
|
||||
|
||||
assert reconstructed == raw
|
||||
|
||||
def test_reconstruct_fallback_without_lookup(self, translator):
|
||||
# Reconstruct without having parsed first — uses fallback
|
||||
tc_dicts = [
|
||||
{
|
||||
"id": "fake-id-123",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": '{"query": "cats"}'},
|
||||
}
|
||||
]
|
||||
reconstructed = translator.reconstruct_raw_text_from_tool_calls(tc_dicts)
|
||||
|
||||
assert "<tool_call>" in reconstructed
|
||||
assert "search" in reconstructed
|
||||
assert "cats" in reconstructed
|
||||
|
||||
def test_reconstruct_empty_list(self, translator):
|
||||
assert translator.reconstruct_raw_text_from_tool_calls([]) == ""
|
||||
|
||||
def test_reconstruct_multiple_tool_calls(self, translator):
|
||||
tc_dicts = [
|
||||
{
|
||||
"id": "id-1",
|
||||
"type": "function",
|
||||
"function": {"name": "get_weather", "arguments": '{"city": "SF"}'},
|
||||
},
|
||||
{
|
||||
"id": "id-2",
|
||||
"type": "function",
|
||||
"function": {"name": "get_time", "arguments": '{"tz": "PST"}'},
|
||||
},
|
||||
]
|
||||
reconstructed = translator.reconstruct_raw_text_from_tool_calls(tc_dicts)
|
||||
|
||||
assert reconstructed.count("<tool_call>") == 2
|
||||
assert "get_weather" in reconstructed
|
||||
assert "get_time" in reconstructed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConvertMessages:
|
||||
def test_regular_messages_pass_through(self, translator):
|
||||
messages = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hi there"},
|
||||
]
|
||||
result = translator.convert_messages_for_template(messages)
|
||||
|
||||
assert result == messages
|
||||
|
||||
def test_assistant_with_tool_calls_reconstructed(self, translator):
|
||||
# Parse first to populate lookup
|
||||
raw = (
|
||||
'<tool_call>{"name": "search", "arguments": {"query": "cats"}}</tool_call>'
|
||||
)
|
||||
tools = [{"type": "function", "function": {"name": "search"}}]
|
||||
_, tool_calls, _ = translator.parse_model_output(
|
||||
raw, tool_choice="auto", tools=tools
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Search for cats"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [tc.model_dump() for tc in tool_calls],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_calls[0].id,
|
||||
"content": "Found 5 cats",
|
||||
},
|
||||
]
|
||||
|
||||
result = translator.convert_messages_for_template(messages)
|
||||
|
||||
# User message unchanged
|
||||
assert result[0] == messages[0]
|
||||
# Assistant message reconstructed to raw text
|
||||
assert result[1]["role"] == "assistant"
|
||||
assert "<tool_call>" in result[1]["content"]
|
||||
assert "tool_calls" not in result[1]
|
||||
# Tool message passed through
|
||||
assert result[2] == messages[2]
|
||||
|
||||
def test_assistant_with_content_and_tool_calls(self, translator):
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Let me search.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "fake-id",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": '{"q": "x"}'},
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
result = translator.convert_messages_for_template(messages)
|
||||
|
||||
assert result[0]["role"] == "assistant"
|
||||
assert "Let me search." in result[0]["content"]
|
||||
assert "<tool_call>" in result[0]["content"]
|
||||
|
||||
def test_mixed_message_types(self, translator):
|
||||
"""Only tool_call assistant messages are reconstructed."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "Hello!"}, # regular, no tool_calls
|
||||
{"role": "user", "content": "Search cats"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "tc-1",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": '{"q": "cats"}'},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "tc-1", "content": "5 results"},
|
||||
{"role": "assistant", "content": "Found 5 cats!"}, # regular again
|
||||
]
|
||||
|
||||
result = translator.convert_messages_for_template(messages)
|
||||
|
||||
# Messages at indices 0, 1, 2, 4, 5 should be unchanged
|
||||
assert result[0] == messages[0]
|
||||
assert result[1] == messages[1]
|
||||
assert result[2] == messages[2]
|
||||
assert result[4] == messages[4]
|
||||
assert result[5] == messages[5]
|
||||
# Message at index 3 should be reconstructed
|
||||
assert "<tool_call>" in result[3]["content"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Roundtrip
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRoundtrip:
|
||||
def test_single_tool_call_roundtrip(self, translator):
|
||||
raw = (
|
||||
'<tool_call>{"name": "search", "arguments": {"query": "cats"}}</tool_call>'
|
||||
)
|
||||
tools = [{"type": "function", "function": {"name": "search"}}]
|
||||
|
||||
# Parse
|
||||
_, tool_calls, _ = translator.parse_model_output(
|
||||
raw, tool_choice="auto", tools=tools
|
||||
)
|
||||
# Reconstruct
|
||||
tc_dicts = [tc.model_dump() for tc in tool_calls]
|
||||
reconstructed = translator.reconstruct_raw_text_from_tool_calls(tc_dicts)
|
||||
|
||||
assert reconstructed == raw
|
||||
|
||||
def test_tool_call_empty_arguments(self, translator):
|
||||
raw = '<tool_call>{"name": "list_all", "arguments": {}}</tool_call>'
|
||||
tools = [{"type": "function", "function": {"name": "list_all"}}]
|
||||
|
||||
_, tool_calls, _ = translator.parse_model_output(
|
||||
raw, tool_choice="auto", tools=tools
|
||||
)
|
||||
assert tool_calls is not None
|
||||
assert json.loads(tool_calls[0].function.arguments) == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Decode with tool awareness
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDecodeToolAwareness:
|
||||
def test_decode_without_tools(self, translator):
|
||||
tokens = [72, 101, 108, 108, 111] # "Hello"
|
||||
text = translator.decode_with_tool_awareness(tokens, has_tools=False)
|
||||
assert text == "Hello"
|
||||
|
||||
def test_decode_with_tools_preserves_special(self, translator):
|
||||
# With the mock tokenizer there are no "special" tokens to strip,
|
||||
# but verify the flag is passed correctly
|
||||
tokens = [72, 101, 108, 108, 111]
|
||||
text = translator.decode_with_tool_awareness(tokens, has_tools=True)
|
||||
assert text == "Hello"
|
||||
|
||||
def test_decode_strips_bos_without_tools(self, translator):
|
||||
tokens = [1, 72, 101, 108, 108, 111] # BOS + "Hello"
|
||||
text = translator.decode_with_tool_awareness(tokens, has_tools=False)
|
||||
assert text == "Hello" # BOS stripped
|
||||
|
||||
def test_decode_keeps_bos_with_tools(self, translator):
|
||||
tokens = [1, 72, 101, 108, 108, 111] # BOS + "Hello"
|
||||
text = translator.decode_with_tool_awareness(tokens, has_tools=True)
|
||||
# BOS (chr(1)) is not printable so mock tokenizer returns "" for it
|
||||
# but the flag skip_special_tokens=False is passed
|
||||
assert "Hello" in text
|
||||
Loading…
Add table
Add a link
Reference in a new issue