mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
add tool call parsing based on vllm impl and an openai server endpoint
This commit is contained in:
parent
887a94374c
commit
add42a2afb
11 changed files with 3370 additions and 34 deletions
|
|
@ -493,6 +493,295 @@ async def test_multi_turn_chat_with_branching(mock_server):
|
|||
assert f"More{actual_i}" in node.full_text # Has third turn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool call support in ManagedServer.chat_completion()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockTokenizerWithTools(MockTokenizer):
|
||||
"""Extended mock tokenizer that supports tools kwarg in apply_chat_template."""
|
||||
|
||||
def apply_chat_template(
|
||||
self, messages, tokenize=False, add_generation_prompt=True, tools=None
|
||||
):
|
||||
result = ""
|
||||
if tools:
|
||||
import json
|
||||
|
||||
result += f"<tools>{json.dumps(tools)}</tools>\n"
|
||||
for msg in messages:
|
||||
content = msg.get("content", "") or ""
|
||||
result += f"<{msg['role']}>{content}</{msg['role']}>"
|
||||
if add_generation_prompt:
|
||||
result += "<assistant>"
|
||||
if tokenize:
|
||||
return self.encode(result)
|
||||
return result
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_server_with_tools():
|
||||
"""Mock server with tool-aware tokenizer."""
|
||||
server = ServerHarness()
|
||||
server.tokenizer = MockTokenizerWithTools()
|
||||
|
||||
class Config:
|
||||
model_name = "test_model"
|
||||
|
||||
server.config = Config()
|
||||
return server
|
||||
|
||||
|
||||
def _setup_chat_completion(server, tokenizer, messages, output_texts, tools=None):
|
||||
"""Helper: set up mock tokens_and_logprobs for a chat_completion call."""
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True, tools=tools
|
||||
)
|
||||
prompt_tokens = tokenizer.encode(prompt)
|
||||
output_tokens_list = [[ord(c) for c in text] for text in output_texts]
|
||||
output_logprobs_list = [[-0.1] * len(tokens) for tokens in output_tokens_list]
|
||||
finish_reasons = ["stop"] * len(output_texts)
|
||||
|
||||
server.set_tokens_and_logprobs_response(
|
||||
prompt=prompt,
|
||||
prompt_tokens=prompt_tokens,
|
||||
output_tokens_list=output_tokens_list,
|
||||
output_logprobs_list=output_logprobs_list,
|
||||
finish_reasons=finish_reasons,
|
||||
)
|
||||
return prompt
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_parsing_outbound(mock_server_with_tools):
|
||||
"""Model generates <tool_call> → chat_completion returns structured tool_calls."""
|
||||
managed = ManagedServer(
|
||||
mock_server_with_tools,
|
||||
tokenizer=mock_server_with_tools.tokenizer,
|
||||
tool_parser="hermes",
|
||||
)
|
||||
|
||||
tools = [{"type": "function", "function": {"name": "search", "parameters": {}}}]
|
||||
messages = [{"role": "user", "content": "Search cats"}]
|
||||
raw_output = (
|
||||
'<tool_call>{"name": "search", "arguments": {"query": "cats"}}</tool_call>'
|
||||
)
|
||||
|
||||
_setup_chat_completion(
|
||||
mock_server_with_tools,
|
||||
mock_server_with_tools.tokenizer,
|
||||
messages,
|
||||
[raw_output],
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
result = await managed.chat_completion(
|
||||
messages=messages, tools=tools, tool_choice="auto"
|
||||
)
|
||||
|
||||
assert len(result.choices) == 1
|
||||
choice = result.choices[0]
|
||||
assert choice.finish_reason == "tool_calls"
|
||||
assert choice.message.tool_calls is not None
|
||||
assert len(choice.message.tool_calls) == 1
|
||||
tc = choice.message.tool_calls[0]
|
||||
assert tc["function"]["name"] == "search"
|
||||
|
||||
# Node should have raw text (not parsed)
|
||||
state = managed.get_state()
|
||||
assert len(state["nodes"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_choice_none_skips(mock_server_with_tools):
|
||||
"""tool_choice='none' returns raw text, no parsing."""
|
||||
managed = ManagedServer(
|
||||
mock_server_with_tools,
|
||||
tokenizer=mock_server_with_tools.tokenizer,
|
||||
tool_parser="hermes",
|
||||
)
|
||||
|
||||
tools = [{"type": "function", "function": {"name": "search", "parameters": {}}}]
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
raw_output = '<tool_call>{"name": "search", "arguments": {"q": "x"}}</tool_call>'
|
||||
|
||||
_setup_chat_completion(
|
||||
mock_server_with_tools,
|
||||
mock_server_with_tools.tokenizer,
|
||||
messages,
|
||||
[raw_output],
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
result = await managed.chat_completion(
|
||||
messages=messages, tools=tools, tool_choice="none"
|
||||
)
|
||||
|
||||
assert result.choices[0].message.tool_calls is None
|
||||
assert result.choices[0].finish_reason == "stop"
|
||||
# Raw text should be content
|
||||
assert "<tool_call>" in result.choices[0].message.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_tool_parser_passes_through(mock_server_with_tools):
|
||||
"""Without tool_parser, tools kwarg is ignored — no parsing."""
|
||||
managed = ManagedServer(
|
||||
mock_server_with_tools,
|
||||
tokenizer=mock_server_with_tools.tokenizer,
|
||||
# No tool_parser
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
raw_output = '<tool_call>{"name": "search", "arguments": {"q": "x"}}</tool_call>'
|
||||
|
||||
_setup_chat_completion(
|
||||
mock_server_with_tools, mock_server_with_tools.tokenizer, messages, [raw_output]
|
||||
)
|
||||
|
||||
result = await managed.chat_completion(messages=messages)
|
||||
|
||||
# No tool parsing — raw text as content
|
||||
assert result.choices[0].message.tool_calls is None
|
||||
assert result.choices[0].finish_reason == "stop"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_multi_turn_extends_node(mock_server_with_tools):
|
||||
"""Multi-turn with tool calls should extend to 1 node."""
|
||||
managed = ManagedServer(
|
||||
mock_server_with_tools,
|
||||
tokenizer=mock_server_with_tools.tokenizer,
|
||||
tool_parser="hermes",
|
||||
)
|
||||
tok = mock_server_with_tools.tokenizer
|
||||
tools = [{"type": "function", "function": {"name": "search", "parameters": {}}}]
|
||||
|
||||
# Step 1: user → tool_call
|
||||
messages_1 = [{"role": "user", "content": "Search cats"}]
|
||||
output_1 = '<tool_call>{"name": "search", "arguments": {"q": "cats"}}</tool_call>'
|
||||
_setup_chat_completion(
|
||||
mock_server_with_tools, tok, messages_1, [output_1], tools=tools
|
||||
)
|
||||
|
||||
result_1 = await managed.chat_completion(
|
||||
messages=messages_1, tools=tools, tool_choice="auto"
|
||||
)
|
||||
tc_1 = result_1.choices[0].message.tool_calls
|
||||
|
||||
assert len(managed.get_state()["nodes"]) == 1
|
||||
|
||||
# Step 2: include tool result → plain response
|
||||
# Reconstruct the assistant message with tool_calls for the translator
|
||||
messages_2 = [
|
||||
{"role": "user", "content": "Search cats"},
|
||||
{"role": "assistant", "content": None, "tool_calls": tc_1},
|
||||
{"role": "tool", "tool_call_id": tc_1[0]["id"], "content": "Found 5 cats"},
|
||||
]
|
||||
|
||||
# The translator will reconstruct the tool_call to raw text,
|
||||
# so we need the prompt to match what it produces
|
||||
output_2 = "Here are 5 cats!"
|
||||
prompt_2 = tok.apply_chat_template(
|
||||
managed._get_translator().convert_messages_for_template(messages_2),
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
tools=tools,
|
||||
)
|
||||
prompt_tokens_2 = tok.encode(prompt_2)
|
||||
output_tokens_2 = [ord(c) for c in output_2]
|
||||
mock_server_with_tools.set_tokens_and_logprobs_response(
|
||||
prompt=prompt_2,
|
||||
prompt_tokens=prompt_tokens_2,
|
||||
output_tokens_list=[output_tokens_2],
|
||||
output_logprobs_list=[[-0.1] * len(output_tokens_2)],
|
||||
finish_reasons=["stop"],
|
||||
)
|
||||
|
||||
result_2 = await managed.chat_completion(
|
||||
messages=messages_2, tools=tools, tool_choice="auto"
|
||||
)
|
||||
assert result_2.choices[0].message.content == output_2
|
||||
|
||||
# Still 1 node — step 2 extended step 1
|
||||
assert len(managed.get_state()["nodes"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_multiple_tools_parsed(mock_server_with_tools):
|
||||
"""Multiple tool calls in one response are all parsed."""
|
||||
managed = ManagedServer(
|
||||
mock_server_with_tools,
|
||||
tokenizer=mock_server_with_tools.tokenizer,
|
||||
tool_parser="hermes",
|
||||
)
|
||||
|
||||
tools = [
|
||||
{"type": "function", "function": {"name": "get_weather", "parameters": {}}},
|
||||
{"type": "function", "function": {"name": "get_time", "parameters": {}}},
|
||||
]
|
||||
messages = [{"role": "user", "content": "Weather and time?"}]
|
||||
raw_output = (
|
||||
'<tool_call>{"name": "get_weather", "arguments": {"city": "SF"}}</tool_call>\n'
|
||||
'<tool_call>{"name": "get_time", "arguments": {"tz": "PST"}}</tool_call>'
|
||||
)
|
||||
_setup_chat_completion(
|
||||
mock_server_with_tools,
|
||||
mock_server_with_tools.tokenizer,
|
||||
messages,
|
||||
[raw_output],
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
result = await managed.chat_completion(
|
||||
messages=messages, tools=tools, tool_choice="auto"
|
||||
)
|
||||
|
||||
assert result.choices[0].finish_reason == "tool_calls"
|
||||
assert len(result.choices[0].message.tool_calls) == 2
|
||||
names = {tc["function"]["name"] for tc in result.choices[0].message.tool_calls}
|
||||
assert names == {"get_weather", "get_time"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_node_masking(mock_server_with_tools):
|
||||
"""Nodes have proper masking even with tool parsing active."""
|
||||
managed = ManagedServer(
|
||||
mock_server_with_tools,
|
||||
tokenizer=mock_server_with_tools.tokenizer,
|
||||
tool_parser="hermes",
|
||||
)
|
||||
|
||||
tools = [{"type": "function", "function": {"name": "search", "parameters": {}}}]
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
raw_output = '<tool_call>{"name": "search", "arguments": {"q": "x"}}</tool_call>'
|
||||
|
||||
_setup_chat_completion(
|
||||
mock_server_with_tools,
|
||||
mock_server_with_tools.tokenizer,
|
||||
messages,
|
||||
[raw_output],
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
await managed.chat_completion(messages=messages, tools=tools)
|
||||
|
||||
node = managed.get_state()["nodes"][0]
|
||||
|
||||
# Lengths must match
|
||||
assert len(node.tokens) == len(node.masked_tokens) == len(node.logprobs)
|
||||
|
||||
# Should have masked prompt tokens and actual completion tokens
|
||||
num_masked = sum(1 for t in node.masked_tokens if t == -100)
|
||||
num_actual = sum(1 for t in node.masked_tokens if t != -100)
|
||||
assert num_masked > 0
|
||||
assert num_actual > 0
|
||||
|
||||
# Prompt logprobs = 1.0, completion logprobs < 0
|
||||
assert all(lp == 1.0 for lp in node.logprobs[:num_masked])
|
||||
assert all(lp < 0 for lp in node.logprobs[num_masked:])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run tests
|
||||
pytest.main([__file__, "-v"])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue