add tool call parsing based on vllm impl and an openai server endpoint

This commit is contained in:
dmahan93 2026-03-02 23:17:13 -06:00
parent 887a94374c
commit add42a2afb
11 changed files with 3370 additions and 34 deletions

View file

@ -62,6 +62,7 @@ class ManagedServer:
server: APIServer,
tokenizer: Optional[Any] = None,
track_tree: bool = False,
tool_parser: Optional[str] = None,
):
"""
Initialize the managed server.
@ -73,10 +74,17 @@ class ManagedServer:
track_tree: If True, maintains a tree structure with parent-child links
(for multi-turn RL with per-step advantages). If False (default),
maintains a simple list of current nodes that updates in-place.
tool_parser: Optional vLLM tool parser name (e.g. "hermes", "llama3_json",
"mistral", etc.). If provided, enables tool call support in
chat_completion(). The parser handles extraction of structured
tool calls from raw model output. See
ToolParserManager.list_registered() for available parsers.
"""
self.server = server
self.tokenizer = tokenizer
self.track_tree = track_tree
self._tool_parser_name = tool_parser
self._translator = None # Lazy init
# Initialize storage based on mode
if track_tree:
@ -107,19 +115,57 @@ class ManagedServer:
)
self.tokenizer = None
def _convert_messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
def _get_translator(self):
"""Lazily create the ToolCallTranslator when first needed.
Returns None if tool_parser was not specified or if vLLM is not
installed (the translator will warn on creation in that case).
"""
if self._translator is None and self._tool_parser_name and self.tokenizer:
try:
from atroposlib.envs.server_handling.tool_call_translator import (
ToolCallTranslator,
)
self._translator = ToolCallTranslator(
tokenizer=self.tokenizer,
parser_name=self._tool_parser_name,
)
except Exception as e:
warnings.warn(
f"Failed to create ToolCallTranslator: {e}. "
"Tool call parsing will be disabled.",
stacklevel=2,
)
self._tool_parser_name = None # Don't retry
return None
return self._translator
def _convert_messages_to_prompt(
self,
messages: List[Dict[str, str]],
tools: Optional[List[dict]] = None,
) -> str:
"""
Convert chat messages to prompt text using tokenizer's chat template.
Args:
messages: List of message dicts with 'role' and 'content'
tools: Optional list of tool definitions (OpenAI format). Passed to
apply_chat_template() so the template can inject tool defs
into the system prompt.
Returns:
Formatted prompt string
"""
# If tools are active and we have a translator, convert any assistant
# messages with tool_calls back to raw text first
if tools and self._get_translator():
messages = self._get_translator().convert_messages_for_template(messages)
if self.tokenizer is None:
# Fallback: simple concatenation
return "\n".join([f"{m['role']}: {m['content']}" for m in messages])
return "\n".join([f"{m['role']}: {m.get('content', '')}" for m in messages])
if hasattr(self.tokenizer, "apply_chat_template"):
# Only add generation prompt if last message is not from assistant
@ -127,13 +173,19 @@ class ManagedServer:
len(messages) == 0 or messages[-1].get("role") != "assistant"
)
# Build kwargs
template_kwargs = {
"tokenize": False,
"add_generation_prompt": add_generation_prompt,
}
if tools:
template_kwargs["tools"] = tools
# Use the tokenizer's chat template
return self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=add_generation_prompt
)
return self.tokenizer.apply_chat_template(messages, **template_kwargs)
else:
# Fallback for tokenizers without chat template
return "\n".join([f"{m['role']}: {m['content']}" for m in messages])
return "\n".join([f"{m['role']}: {m.get('content', '')}" for m in messages])
def _debug_requests_enabled(self) -> bool:
"""Enable verbose request construction logs with ATROPOS_DEBUG_REQUESTS=1."""
@ -268,15 +320,35 @@ class ManagedServer:
Internally converts to prompt, calls tokens_and_logprobs_completion,
tracks the sequence, and reconstructs a ChatCompletion response.
Supports tool calling when a tool_parser was provided at init:
- Accepts `tools` and `tool_choice` kwargs
- Converts inbound assistant tool_call messages to raw text
- Parses outbound model output for tool calls
- Returns ChatCompletion with proper tool_calls in choices
- Preserves raw text in tracked nodes (tool parsing is response-only)
Args:
**kwargs: Standard chat completion kwargs (messages, n, etc.)
**kwargs: Standard chat completion kwargs (messages, n, max_tokens,
temperature, tools, tool_choice, etc.)
Returns:
ChatCompletion response
ChatCompletion response (with tool_calls if detected)
"""
# Get input text
# Extract tool-related kwargs
tools = kwargs.pop("tools", None)
tool_choice = kwargs.pop("tool_choice", None)
has_tools = bool(tools) and self._get_translator() is not None
# Default tool_choice to "auto" if tools provided
if has_tools and tool_choice is None:
tool_choice = "auto"
# Get input text — passes tools for template rendering and
# handles reconstruction of inbound tool_call messages
messages = kwargs.get("messages", [])
prompt = self._convert_messages_to_prompt(messages)
prompt = self._convert_messages_to_prompt(
messages, tools=tools if has_tools else None
)
# Handle parent node and extending logic based on mode
if self.track_tree:
@ -296,11 +368,12 @@ class ManagedServer:
msg_count = len(messages)
prompt_preview = prompt.replace("\n", "\\n")[:600]
logger.debug(
"[ATROPOS_REQ_DEBUG] chat_completion messages=%s n=%s max_tokens=%s temperature=%s",
"[ATROPOS_REQ_DEBUG] chat_completion messages=%s n=%s max_tokens=%s temperature=%s tools=%s",
msg_count,
completion_kwargs.get("n"),
completion_kwargs.get("max_tokens"),
completion_kwargs.get("temperature"),
bool(tools),
)
logger.debug("[ATROPOS_REQ_DEBUG] prompt_preview=%r", prompt_preview)
@ -336,15 +409,18 @@ class ManagedServer:
else:
finish_reason = finish_reason_raw
# Decode completion text
# Decode completion text — use skip_special_tokens=False when
# tools are active so <tool_call> tags aren't stripped
if self.tokenizer is not None:
completion_text = self.tokenizer.decode(
output_tokens, skip_special_tokens=True
output_tokens,
skip_special_tokens=not has_tools,
)
else:
completion_text = "".join([chr(t) for t in output_tokens if t > 31])
# Create and store sequence node
# Create and store sequence node — always uses the raw text,
# tool parsing only affects the ChatCompletion response
node = self._create_sequence_node(
input_text=prompt,
parent_node=parent_node,
@ -373,14 +449,50 @@ class ManagedServer:
# New context - append to list
self.current_nodes.append(node)
# Parse tool calls from raw output if tools are active
tool_calls_parsed = None
content_for_response = completion_text
if has_tools and tool_choice != "none":
translator = self._get_translator()
content_for_response, tool_calls_parsed, finish_reason = (
translator.parse_model_output(
raw_text=completion_text,
tool_choice=(
tool_choice if isinstance(tool_choice, str) else "auto"
),
tools=tools,
)
)
# Build choice
message_kwargs = {
"content": content_for_response,
"role": "assistant",
}
# Note: openai's ChatCompletionMessage model handles tool_calls
# but we can't pass them through the constructor easily. We'll
# attach them after construction if needed.
choice = Choice(
finish_reason=finish_reason,
index=i,
message=ChatCompletionMessage(
content=completion_text, role="assistant"
),
message=ChatCompletionMessage(**message_kwargs),
)
# Attach tool_calls to the message if present
if tool_calls_parsed:
choice.message.tool_calls = [
# Convert vLLM ToolCall to openai ToolCall format
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
for tc in tool_calls_parsed
]
choices.append(choice)
# Construct ChatCompletion response