mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
add tool call parsing based on vllm impl and an openai server endpoint
This commit is contained in:
parent
887a94374c
commit
add42a2afb
11 changed files with 3370 additions and 34 deletions
|
|
@ -62,6 +62,7 @@ class ManagedServer:
|
|||
server: APIServer,
|
||||
tokenizer: Optional[Any] = None,
|
||||
track_tree: bool = False,
|
||||
tool_parser: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the managed server.
|
||||
|
|
@ -73,10 +74,17 @@ class ManagedServer:
|
|||
track_tree: If True, maintains a tree structure with parent-child links
|
||||
(for multi-turn RL with per-step advantages). If False (default),
|
||||
maintains a simple list of current nodes that updates in-place.
|
||||
tool_parser: Optional vLLM tool parser name (e.g. "hermes", "llama3_json",
|
||||
"mistral", etc.). If provided, enables tool call support in
|
||||
chat_completion(). The parser handles extraction of structured
|
||||
tool calls from raw model output. See
|
||||
ToolParserManager.list_registered() for available parsers.
|
||||
"""
|
||||
self.server = server
|
||||
self.tokenizer = tokenizer
|
||||
self.track_tree = track_tree
|
||||
self._tool_parser_name = tool_parser
|
||||
self._translator = None # Lazy init
|
||||
|
||||
# Initialize storage based on mode
|
||||
if track_tree:
|
||||
|
|
@ -107,19 +115,57 @@ class ManagedServer:
|
|||
)
|
||||
self.tokenizer = None
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
|
||||
def _get_translator(self):
|
||||
"""Lazily create the ToolCallTranslator when first needed.
|
||||
|
||||
Returns None if tool_parser was not specified or if vLLM is not
|
||||
installed (the translator will warn on creation in that case).
|
||||
"""
|
||||
if self._translator is None and self._tool_parser_name and self.tokenizer:
|
||||
try:
|
||||
from atroposlib.envs.server_handling.tool_call_translator import (
|
||||
ToolCallTranslator,
|
||||
)
|
||||
|
||||
self._translator = ToolCallTranslator(
|
||||
tokenizer=self.tokenizer,
|
||||
parser_name=self._tool_parser_name,
|
||||
)
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
f"Failed to create ToolCallTranslator: {e}. "
|
||||
"Tool call parsing will be disabled.",
|
||||
stacklevel=2,
|
||||
)
|
||||
self._tool_parser_name = None # Don't retry
|
||||
return None
|
||||
return self._translator
|
||||
|
||||
def _convert_messages_to_prompt(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Convert chat messages to prompt text using tokenizer's chat template.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'
|
||||
tools: Optional list of tool definitions (OpenAI format). Passed to
|
||||
apply_chat_template() so the template can inject tool defs
|
||||
into the system prompt.
|
||||
|
||||
Returns:
|
||||
Formatted prompt string
|
||||
"""
|
||||
# If tools are active and we have a translator, convert any assistant
|
||||
# messages with tool_calls back to raw text first
|
||||
if tools and self._get_translator():
|
||||
messages = self._get_translator().convert_messages_for_template(messages)
|
||||
|
||||
if self.tokenizer is None:
|
||||
# Fallback: simple concatenation
|
||||
return "\n".join([f"{m['role']}: {m['content']}" for m in messages])
|
||||
return "\n".join([f"{m['role']}: {m.get('content', '')}" for m in messages])
|
||||
|
||||
if hasattr(self.tokenizer, "apply_chat_template"):
|
||||
# Only add generation prompt if last message is not from assistant
|
||||
|
|
@ -127,13 +173,19 @@ class ManagedServer:
|
|||
len(messages) == 0 or messages[-1].get("role") != "assistant"
|
||||
)
|
||||
|
||||
# Build kwargs
|
||||
template_kwargs = {
|
||||
"tokenize": False,
|
||||
"add_generation_prompt": add_generation_prompt,
|
||||
}
|
||||
if tools:
|
||||
template_kwargs["tools"] = tools
|
||||
|
||||
# Use the tokenizer's chat template
|
||||
return self.tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=add_generation_prompt
|
||||
)
|
||||
return self.tokenizer.apply_chat_template(messages, **template_kwargs)
|
||||
else:
|
||||
# Fallback for tokenizers without chat template
|
||||
return "\n".join([f"{m['role']}: {m['content']}" for m in messages])
|
||||
return "\n".join([f"{m['role']}: {m.get('content', '')}" for m in messages])
|
||||
|
||||
def _debug_requests_enabled(self) -> bool:
|
||||
"""Enable verbose request construction logs with ATROPOS_DEBUG_REQUESTS=1."""
|
||||
|
|
@ -268,15 +320,35 @@ class ManagedServer:
|
|||
Internally converts to prompt, calls tokens_and_logprobs_completion,
|
||||
tracks the sequence, and reconstructs a ChatCompletion response.
|
||||
|
||||
Supports tool calling when a tool_parser was provided at init:
|
||||
- Accepts `tools` and `tool_choice` kwargs
|
||||
- Converts inbound assistant tool_call messages to raw text
|
||||
- Parses outbound model output for tool calls
|
||||
- Returns ChatCompletion with proper tool_calls in choices
|
||||
- Preserves raw text in tracked nodes (tool parsing is response-only)
|
||||
|
||||
Args:
|
||||
**kwargs: Standard chat completion kwargs (messages, n, etc.)
|
||||
**kwargs: Standard chat completion kwargs (messages, n, max_tokens,
|
||||
temperature, tools, tool_choice, etc.)
|
||||
|
||||
Returns:
|
||||
ChatCompletion response
|
||||
ChatCompletion response (with tool_calls if detected)
|
||||
"""
|
||||
# Get input text
|
||||
# Extract tool-related kwargs
|
||||
tools = kwargs.pop("tools", None)
|
||||
tool_choice = kwargs.pop("tool_choice", None)
|
||||
has_tools = bool(tools) and self._get_translator() is not None
|
||||
|
||||
# Default tool_choice to "auto" if tools provided
|
||||
if has_tools and tool_choice is None:
|
||||
tool_choice = "auto"
|
||||
|
||||
# Get input text — passes tools for template rendering and
|
||||
# handles reconstruction of inbound tool_call messages
|
||||
messages = kwargs.get("messages", [])
|
||||
prompt = self._convert_messages_to_prompt(messages)
|
||||
prompt = self._convert_messages_to_prompt(
|
||||
messages, tools=tools if has_tools else None
|
||||
)
|
||||
|
||||
# Handle parent node and extending logic based on mode
|
||||
if self.track_tree:
|
||||
|
|
@ -296,11 +368,12 @@ class ManagedServer:
|
|||
msg_count = len(messages)
|
||||
prompt_preview = prompt.replace("\n", "\\n")[:600]
|
||||
logger.debug(
|
||||
"[ATROPOS_REQ_DEBUG] chat_completion messages=%s n=%s max_tokens=%s temperature=%s",
|
||||
"[ATROPOS_REQ_DEBUG] chat_completion messages=%s n=%s max_tokens=%s temperature=%s tools=%s",
|
||||
msg_count,
|
||||
completion_kwargs.get("n"),
|
||||
completion_kwargs.get("max_tokens"),
|
||||
completion_kwargs.get("temperature"),
|
||||
bool(tools),
|
||||
)
|
||||
logger.debug("[ATROPOS_REQ_DEBUG] prompt_preview=%r", prompt_preview)
|
||||
|
||||
|
|
@ -336,15 +409,18 @@ class ManagedServer:
|
|||
else:
|
||||
finish_reason = finish_reason_raw
|
||||
|
||||
# Decode completion text
|
||||
# Decode completion text — use skip_special_tokens=False when
|
||||
# tools are active so <tool_call> tags aren't stripped
|
||||
if self.tokenizer is not None:
|
||||
completion_text = self.tokenizer.decode(
|
||||
output_tokens, skip_special_tokens=True
|
||||
output_tokens,
|
||||
skip_special_tokens=not has_tools,
|
||||
)
|
||||
else:
|
||||
completion_text = "".join([chr(t) for t in output_tokens if t > 31])
|
||||
|
||||
# Create and store sequence node
|
||||
# Create and store sequence node — always uses the raw text,
|
||||
# tool parsing only affects the ChatCompletion response
|
||||
node = self._create_sequence_node(
|
||||
input_text=prompt,
|
||||
parent_node=parent_node,
|
||||
|
|
@ -373,14 +449,50 @@ class ManagedServer:
|
|||
# New context - append to list
|
||||
self.current_nodes.append(node)
|
||||
|
||||
# Parse tool calls from raw output if tools are active
|
||||
tool_calls_parsed = None
|
||||
content_for_response = completion_text
|
||||
if has_tools and tool_choice != "none":
|
||||
translator = self._get_translator()
|
||||
content_for_response, tool_calls_parsed, finish_reason = (
|
||||
translator.parse_model_output(
|
||||
raw_text=completion_text,
|
||||
tool_choice=(
|
||||
tool_choice if isinstance(tool_choice, str) else "auto"
|
||||
),
|
||||
tools=tools,
|
||||
)
|
||||
)
|
||||
|
||||
# Build choice
|
||||
message_kwargs = {
|
||||
"content": content_for_response,
|
||||
"role": "assistant",
|
||||
}
|
||||
# Note: openai's ChatCompletionMessage model handles tool_calls
|
||||
# but we can't pass them through the constructor easily. We'll
|
||||
# attach them after construction if needed.
|
||||
choice = Choice(
|
||||
finish_reason=finish_reason,
|
||||
index=i,
|
||||
message=ChatCompletionMessage(
|
||||
content=completion_text, role="assistant"
|
||||
),
|
||||
message=ChatCompletionMessage(**message_kwargs),
|
||||
)
|
||||
|
||||
# Attach tool_calls to the message if present
|
||||
if tool_calls_parsed:
|
||||
choice.message.tool_calls = [
|
||||
# Convert vLLM ToolCall to openai ToolCall format
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
},
|
||||
}
|
||||
for tc in tool_calls_parsed
|
||||
]
|
||||
|
||||
choices.append(choice)
|
||||
|
||||
# Construct ChatCompletion response
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue