add tool call parsing based on vllm impl and an openai server endpoint

This commit is contained in:
dmahan93 2026-03-02 23:17:13 -06:00
parent 887a94374c
commit add42a2afb
11 changed files with 3370 additions and 34 deletions

View file

@ -5,19 +5,32 @@ def pytest_addoption(parser):
parser.addoption(
"--runproviders", action="store_true", default=False, help="run provider tests"
)
parser.addoption(
"--run-gpu",
action="store_true",
default=False,
help="run GPU integration tests",
)
def pytest_configure(config):
config.addinivalue_line(
"markers", "providers: mark test as requires providers api keys to run"
)
config.addinivalue_line(
"markers", "gpu: mark test as requiring GPU (skipped unless --run-gpu)"
)
def pytest_collection_modifyitems(config, items):
if config.getoption("--runproviders"):
# --runproviders given in cli: do not skip slow tests
return
skip_providers = pytest.mark.skip(reason="need --runproviders option to run")
for item in items:
if "providers" in item.keywords:
item.add_marker(skip_providers)
if not config.getoption("--runproviders"):
skip_providers = pytest.mark.skip(reason="need --runproviders option to run")
for item in items:
if "providers" in item.keywords:
item.add_marker(skip_providers)
if not config.getoption("--run-gpu"):
skip_gpu = pytest.mark.skip(reason="need --run-gpu option to run")
for item in items:
if "gpu" in item.keywords:
item.add_marker(skip_gpu)

View file

@ -493,6 +493,295 @@ async def test_multi_turn_chat_with_branching(mock_server):
assert f"More{actual_i}" in node.full_text # Has third turn
# ---------------------------------------------------------------------------
# Tool call support in ManagedServer.chat_completion()
# ---------------------------------------------------------------------------
class MockTokenizerWithTools(MockTokenizer):
"""Extended mock tokenizer that supports tools kwarg in apply_chat_template."""
def apply_chat_template(
self, messages, tokenize=False, add_generation_prompt=True, tools=None
):
result = ""
if tools:
import json
result += f"<tools>{json.dumps(tools)}</tools>\n"
for msg in messages:
content = msg.get("content", "") or ""
result += f"<{msg['role']}>{content}</{msg['role']}>"
if add_generation_prompt:
result += "<assistant>"
if tokenize:
return self.encode(result)
return result
@pytest.fixture
def mock_server_with_tools():
"""Mock server with tool-aware tokenizer."""
server = ServerHarness()
server.tokenizer = MockTokenizerWithTools()
class Config:
model_name = "test_model"
server.config = Config()
return server
def _setup_chat_completion(server, tokenizer, messages, output_texts, tools=None):
"""Helper: set up mock tokens_and_logprobs for a chat_completion call."""
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, tools=tools
)
prompt_tokens = tokenizer.encode(prompt)
output_tokens_list = [[ord(c) for c in text] for text in output_texts]
output_logprobs_list = [[-0.1] * len(tokens) for tokens in output_tokens_list]
finish_reasons = ["stop"] * len(output_texts)
server.set_tokens_and_logprobs_response(
prompt=prompt,
prompt_tokens=prompt_tokens,
output_tokens_list=output_tokens_list,
output_logprobs_list=output_logprobs_list,
finish_reasons=finish_reasons,
)
return prompt
@pytest.mark.asyncio
async def test_tool_call_parsing_outbound(mock_server_with_tools):
"""Model generates <tool_call> → chat_completion returns structured tool_calls."""
managed = ManagedServer(
mock_server_with_tools,
tokenizer=mock_server_with_tools.tokenizer,
tool_parser="hermes",
)
tools = [{"type": "function", "function": {"name": "search", "parameters": {}}}]
messages = [{"role": "user", "content": "Search cats"}]
raw_output = (
'<tool_call>{"name": "search", "arguments": {"query": "cats"}}</tool_call>'
)
_setup_chat_completion(
mock_server_with_tools,
mock_server_with_tools.tokenizer,
messages,
[raw_output],
tools=tools,
)
result = await managed.chat_completion(
messages=messages, tools=tools, tool_choice="auto"
)
assert len(result.choices) == 1
choice = result.choices[0]
assert choice.finish_reason == "tool_calls"
assert choice.message.tool_calls is not None
assert len(choice.message.tool_calls) == 1
tc = choice.message.tool_calls[0]
assert tc["function"]["name"] == "search"
# Node should have raw text (not parsed)
state = managed.get_state()
assert len(state["nodes"]) == 1
@pytest.mark.asyncio
async def test_tool_choice_none_skips(mock_server_with_tools):
"""tool_choice='none' returns raw text, no parsing."""
managed = ManagedServer(
mock_server_with_tools,
tokenizer=mock_server_with_tools.tokenizer,
tool_parser="hermes",
)
tools = [{"type": "function", "function": {"name": "search", "parameters": {}}}]
messages = [{"role": "user", "content": "Hi"}]
raw_output = '<tool_call>{"name": "search", "arguments": {"q": "x"}}</tool_call>'
_setup_chat_completion(
mock_server_with_tools,
mock_server_with_tools.tokenizer,
messages,
[raw_output],
tools=tools,
)
result = await managed.chat_completion(
messages=messages, tools=tools, tool_choice="none"
)
assert result.choices[0].message.tool_calls is None
assert result.choices[0].finish_reason == "stop"
# Raw text should be content
assert "<tool_call>" in result.choices[0].message.content
@pytest.mark.asyncio
async def test_no_tool_parser_passes_through(mock_server_with_tools):
"""Without tool_parser, tools kwarg is ignored — no parsing."""
managed = ManagedServer(
mock_server_with_tools,
tokenizer=mock_server_with_tools.tokenizer,
# No tool_parser
)
messages = [{"role": "user", "content": "Hi"}]
raw_output = '<tool_call>{"name": "search", "arguments": {"q": "x"}}</tool_call>'
_setup_chat_completion(
mock_server_with_tools, mock_server_with_tools.tokenizer, messages, [raw_output]
)
result = await managed.chat_completion(messages=messages)
# No tool parsing — raw text as content
assert result.choices[0].message.tool_calls is None
assert result.choices[0].finish_reason == "stop"
@pytest.mark.asyncio
async def test_tool_call_multi_turn_extends_node(mock_server_with_tools):
"""Multi-turn with tool calls should extend to 1 node."""
managed = ManagedServer(
mock_server_with_tools,
tokenizer=mock_server_with_tools.tokenizer,
tool_parser="hermes",
)
tok = mock_server_with_tools.tokenizer
tools = [{"type": "function", "function": {"name": "search", "parameters": {}}}]
# Step 1: user → tool_call
messages_1 = [{"role": "user", "content": "Search cats"}]
output_1 = '<tool_call>{"name": "search", "arguments": {"q": "cats"}}</tool_call>'
_setup_chat_completion(
mock_server_with_tools, tok, messages_1, [output_1], tools=tools
)
result_1 = await managed.chat_completion(
messages=messages_1, tools=tools, tool_choice="auto"
)
tc_1 = result_1.choices[0].message.tool_calls
assert len(managed.get_state()["nodes"]) == 1
# Step 2: include tool result → plain response
# Reconstruct the assistant message with tool_calls for the translator
messages_2 = [
{"role": "user", "content": "Search cats"},
{"role": "assistant", "content": None, "tool_calls": tc_1},
{"role": "tool", "tool_call_id": tc_1[0]["id"], "content": "Found 5 cats"},
]
# The translator will reconstruct the tool_call to raw text,
# so we need the prompt to match what it produces
output_2 = "Here are 5 cats!"
prompt_2 = tok.apply_chat_template(
managed._get_translator().convert_messages_for_template(messages_2),
tokenize=False,
add_generation_prompt=True,
tools=tools,
)
prompt_tokens_2 = tok.encode(prompt_2)
output_tokens_2 = [ord(c) for c in output_2]
mock_server_with_tools.set_tokens_and_logprobs_response(
prompt=prompt_2,
prompt_tokens=prompt_tokens_2,
output_tokens_list=[output_tokens_2],
output_logprobs_list=[[-0.1] * len(output_tokens_2)],
finish_reasons=["stop"],
)
result_2 = await managed.chat_completion(
messages=messages_2, tools=tools, tool_choice="auto"
)
assert result_2.choices[0].message.content == output_2
# Still 1 node — step 2 extended step 1
assert len(managed.get_state()["nodes"]) == 1
@pytest.mark.asyncio
async def test_tool_call_multiple_tools_parsed(mock_server_with_tools):
"""Multiple tool calls in one response are all parsed."""
managed = ManagedServer(
mock_server_with_tools,
tokenizer=mock_server_with_tools.tokenizer,
tool_parser="hermes",
)
tools = [
{"type": "function", "function": {"name": "get_weather", "parameters": {}}},
{"type": "function", "function": {"name": "get_time", "parameters": {}}},
]
messages = [{"role": "user", "content": "Weather and time?"}]
raw_output = (
'<tool_call>{"name": "get_weather", "arguments": {"city": "SF"}}</tool_call>\n'
'<tool_call>{"name": "get_time", "arguments": {"tz": "PST"}}</tool_call>'
)
_setup_chat_completion(
mock_server_with_tools,
mock_server_with_tools.tokenizer,
messages,
[raw_output],
tools=tools,
)
result = await managed.chat_completion(
messages=messages, tools=tools, tool_choice="auto"
)
assert result.choices[0].finish_reason == "tool_calls"
assert len(result.choices[0].message.tool_calls) == 2
names = {tc["function"]["name"] for tc in result.choices[0].message.tool_calls}
assert names == {"get_weather", "get_time"}
@pytest.mark.asyncio
async def test_tool_call_node_masking(mock_server_with_tools):
"""Nodes have proper masking even with tool parsing active."""
managed = ManagedServer(
mock_server_with_tools,
tokenizer=mock_server_with_tools.tokenizer,
tool_parser="hermes",
)
tools = [{"type": "function", "function": {"name": "search", "parameters": {}}}]
messages = [{"role": "user", "content": "Hi"}]
raw_output = '<tool_call>{"name": "search", "arguments": {"q": "x"}}</tool_call>'
_setup_chat_completion(
mock_server_with_tools,
mock_server_with_tools.tokenizer,
messages,
[raw_output],
tools=tools,
)
await managed.chat_completion(messages=messages, tools=tools)
node = managed.get_state()["nodes"][0]
# Lengths must match
assert len(node.tokens) == len(node.masked_tokens) == len(node.logprobs)
# Should have masked prompt tokens and actual completion tokens
num_masked = sum(1 for t in node.masked_tokens if t == -100)
num_actual = sum(1 for t in node.masked_tokens if t != -100)
assert num_masked > 0
assert num_actual > 0
# Prompt logprobs = 1.0, completion logprobs < 0
assert all(lp == 1.0 for lp in node.logprobs[:num_masked])
assert all(lp < 0 for lp in node.logprobs[num_masked:])
if __name__ == "__main__":
# Run tests
pytest.main([__file__, "-v"])

View file

@ -0,0 +1,852 @@
"""Mock-based tests for the ManagedServer OpenAI proxy.
Uses ServerHarness as the backend no real model or GPU needed.
Tests the full HTTP layer: session management, chat completions,
tool call translation, render endpoint, nodes, cleanup.
"""
import json
import pytest
from fastapi.testclient import TestClient
from atroposlib.envs.server_handling.managed_server_proxy import create_app
from atroposlib.envs.server_handling.server_harness import ServerHarness
from atroposlib.envs.server_handling.server_manager import ServerManager
# ---------------------------------------------------------------------------
# Mock tokenizer (same as test_managed_server.py / test_tool_call_translator.py)
# ---------------------------------------------------------------------------
class MockTokenizer:
def __init__(self):
self.eos_token_id = 2
self.bos_token_id = 1
def encode(self, text, add_special_tokens=True):
tokens = [ord(c) for c in text]
if add_special_tokens:
tokens = [self.bos_token_id] + tokens
return tokens
def decode(self, tokens, skip_special_tokens=False):
if skip_special_tokens:
tokens = [
t for t in tokens if t not in [self.bos_token_id, self.eos_token_id]
]
return "".join([chr(t) if t > 31 else "" for t in tokens])
def get_vocab(self):
return {chr(i): i for i in range(128)}
def apply_chat_template(
self, messages, tokenize=False, add_generation_prompt=True, tools=None
):
result = ""
if tools:
result += f"<tools>{json.dumps(tools)}</tools>\n"
for msg in messages:
content = msg.get("content", "") or ""
result += f"<{msg['role']}>{content}</{msg['role']}>"
if add_generation_prompt:
result += "<assistant>"
if tokenize:
return self.encode(result)
return result
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def mock_backend():
"""Create a mock server backend with tokenizer."""
server = ServerHarness()
server.tokenizer = MockTokenizer()
# ServerManager's _select_server checks these attributes
server.server_healthy = True
class Config:
model_name = "test_model"
server.config = Config()
return server
@pytest.fixture
def server_manager(mock_backend):
"""Create a ServerManager wrapping the mock backend."""
# Can't use ServerManager constructor with empty configs, so build manually
mgr = object.__new__(ServerManager)
mgr.max_n_completions = 8
mgr.reasoning_config = None
mgr.servers = [mock_backend]
return mgr
@pytest.fixture
def client(server_manager):
"""Create a test client for the proxy app."""
tokenizer = MockTokenizer()
app = create_app(
server_manager=server_manager,
tokenizer=tokenizer,
model_name="test_model",
)
return TestClient(app)
@pytest.fixture
def client_and_backend(mock_backend, server_manager):
"""Return both client and backend for tests that need to set up mock responses."""
tokenizer = MockTokenizer()
app = create_app(
server_manager=server_manager,
tokenizer=tokenizer,
model_name="test_model",
)
return TestClient(app), mock_backend, tokenizer
def _setup_completion(
backend, tokenizer, prompt_text, output_texts, finish_reasons=None
):
"""Helper to set up a mock tokens_and_logprobs response."""
prompt_tokens = tokenizer.encode(prompt_text)
output_tokens_list = [[ord(c) for c in text] for text in output_texts]
output_logprobs_list = [[-0.1] * len(tokens) for tokens in output_tokens_list]
if finish_reasons is None:
finish_reasons = ["stop"] * len(output_texts)
backend.set_tokens_and_logprobs_response(
prompt=prompt_text,
prompt_tokens=prompt_tokens,
output_tokens_list=output_tokens_list,
output_logprobs_list=output_logprobs_list,
finish_reasons=finish_reasons,
)
# ---------------------------------------------------------------------------
# Health / Models
# ---------------------------------------------------------------------------
class TestHealth:
def test_health(self, client):
resp = client.get("/health")
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "ok"
assert data["model"] == "test_model"
assert data["sessions"] == 0
def test_models(self, client):
resp = client.get("/v1/models")
assert resp.status_code == 200
data = resp.json()
assert data["object"] == "list"
assert len(data["data"]) == 1
assert data["data"][0]["id"] == "test_model"
# ---------------------------------------------------------------------------
# Session Management
# ---------------------------------------------------------------------------
class TestSessionManagement:
def test_create_session(self, client):
resp = client.post("/sessions/create", json={})
assert resp.status_code == 200
data = resp.json()
assert "uuid" in data
assert data["model_name"] == "test_model"
assert data["tool_parser"] == "hermes"
def test_create_session_custom_parser(self, client):
resp = client.post("/sessions/create", json={"tool_parser": "hermes"})
assert resp.status_code == 200
assert resp.json()["tool_parser"] == "hermes"
def test_list_sessions(self, client):
# Create 3 sessions
uuids = []
for _ in range(3):
resp = client.post("/sessions/create", json={})
uuids.append(resp.json()["uuid"])
resp = client.get("/sessions")
assert resp.status_code == 200
sessions = resp.json()["sessions"]
assert len(sessions) == 3
listed_uuids = {s["uuid"] for s in sessions}
assert listed_uuids == set(uuids)
def test_delete_session(self, client):
resp = client.post("/sessions/create", json={})
uuid = resp.json()["uuid"]
resp = client.delete(f"/{uuid}")
assert resp.status_code == 200
assert resp.json()["status"] == "deleted"
# Should be gone
resp = client.get(f"/{uuid}/nodes")
assert resp.status_code == 404
def test_delete_nonexistent_session(self, client):
resp = client.delete("/nonexistent-uuid")
assert resp.status_code == 404
def test_session_not_found(self, client):
resp = client.post(
"/nonexistent-uuid/v1/chat/completions",
json={"messages": [{"role": "user", "content": "hi"}]},
)
assert resp.status_code == 404
# ---------------------------------------------------------------------------
# Chat Completions
# ---------------------------------------------------------------------------
class TestChatCompletions:
def test_basic_completion(self, client_and_backend):
client, backend, tokenizer = client_and_backend
# Create session
resp = client.post("/sessions/create", json={})
uuid = resp.json()["uuid"]
# Set up mock response
messages = [{"role": "user", "content": "Hello"}]
prompt_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
_setup_completion(backend, tokenizer, prompt_text, ["Hi there!"])
# Make request
resp = client.post(
f"/{uuid}/v1/chat/completions",
json={"messages": messages, "max_tokens": 100},
)
assert resp.status_code == 200
data = resp.json()
assert data["object"] == "chat.completion"
assert data["model"] == "test_model"
assert len(data["choices"]) == 1
assert data["choices"][0]["message"]["role"] == "assistant"
assert data["choices"][0]["message"]["content"] == "Hi there!"
assert data["choices"][0]["finish_reason"] == "stop"
assert data["id"].startswith("chatcmpl-")
def test_completion_with_n(self, client_and_backend):
client, backend, tokenizer = client_and_backend
resp = client.post("/sessions/create", json={})
uuid = resp.json()["uuid"]
messages = [{"role": "user", "content": "Pick a number"}]
prompt_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
_setup_completion(backend, tokenizer, prompt_text, ["One", "Two", "Three"])
resp = client.post(
f"/{uuid}/v1/chat/completions",
json={"messages": messages, "n": 3, "max_tokens": 50},
)
assert resp.status_code == 200
data = resp.json()
assert len(data["choices"]) == 3
def test_empty_messages_error(self, client):
resp = client.post("/sessions/create", json={})
uuid = resp.json()["uuid"]
resp = client.post(
f"/{uuid}/v1/chat/completions",
json={"messages": []},
)
assert resp.status_code == 400
def test_completion_with_system_prompt(self, client_and_backend):
client, backend, tokenizer = client_and_backend
resp = client.post("/sessions/create", json={})
uuid = resp.json()["uuid"]
messages = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hi"},
]
prompt_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
_setup_completion(backend, tokenizer, prompt_text, ["Hello!"])
resp = client.post(
f"/{uuid}/v1/chat/completions",
json={"messages": messages},
)
assert resp.status_code == 200
assert resp.json()["choices"][0]["message"]["content"] == "Hello!"
# ---------------------------------------------------------------------------
# Tool Call Handling
# ---------------------------------------------------------------------------
class TestToolCalls:
def test_tool_call_outbound(self, client_and_backend):
"""Model generates <tool_call> tags → response has structured tool_calls."""
client, backend, tokenizer = client_and_backend
resp = client.post("/sessions/create", json={})
uuid = resp.json()["uuid"]
tools = [{"type": "function", "function": {"name": "search", "parameters": {}}}]
messages = [{"role": "user", "content": "Search cats"}]
prompt_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, tools=tools
)
# Model output includes tool call tags
raw_output = (
'<tool_call>{"name": "search", "arguments": {"query": "cats"}}</tool_call>'
)
_setup_completion(backend, tokenizer, prompt_text, [raw_output])
resp = client.post(
f"/{uuid}/v1/chat/completions",
json={"messages": messages, "tools": tools},
)
assert resp.status_code == 200
data = resp.json()
choice = data["choices"][0]
assert choice["finish_reason"] == "tool_calls"
assert "tool_calls" in choice["message"]
assert len(choice["message"]["tool_calls"]) == 1
tc = choice["message"]["tool_calls"][0]
assert tc["function"]["name"] == "search"
assert json.loads(tc["function"]["arguments"]) == {"query": "cats"}
def test_tool_choice_none(self, client_and_backend):
"""tool_choice=none → no parsing, raw text returned."""
client, backend, tokenizer = client_and_backend
resp = client.post("/sessions/create", json={})
uuid = resp.json()["uuid"]
tools = [{"type": "function", "function": {"name": "search", "parameters": {}}}]
messages = [{"role": "user", "content": "Search cats"}]
prompt_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, tools=tools
)
raw_output = (
'<tool_call>{"name": "search", "arguments": {"query": "cats"}}</tool_call>'
)
_setup_completion(backend, tokenizer, prompt_text, [raw_output])
resp = client.post(
f"/{uuid}/v1/chat/completions",
json={"messages": messages, "tools": tools, "tool_choice": "none"},
)
assert resp.status_code == 200
choice = resp.json()["choices"][0]
# Should NOT have tool_calls since tool_choice is "none"
assert choice["finish_reason"] == "stop"
assert (
"tool_calls" not in choice["message"]
or choice["message"].get("tool_calls") is None
)
def test_nodes_preserve_raw_text(self, client_and_backend):
"""ManagedServer nodes should have raw text, not parsed tool_calls."""
client, backend, tokenizer = client_and_backend
resp = client.post("/sessions/create", json={})
uuid = resp.json()["uuid"]
tools = [{"type": "function", "function": {"name": "search", "parameters": {}}}]
messages = [{"role": "user", "content": "Search cats"}]
prompt_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, tools=tools
)
raw_output = (
'<tool_call>{"name": "search", "arguments": {"query": "cats"}}</tool_call>'
)
_setup_completion(backend, tokenizer, prompt_text, [raw_output])
client.post(
f"/{uuid}/v1/chat/completions",
json={"messages": messages, "tools": tools},
)
# Check nodes — should have the raw tokens, not parsed
resp = client.get(f"/{uuid}/nodes")
assert resp.status_code == 200
nodes = resp.json()["nodes"]
assert len(nodes) == 1
# ---------------------------------------------------------------------------
# Render Endpoint
# ---------------------------------------------------------------------------
class TestRender:
def test_render_basic(self, client):
resp = client.post("/sessions/create", json={})
uuid = resp.json()["uuid"]
resp = client.post(
f"/{uuid}/v1/chat/completions/render",
json={"messages": [{"role": "user", "content": "Hello"}]},
)
assert resp.status_code == 200
data = resp.json()
assert "prompt_text" in data
assert "token_ids" in data
assert "num_tokens" in data
assert data["num_tokens"] == len(data["token_ids"])
assert "<user>Hello</user>" in data["prompt_text"]
def test_render_with_tools(self, client):
resp = client.post("/sessions/create", json={})
uuid = resp.json()["uuid"]
tools = [{"type": "function", "function": {"name": "search"}}]
resp = client.post(
f"/{uuid}/v1/chat/completions/render",
json={
"messages": [{"role": "user", "content": "Hi"}],
"tools": tools,
},
)
assert resp.status_code == 200
data = resp.json()
# Tool definitions should appear in the rendered prompt
assert "search" in data["prompt_text"]
def test_render_does_not_create_nodes(self, client):
"""Render should not cause any generation or node creation."""
resp = client.post("/sessions/create", json={})
uuid = resp.json()["uuid"]
client.post(
f"/{uuid}/v1/chat/completions/render",
json={"messages": [{"role": "user", "content": "Hi"}]},
)
resp = client.get(f"/{uuid}/nodes")
assert resp.json()["nodes"] == []
# ---------------------------------------------------------------------------
# Nodes
# ---------------------------------------------------------------------------
class TestNodes:
def test_get_nodes_empty(self, client):
resp = client.post("/sessions/create", json={})
uuid = resp.json()["uuid"]
resp = client.get(f"/{uuid}/nodes")
assert resp.status_code == 200
assert resp.json()["nodes"] == []
def test_get_nodes_after_completion(self, client_and_backend):
client, backend, tokenizer = client_and_backend
resp = client.post("/sessions/create", json={})
uuid = resp.json()["uuid"]
messages = [{"role": "user", "content": "Hi"}]
prompt_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
_setup_completion(backend, tokenizer, prompt_text, ["Hello!"])
client.post(
f"/{uuid}/v1/chat/completions",
json={"messages": messages},
)
resp = client.get(f"/{uuid}/nodes")
assert resp.status_code == 200
nodes = resp.json()["nodes"]
assert len(nodes) == 1
node = nodes[0]
assert "tokens" in node
assert "masked_tokens" in node
assert "logprobs" in node
assert "full_text" in node
assert (
len(node["tokens"]) == len(node["masked_tokens"]) == len(node["logprobs"])
)
def test_nodes_have_proper_masking(self, client_and_backend):
client, backend, tokenizer = client_and_backend
resp = client.post("/sessions/create", json={})
uuid = resp.json()["uuid"]
messages = [{"role": "user", "content": "Hi"}]
prompt_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
prompt_tokens = tokenizer.encode(prompt_text)
prompt_len = len(prompt_tokens)
_setup_completion(backend, tokenizer, prompt_text, ["Hello!"])
client.post(
f"/{uuid}/v1/chat/completions",
json={"messages": messages},
)
resp = client.get(f"/{uuid}/nodes")
node = resp.json()["nodes"][0]
# Prompt tokens should be masked with -100
assert all(t == -100 for t in node["masked_tokens"][:prompt_len])
# Prompt logprobs should be 1.0
assert all(lp == 1.0 for lp in node["logprobs"][:prompt_len])
# Completion logprobs should be actual values (negative)
assert all(lp < 0 for lp in node["logprobs"][prompt_len:])
# ---------------------------------------------------------------------------
# Deep multi-step node handling
# ---------------------------------------------------------------------------
class TestMultiStepNodeHandling:
"""Test that multi-step conversations with tool calls produce exactly 1 node.
Simulates a realistic 10+ message agentic conversation:
user assistant(tool_call) tool_result assistant(text)
user assistant(tool_call) tool_result assistant(tool_call)
tool_result assistant(text) user assistant(text)
Each step extends the previous node, so we should end up with exactly
1 node containing the full tokenized conversation.
"""
def _do_step(
self,
client,
backend,
tokenizer,
uuid,
messages,
output_text,
tools=None,
expect_tool_calls=False,
):
"""Helper: use render endpoint to get exact prompt, set up mock, call endpoint."""
body = {"messages": messages, "max_tokens": 200}
if tools:
body["tools"] = tools
# Use the render endpoint to get the exact prompt the proxy will generate
# (this includes tool_call reconstruction through the translator)
render_resp = client.post(f"/{uuid}/v1/chat/completions/render", json=body)
assert render_resp.status_code == 200, f"Render failed: {render_resp.json()}"
prompt_text = render_resp.json()["prompt_text"]
_setup_completion(backend, tokenizer, prompt_text, [output_text])
resp = client.post(f"/{uuid}/v1/chat/completions", json=body)
assert resp.status_code == 200, f"Step failed: {resp.json()}"
return resp.json()
def test_10_message_conversation_one_node(self, client_and_backend):
"""Full 10-message conversation with tool calls → exactly 1 node."""
client, backend, tokenizer = client_and_backend
resp = client.post("/sessions/create", json={})
uuid = resp.json()["uuid"]
tools = [
{"type": "function", "function": {"name": "get_weather", "parameters": {}}},
{
"type": "function",
"function": {"name": "get_forecast", "parameters": {}},
},
]
# -- Step 1: user asks about weather --
messages = [{"role": "user", "content": "What's the weather in SF?"}]
output_1 = '<tool_call>{"name": "get_weather", "arguments": {"city": "SF"}}</tool_call>'
data = self._do_step(
client, backend, tokenizer, uuid, messages, output_1, tools=tools
)
assert data["choices"][0]["finish_reason"] == "tool_calls"
tc_1 = data["choices"][0]["message"]["tool_calls"]
# Check: 1 node so far
nodes = client.get(f"/{uuid}/nodes").json()["nodes"]
assert len(nodes) == 1, f"Expected 1 node after step 1, got {len(nodes)}"
# -- Step 2: tool result --
messages = [
{"role": "user", "content": "What's the weather in SF?"},
{"role": "assistant", "content": None, "tool_calls": tc_1},
{
"role": "tool",
"tool_call_id": tc_1[0]["id"],
"content": "72°F and sunny",
},
]
output_2 = "The weather in SF is 72°F and sunny! Want the forecast too?"
self._do_step(client, backend, tokenizer, uuid, messages, output_2, tools=tools)
nodes = client.get(f"/{uuid}/nodes").json()["nodes"]
assert len(nodes) == 1, f"Expected 1 node after step 2, got {len(nodes)}"
# -- Step 3: user says yes --
messages.extend(
[
{"role": "assistant", "content": output_2},
{"role": "user", "content": "Yes please, get the forecast"},
]
)
output_3 = '<tool_call>{"name": "get_forecast", "arguments": {"city": "SF"}}</tool_call>'
data = self._do_step(
client, backend, tokenizer, uuid, messages, output_3, tools=tools
)
tc_3 = data["choices"][0]["message"]["tool_calls"]
nodes = client.get(f"/{uuid}/nodes").json()["nodes"]
assert len(nodes) == 1, f"Expected 1 node after step 3, got {len(nodes)}"
# -- Step 4: forecast tool result --
messages.extend(
[
{"role": "assistant", "content": None, "tool_calls": tc_3},
{
"role": "tool",
"tool_call_id": tc_3[0]["id"],
"content": "Rain expected tomorrow",
},
]
)
output_4 = "The forecast says rain is expected tomorrow in SF."
self._do_step(client, backend, tokenizer, uuid, messages, output_4, tools=tools)
nodes = client.get(f"/{uuid}/nodes").json()["nodes"]
assert len(nodes) == 1, f"Expected 1 node after step 4, got {len(nodes)}"
# -- Step 5: user asks about another city --
messages.extend(
[
{"role": "assistant", "content": output_4},
{"role": "user", "content": "What about NYC?"},
]
)
output_5 = '<tool_call>{"name": "get_weather", "arguments": {"city": "NYC"}}</tool_call>'
data = self._do_step(
client, backend, tokenizer, uuid, messages, output_5, tools=tools
)
tc_5 = data["choices"][0]["message"]["tool_calls"]
nodes = client.get(f"/{uuid}/nodes").json()["nodes"]
assert len(nodes) == 1, f"Expected 1 node after step 5, got {len(nodes)}"
# -- Step 6: NYC tool result --
messages.extend(
[
{"role": "assistant", "content": None, "tool_calls": tc_5},
{
"role": "tool",
"tool_call_id": tc_5[0]["id"],
"content": "55°F and cloudy",
},
]
)
output_6 = "NYC is 55°F and cloudy. Quite different from SF!"
self._do_step(client, backend, tokenizer, uuid, messages, output_6, tools=tools)
# -- FINAL CHECK: still exactly 1 node after 6 completions / 12+ messages --
nodes = client.get(f"/{uuid}/nodes").json()["nodes"]
assert (
len(nodes) == 1
), f"Expected 1 node after full conversation, got {len(nodes)}"
# Verify the node has proper structure
node = nodes[0]
assert (
len(node["tokens"]) == len(node["masked_tokens"]) == len(node["logprobs"])
)
assert len(node["tokens"]) > 0
# Verify masking: there should be SOME -100 (prompt) and SOME actual tokens
num_masked = sum(1 for t in node["masked_tokens"] if t == -100)
num_actual = sum(1 for t in node["masked_tokens"] if t != -100)
assert num_masked > 0, "Should have masked prompt tokens"
assert num_actual > 0, "Should have unmasked completion tokens"
def test_plain_multi_turn_no_tools_one_node(self, client_and_backend):
"""5-turn conversation without tools → exactly 1 node."""
client, backend, tokenizer = client_and_backend
resp = client.post("/sessions/create", json={})
uuid = resp.json()["uuid"]
conversation = []
for i in range(5):
# Add user message
conversation.append({"role": "user", "content": f"Turn {i+1} question"})
prompt_text = tokenizer.apply_chat_template(
conversation, tokenize=False, add_generation_prompt=True
)
output = f"Response to turn {i+1}"
_setup_completion(backend, tokenizer, prompt_text, [output])
resp = client.post(
f"/{uuid}/v1/chat/completions",
json={"messages": conversation},
)
assert resp.status_code == 200
# Add assistant response for next turn
conversation.append({"role": "assistant", "content": output})
# After 5 turns (10 messages), should still be 1 node
nodes = client.get(f"/{uuid}/nodes").json()["nodes"]
assert len(nodes) == 1, f"Expected 1 node after 5 turns, got {len(nodes)}"
def test_tool_then_plain_then_tool_one_node(self, client_and_backend):
"""Mixed: tool call → plain text → tool call → plain → exactly 1 node."""
client, backend, tokenizer = client_and_backend
resp = client.post("/sessions/create", json={})
uuid = resp.json()["uuid"]
tools = [{"type": "function", "function": {"name": "search", "parameters": {}}}]
# Step 1: tool call
messages = [{"role": "user", "content": "Search for cats"}]
output = '<tool_call>{"name": "search", "arguments": {"q": "cats"}}</tool_call>'
data = self._do_step(
client, backend, tokenizer, uuid, messages, output, tools=tools
)
tc = data["choices"][0]["message"]["tool_calls"]
# Step 2: tool result → plain response
messages = [
{"role": "user", "content": "Search for cats"},
{"role": "assistant", "content": None, "tool_calls": tc},
{"role": "tool", "tool_call_id": tc[0]["id"], "content": "Found 10 cats"},
]
self._do_step(
client, backend, tokenizer, uuid, messages, "Here are 10 cats!", tools=tools
)
# Step 3: user asks for more → another tool call
messages.extend(
[
{"role": "assistant", "content": "Here are 10 cats!"},
{"role": "user", "content": "Search for dogs too"},
]
)
output = '<tool_call>{"name": "search", "arguments": {"q": "dogs"}}</tool_call>'
data = self._do_step(
client, backend, tokenizer, uuid, messages, output, tools=tools
)
tc2 = data["choices"][0]["message"]["tool_calls"]
# Step 4: tool result → plain response
messages.extend(
[
{"role": "assistant", "content": None, "tool_calls": tc2},
{
"role": "tool",
"tool_call_id": tc2[0]["id"],
"content": "Found 5 dogs",
},
]
)
self._do_step(
client, backend, tokenizer, uuid, messages, "Found 5 dogs too!", tools=tools
)
# Step 5: plain follow-up, no tools
messages.extend(
[
{"role": "assistant", "content": "Found 5 dogs too!"},
{"role": "user", "content": "Thanks!"},
]
)
self._do_step(
client, backend, tokenizer, uuid, messages, "You're welcome!", tools=tools
)
# 5 completion steps, 11 messages — still 1 node
nodes = client.get(f"/{uuid}/nodes").json()["nodes"]
assert len(nodes) == 1, f"Expected 1 node, got {len(nodes)}"
# ---------------------------------------------------------------------------
# Cleanup
# ---------------------------------------------------------------------------
class TestCleanup:
def test_delete_resets_nodes(self, client_and_backend):
client, backend, tokenizer = client_and_backend
resp = client.post("/sessions/create", json={})
uuid = resp.json()["uuid"]
messages = [{"role": "user", "content": "Hi"}]
prompt_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
_setup_completion(backend, tokenizer, prompt_text, ["Hello!"])
client.post(
f"/{uuid}/v1/chat/completions",
json={"messages": messages},
)
# Delete
resp = client.delete(f"/{uuid}")
assert resp.status_code == 200
# Session gone
resp = client.get(f"/{uuid}/nodes")
assert resp.status_code == 404
# ---------------------------------------------------------------------------
# Error format
# ---------------------------------------------------------------------------
class TestErrorFormat:
def test_404_is_openai_format(self, client):
resp = client.get("/nonexistent-uuid/nodes")
assert resp.status_code == 404
data = resp.json()
assert "error" in data
assert "message" in data["error"]
assert "type" in data["error"]
assert "code" in data["error"]

View file

@ -0,0 +1,363 @@
"""Integration tests for ManagedServer OpenAI proxy against real vLLM backend.
Spins up example_trainer/vllm_api_server.py with Qwen3-4B as a subprocess.
Requires GPU skipped by default. Run with:
pytest --run-gpu atroposlib/tests/test_managed_server_proxy_integration.py -v -s
"""
import os
import signal
import subprocess
import sys
import time
import pytest
import requests
from transformers import AutoTokenizer
from atroposlib.envs.server_handling.managed_server_proxy import create_app
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
from atroposlib.envs.server_handling.server_manager import ServerManager
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
VLLM_PORT = 8123
VLLM_MODEL = "Qwen/Qwen3-4B-Thinking-2507"
PROXY_MODEL = VLLM_MODEL
VLLM_BASE_URL = f"http://localhost:{VLLM_PORT}/v1"
REPO_ROOT = os.path.join(os.path.dirname(__file__), "..", "..")
VLLM_SCRIPT = os.path.join(REPO_ROOT, "example_trainer", "vllm_api_server.py")
VENV_PYTHON = sys.executable # use the current interpreter
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(scope="module")
def vllm_backend():
"""Start vLLM api server as a subprocess. Module-scoped so it's shared."""
cmd = [
VENV_PYTHON,
VLLM_SCRIPT,
"--model",
VLLM_MODEL,
"--port",
str(VLLM_PORT),
"--max-model-len",
"32000",
"--max-num-seqs",
"32",
]
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
cwd=REPO_ROOT,
)
# Wait for health
deadline = time.time() + 180 # 3 min for model loading
healthy = False
while time.time() < deadline:
try:
resp = requests.get(f"http://localhost:{VLLM_PORT}/health", timeout=2)
if resp.status_code == 200:
healthy = True
break
except (requests.ConnectionError, requests.Timeout):
pass
if proc.poll() is not None:
stdout = proc.stdout.read().decode() if proc.stdout else ""
pytest.fail(
f"vLLM server exited early (code={proc.returncode}):\n{stdout[-3000:]}"
)
time.sleep(3)
if not healthy:
proc.kill()
stdout = proc.stdout.read().decode() if proc.stdout else ""
pytest.fail(f"vLLM server didn't become healthy within 180s:\n{stdout[-3000:]}")
yield proc
proc.send_signal(signal.SIGTERM)
try:
proc.wait(timeout=15)
except subprocess.TimeoutExpired:
proc.kill()
proc.wait()
@pytest.fixture(scope="module")
def tokenizer_instance():
return AutoTokenizer.from_pretrained(VLLM_MODEL)
@pytest.fixture(scope="module")
def proxy_client(vllm_backend, tokenizer_instance):
"""Create a test client for the proxy backed by the real vLLM server."""
from fastapi.testclient import TestClient
config = APIServerConfig(
model_name=VLLM_MODEL,
base_url=VLLM_BASE_URL,
api_key="",
server_type="vllm",
health_check=False,
)
server_manager = ServerManager(configs=[config])
app = create_app(
server_manager=server_manager,
tokenizer=tokenizer_instance,
model_name=VLLM_MODEL,
)
return TestClient(app)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.mark.gpu
class TestRealChatCompletion:
def test_basic_completion(self, proxy_client):
resp = proxy_client.post("/sessions/create", json={})
uuid = resp.json()["uuid"]
resp = proxy_client.post(
f"/{uuid}/v1/chat/completions",
json={
"messages": [{"role": "user", "content": "Say hello in one word."}],
"max_tokens": 30,
"temperature": 0.0,
},
)
assert resp.status_code == 200
data = resp.json()
assert len(data["choices"]) == 1
content = data["choices"][0]["message"]["content"]
assert content is not None
assert len(content) > 0
assert data["model"] == VLLM_MODEL
def test_n_completions(self, proxy_client):
resp = proxy_client.post("/sessions/create", json={})
uuid = resp.json()["uuid"]
resp = proxy_client.post(
f"/{uuid}/v1/chat/completions",
json={
"messages": [{"role": "user", "content": "Pick a random number"}],
"max_tokens": 20,
"temperature": 1.0,
"n": 4,
},
)
assert resp.status_code == 200
data = resp.json()
assert len(data["choices"]) == 4
# Check nodes
resp = proxy_client.get(f"/{uuid}/nodes")
assert len(resp.json()["nodes"]) == 4
@pytest.mark.gpu
class TestRealLogprobs:
def test_logprobs_are_valid(self, proxy_client):
resp = proxy_client.post("/sessions/create", json={})
uuid = resp.json()["uuid"]
proxy_client.post(
f"/{uuid}/v1/chat/completions",
json={
"messages": [{"role": "user", "content": "Hi"}],
"max_tokens": 20,
"temperature": 0.0,
},
)
resp = proxy_client.get(f"/{uuid}/nodes")
nodes = resp.json()["nodes"]
assert len(nodes) == 1
node = nodes[0]
# Find where completion starts (logprobs transition from 1.0 to negative)
prompt_end = 0
for i, lp in enumerate(node["logprobs"]):
if lp != 1.0:
prompt_end = i
break
# Prompt logprobs should be 1.0
assert all(lp == 1.0 for lp in node["logprobs"][:prompt_end])
# Completion logprobs should be negative
completion_lps = [lp for lp in node["logprobs"][prompt_end:] if lp != 1.0]
assert len(completion_lps) > 0
assert all(lp < 0 for lp in completion_lps)
@pytest.mark.gpu
class TestRealTokenAlignment:
def test_tokens_decode_to_full_text(self, proxy_client, tokenizer_instance):
resp = proxy_client.post("/sessions/create", json={})
uuid = resp.json()["uuid"]
proxy_client.post(
f"/{uuid}/v1/chat/completions",
json={
"messages": [{"role": "user", "content": "Say exactly: test123"}],
"max_tokens": 30,
"temperature": 0.0,
},
)
resp = proxy_client.get(f"/{uuid}/nodes")
node = resp.json()["nodes"][0]
# Lengths must match
assert len(node["tokens"]) == len(node["masked_tokens"])
assert len(node["tokens"]) == len(node["logprobs"])
# Decode tokens and check they match full_text
decoded = tokenizer_instance.decode(node["tokens"])
# The decoded text should be close to (or contain) the full_text
# Exact match may differ due to special token handling, but content should match
assert len(decoded) > 0
@pytest.mark.gpu
class TestRealRender:
def test_render_matches_tokenizer(self, proxy_client, tokenizer_instance):
resp = proxy_client.post("/sessions/create", json={})
uuid = resp.json()["uuid"]
messages = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hello!"},
]
resp = proxy_client.post(
f"/{uuid}/v1/chat/completions/render",
json={"messages": messages},
)
assert resp.status_code == 200
data = resp.json()
# Compare with direct tokenizer rendering
expected_text = tokenizer_instance.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
assert data["prompt_text"] == expected_text
@pytest.mark.gpu
class TestRealSequenceExtension:
def test_multi_turn_extends(self, proxy_client):
resp = proxy_client.post("/sessions/create", json={})
uuid = resp.json()["uuid"]
# Turn 1
resp = proxy_client.post(
f"/{uuid}/v1/chat/completions",
json={
"messages": [{"role": "user", "content": "Say hello"}],
"max_tokens": 20,
"temperature": 0.0,
},
)
assert resp.status_code == 200
turn1_content = resp.json()["choices"][0]["message"]["content"]
# Turn 2 — extends turn 1
resp = proxy_client.post(
f"/{uuid}/v1/chat/completions",
json={
"messages": [
{"role": "user", "content": "Say hello"},
{"role": "assistant", "content": turn1_content},
{"role": "user", "content": "Now say goodbye"},
],
"max_tokens": 20,
"temperature": 0.0,
},
)
assert resp.status_code == 200
# Should have nodes (extension behavior depends on prefix matching)
resp = proxy_client.get(f"/{uuid}/nodes")
nodes = resp.json()["nodes"]
assert len(nodes) >= 1
@pytest.mark.gpu
class TestRealConcurrentSessions:
def test_sessions_independent(self, proxy_client):
"""Multiple sessions should not contaminate each other."""
uuids = []
for _ in range(3):
resp = proxy_client.post("/sessions/create", json={})
uuids.append(resp.json()["uuid"])
# Complete on each
for i, uuid in enumerate(uuids):
resp = proxy_client.post(
f"/{uuid}/v1/chat/completions",
json={
"messages": [{"role": "user", "content": f"Count to {i+1}"}],
"max_tokens": 30,
"temperature": 0.0,
},
)
assert resp.status_code == 200
# Each should have exactly 1 node
for uuid in uuids:
resp = proxy_client.get(f"/{uuid}/nodes")
assert len(resp.json()["nodes"]) == 1
@pytest.mark.gpu
class TestRealOpenAIClientCompat:
def test_openai_client_works(self, proxy_client):
"""Verify the standard openai Python client can talk to our proxy."""
resp = proxy_client.post("/sessions/create", json={})
uuid = resp.json()["uuid"]
# The TestClient doesn't expose a real port, so we test the
# response format is compatible by checking structure manually
resp = proxy_client.post(
f"/{uuid}/v1/chat/completions",
json={
"messages": [{"role": "user", "content": "Hi"}],
"max_tokens": 10,
"temperature": 0.0,
},
)
data = resp.json()
# Verify all fields the openai client expects
assert "id" in data
assert "object" in data
assert data["object"] == "chat.completion"
assert "created" in data
assert "model" in data
assert "choices" in data
assert isinstance(data["choices"], list)
for choice in data["choices"]:
assert "index" in choice
assert "message" in choice
assert "finish_reason" in choice
assert "role" in choice["message"]
assert "content" in choice["message"]

View file

@ -0,0 +1,478 @@
"""Unit tests for ToolCallTranslator — vLLM parser wrapper and lookup table.
These are pure logic tests, no server or model needed. Uses a mock tokenizer.
"""
import json
import pytest
from atroposlib.envs.server_handling.tool_call_translator import ToolCallTranslator
# ---------------------------------------------------------------------------
# Mock tokenizer (same one from test_managed_server.py)
# ---------------------------------------------------------------------------
class MockTokenizer:
def __init__(self):
self.eos_token_id = 2
self.bos_token_id = 1
def encode(self, text, add_special_tokens=True):
tokens = [ord(c) for c in text]
if add_special_tokens:
tokens = [self.bos_token_id] + tokens
return tokens
def decode(self, tokens, skip_special_tokens=False):
if skip_special_tokens:
tokens = [
t for t in tokens if t not in [self.bos_token_id, self.eos_token_id]
]
return "".join([chr(t) if t > 31 else "" for t in tokens])
def get_vocab(self):
# Minimal vocab for the parser — hermes parser calls this
return {chr(i): i for i in range(128)}
def apply_chat_template(
self, messages, tokenize=False, add_generation_prompt=True, tools=None
):
result = ""
if tools:
result += f"<tools>{json.dumps(tools)}</tools>\n"
for msg in messages:
result += f"<{msg['role']}>{msg.get('content', '')}</{msg['role']}>"
if add_generation_prompt:
result += "<assistant>"
if tokenize:
return self.encode(result)
return result
@pytest.fixture
def translator():
tok = MockTokenizer()
return ToolCallTranslator(tokenizer=tok, parser_name="hermes")
# ---------------------------------------------------------------------------
# Outbound: model output → OpenAI format
# ---------------------------------------------------------------------------
class TestParseModelOutput:
def test_single_tool_call(self, translator):
raw = (
'<tool_call>{"name": "search", "arguments": {"query": "cats"}}</tool_call>'
)
content, tool_calls, finish_reason = translator.parse_model_output(
raw,
tool_choice="auto",
tools=[{"type": "function", "function": {"name": "search"}}],
)
assert finish_reason == "tool_calls"
assert tool_calls is not None
assert len(tool_calls) == 1
assert tool_calls[0].function.name == "search"
assert json.loads(tool_calls[0].function.arguments) == {"query": "cats"}
# content is None or empty when full text is a tool call
assert content is None or content.strip() == ""
def test_multiple_tool_calls(self, translator):
raw = (
'<tool_call>{"name": "get_weather", "arguments": {"city": "SF"}}</tool_call>\n'
'<tool_call>{"name": "get_time", "arguments": {"tz": "PST"}}</tool_call>'
)
tools = [
{"type": "function", "function": {"name": "get_weather"}},
{"type": "function", "function": {"name": "get_time"}},
]
content, tool_calls, finish_reason = translator.parse_model_output(
raw, tool_choice="auto", tools=tools
)
assert finish_reason == "tool_calls"
assert len(tool_calls) == 2
names = {tc.function.name for tc in tool_calls}
assert names == {"get_weather", "get_time"}
def test_no_tool_calls(self, translator):
raw = "The weather in SF is 72 degrees."
content, tool_calls, finish_reason = translator.parse_model_output(
raw,
tool_choice="auto",
tools=[{"type": "function", "function": {"name": "search"}}],
)
assert finish_reason == "stop"
assert tool_calls is None
assert content == raw
def test_content_before_tool_call(self, translator):
raw = 'Let me search for that.\n<tool_call>{"name": "search", "arguments": {"query": "cats"}}</tool_call>'
content, tool_calls, finish_reason = translator.parse_model_output(
raw,
tool_choice="auto",
tools=[{"type": "function", "function": {"name": "search"}}],
)
assert finish_reason == "tool_calls"
assert tool_calls is not None
assert len(tool_calls) == 1
# Content before the tool call tag should be preserved
assert content is not None
assert "search for that" in content
def test_tool_choice_none_skips_parsing(self, translator):
raw = (
'<tool_call>{"name": "search", "arguments": {"query": "cats"}}</tool_call>'
)
content, tool_calls, finish_reason = translator.parse_model_output(
raw,
tool_choice="none",
tools=[{"type": "function", "function": {"name": "search"}}],
)
assert finish_reason == "stop"
assert tool_calls is None
assert content == raw # Raw text returned as-is
def test_no_tools_skips_parsing(self, translator):
raw = (
'<tool_call>{"name": "search", "arguments": {"query": "cats"}}</tool_call>'
)
content, tool_calls, finish_reason = translator.parse_model_output(
raw, tool_choice="auto", tools=None
)
assert finish_reason == "stop"
assert tool_calls is None
assert content == raw
def test_malformed_json_graceful_fallback(self, translator):
raw = "<tool_call>not valid json at all</tool_call>"
content, tool_calls, finish_reason = translator.parse_model_output(
raw,
tool_choice="auto",
tools=[{"type": "function", "function": {"name": "search"}}],
)
# Parser should handle gracefully — either no tools or raw content
assert finish_reason == "stop"
assert tool_calls is None
def test_unclosed_tool_call(self, translator):
raw = '<tool_call>{"name": "search", "arguments": {"query": "cats"}}'
content, tool_calls, finish_reason = translator.parse_model_output(
raw,
tool_choice="auto",
tools=[{"type": "function", "function": {"name": "search"}}],
)
# The hermes regex has a branch for unclosed tags
assert finish_reason == "tool_calls"
assert tool_calls is not None
assert len(tool_calls) == 1
def test_nested_json_arguments(self, translator):
args = {
"filter": {
"type": "date",
"range": {"start": "2024-01-01", "end": "2024-12-31"},
}
}
raw = f'<tool_call>{{"name": "search", "arguments": {json.dumps(args)}}}</tool_call>'
content, tool_calls, finish_reason = translator.parse_model_output(
raw,
tool_choice="auto",
tools=[{"type": "function", "function": {"name": "search"}}],
)
assert finish_reason == "tool_calls"
assert json.loads(tool_calls[0].function.arguments) == args
def test_tool_call_with_think_tags(self, translator):
raw = (
"<think>I should search for this information.</think>\n"
'<tool_call>{"name": "search", "arguments": {"query": "cats"}}</tool_call>'
)
content, tool_calls, finish_reason = translator.parse_model_output(
raw,
tool_choice="auto",
tools=[{"type": "function", "function": {"name": "search"}}],
)
assert finish_reason == "tool_calls"
assert tool_calls is not None
# Think content should be in the content field
if content:
assert "think" in content or "search for this" in content
# ---------------------------------------------------------------------------
# Lookup table
# ---------------------------------------------------------------------------
class TestLookupTable:
def test_parse_populates_lookup(self, translator):
raw = (
'<tool_call>{"name": "search", "arguments": {"query": "cats"}}</tool_call>'
)
_, tool_calls, _ = translator.parse_model_output(
raw,
tool_choice="auto",
tools=[{"type": "function", "function": {"name": "search"}}],
)
assert len(translator.call_id_to_raw_text) == 1
tc_id = tool_calls[0].id
assert tc_id in translator.call_id_to_raw_text
assert translator.call_id_to_raw_text[tc_id] == raw
def test_lookup_accumulates(self, translator):
tools = [{"type": "function", "function": {"name": "search"}}]
raw1 = (
'<tool_call>{"name": "search", "arguments": {"query": "cats"}}</tool_call>'
)
_, tc1, _ = translator.parse_model_output(raw1, tool_choice="auto", tools=tools)
raw2 = (
'<tool_call>{"name": "search", "arguments": {"query": "dogs"}}</tool_call>'
)
_, tc2, _ = translator.parse_model_output(raw2, tool_choice="auto", tools=tools)
assert len(translator.call_id_to_raw_text) == 2
assert tc1[0].id in translator.call_id_to_raw_text
assert tc2[0].id in translator.call_id_to_raw_text
# ---------------------------------------------------------------------------
# Inbound: OpenAI messages → raw text
# ---------------------------------------------------------------------------
class TestReconstructRawText:
def test_reconstruct_from_lookup(self, translator):
# First, parse to populate lookup
raw = (
'<tool_call>{"name": "search", "arguments": {"query": "cats"}}</tool_call>'
)
tools = [{"type": "function", "function": {"name": "search"}}]
_, tool_calls, _ = translator.parse_model_output(
raw, tool_choice="auto", tools=tools
)
# Now reconstruct
tc_dicts = [tc.model_dump() for tc in tool_calls]
reconstructed = translator.reconstruct_raw_text_from_tool_calls(tc_dicts)
assert reconstructed == raw
def test_reconstruct_fallback_without_lookup(self, translator):
# Reconstruct without having parsed first — uses fallback
tc_dicts = [
{
"id": "fake-id-123",
"type": "function",
"function": {"name": "search", "arguments": '{"query": "cats"}'},
}
]
reconstructed = translator.reconstruct_raw_text_from_tool_calls(tc_dicts)
assert "<tool_call>" in reconstructed
assert "search" in reconstructed
assert "cats" in reconstructed
def test_reconstruct_empty_list(self, translator):
assert translator.reconstruct_raw_text_from_tool_calls([]) == ""
def test_reconstruct_multiple_tool_calls(self, translator):
tc_dicts = [
{
"id": "id-1",
"type": "function",
"function": {"name": "get_weather", "arguments": '{"city": "SF"}'},
},
{
"id": "id-2",
"type": "function",
"function": {"name": "get_time", "arguments": '{"tz": "PST"}'},
},
]
reconstructed = translator.reconstruct_raw_text_from_tool_calls(tc_dicts)
assert reconstructed.count("<tool_call>") == 2
assert "get_weather" in reconstructed
assert "get_time" in reconstructed
# ---------------------------------------------------------------------------
# Message conversion
# ---------------------------------------------------------------------------
class TestConvertMessages:
def test_regular_messages_pass_through(self, translator):
messages = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hi there"},
]
result = translator.convert_messages_for_template(messages)
assert result == messages
def test_assistant_with_tool_calls_reconstructed(self, translator):
# Parse first to populate lookup
raw = (
'<tool_call>{"name": "search", "arguments": {"query": "cats"}}</tool_call>'
)
tools = [{"type": "function", "function": {"name": "search"}}]
_, tool_calls, _ = translator.parse_model_output(
raw, tool_choice="auto", tools=tools
)
messages = [
{"role": "user", "content": "Search for cats"},
{
"role": "assistant",
"content": None,
"tool_calls": [tc.model_dump() for tc in tool_calls],
},
{
"role": "tool",
"tool_call_id": tool_calls[0].id,
"content": "Found 5 cats",
},
]
result = translator.convert_messages_for_template(messages)
# User message unchanged
assert result[0] == messages[0]
# Assistant message reconstructed to raw text
assert result[1]["role"] == "assistant"
assert "<tool_call>" in result[1]["content"]
assert "tool_calls" not in result[1]
# Tool message passed through
assert result[2] == messages[2]
def test_assistant_with_content_and_tool_calls(self, translator):
messages = [
{
"role": "assistant",
"content": "Let me search.",
"tool_calls": [
{
"id": "fake-id",
"type": "function",
"function": {"name": "search", "arguments": '{"q": "x"}'},
}
],
},
]
result = translator.convert_messages_for_template(messages)
assert result[0]["role"] == "assistant"
assert "Let me search." in result[0]["content"]
assert "<tool_call>" in result[0]["content"]
def test_mixed_message_types(self, translator):
"""Only tool_call assistant messages are reconstructed."""
messages = [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello!"}, # regular, no tool_calls
{"role": "user", "content": "Search cats"},
{
"role": "assistant",
"tool_calls": [
{
"id": "tc-1",
"type": "function",
"function": {"name": "search", "arguments": '{"q": "cats"}'},
}
],
},
{"role": "tool", "tool_call_id": "tc-1", "content": "5 results"},
{"role": "assistant", "content": "Found 5 cats!"}, # regular again
]
result = translator.convert_messages_for_template(messages)
# Messages at indices 0, 1, 2, 4, 5 should be unchanged
assert result[0] == messages[0]
assert result[1] == messages[1]
assert result[2] == messages[2]
assert result[4] == messages[4]
assert result[5] == messages[5]
# Message at index 3 should be reconstructed
assert "<tool_call>" in result[3]["content"]
# ---------------------------------------------------------------------------
# Roundtrip
# ---------------------------------------------------------------------------
class TestRoundtrip:
def test_single_tool_call_roundtrip(self, translator):
raw = (
'<tool_call>{"name": "search", "arguments": {"query": "cats"}}</tool_call>'
)
tools = [{"type": "function", "function": {"name": "search"}}]
# Parse
_, tool_calls, _ = translator.parse_model_output(
raw, tool_choice="auto", tools=tools
)
# Reconstruct
tc_dicts = [tc.model_dump() for tc in tool_calls]
reconstructed = translator.reconstruct_raw_text_from_tool_calls(tc_dicts)
assert reconstructed == raw
def test_tool_call_empty_arguments(self, translator):
raw = '<tool_call>{"name": "list_all", "arguments": {}}</tool_call>'
tools = [{"type": "function", "function": {"name": "list_all"}}]
_, tool_calls, _ = translator.parse_model_output(
raw, tool_choice="auto", tools=tools
)
assert tool_calls is not None
assert json.loads(tool_calls[0].function.arguments) == {}
# ---------------------------------------------------------------------------
# Decode with tool awareness
# ---------------------------------------------------------------------------
class TestDecodeToolAwareness:
def test_decode_without_tools(self, translator):
tokens = [72, 101, 108, 108, 111] # "Hello"
text = translator.decode_with_tool_awareness(tokens, has_tools=False)
assert text == "Hello"
def test_decode_with_tools_preserves_special(self, translator):
# With the mock tokenizer there are no "special" tokens to strip,
# but verify the flag is passed correctly
tokens = [72, 101, 108, 108, 111]
text = translator.decode_with_tool_awareness(tokens, has_tools=True)
assert text == "Hello"
def test_decode_strips_bos_without_tools(self, translator):
tokens = [1, 72, 101, 108, 108, 111] # BOS + "Hello"
text = translator.decode_with_tool_awareness(tokens, has_tools=False)
assert text == "Hello" # BOS stripped
def test_decode_keeps_bos_with_tools(self, translator):
tokens = [1, 72, 101, 108, 108, 111] # BOS + "Hello"
text = translator.decode_with_tool_awareness(tokens, has_tools=True)
# BOS (chr(1)) is not printable so mock tokenizer returns "" for it
# but the flag skip_special_tokens=False is passed
assert "Hello" in text