mirror of
https://github.com/NousResearch/atropos.git
synced 2026-05-02 17:45:50 +00:00
Enhance ManagedServer and ServerManager to support tool call parsing. Added optional tool_call_parser parameter to ManagedServer and managed_server context manager, enabling structured extraction of tool calls from raw completion text. Updated _convert_messages_to_prompt to forward additional template kwargs for improved prompt construction.
This commit is contained in:
parent
7da681ec46
commit
49199fa6b7
2 changed files with 94 additions and 15 deletions
|
|
@ -58,6 +58,7 @@ class ManagedServer:
|
|||
server: APIServer,
|
||||
tokenizer: Optional[Any] = None,
|
||||
track_tree: bool = False,
|
||||
tool_call_parser=None,
|
||||
):
|
||||
"""
|
||||
Initialize the managed server.
|
||||
|
|
@ -69,10 +70,14 @@ class ManagedServer:
|
|||
track_tree: If True, maintains a tree structure with parent-child links
|
||||
(for multi-turn RL with per-step advantages). If False (default),
|
||||
maintains a simple list of current nodes that updates in-place.
|
||||
tool_call_parser: Optional tool call parser instance that extracts structured
|
||||
tool_calls from raw completion text. Used in Phase 2 when
|
||||
/generate returns raw tokens without tool call parsing.
|
||||
"""
|
||||
self.server = server
|
||||
self.tokenizer = tokenizer
|
||||
self.track_tree = track_tree
|
||||
self.tool_call_parser = tool_call_parser
|
||||
|
||||
# Initialize storage based on mode
|
||||
if track_tree:
|
||||
|
|
@ -103,12 +108,18 @@ class ManagedServer:
|
|||
)
|
||||
self.tokenizer = None
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
|
||||
# Kwargs that should be forwarded to apply_chat_template (not to the completion endpoint)
|
||||
# Models like Kimi-K2, Qwen3 use these for tool formatting, reasoning block handling, etc.
|
||||
CHAT_TEMPLATE_KWARGS = {"tools", "thinking", "enable_thinking", "documents"}
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: List[Dict[str, str]], **template_kwargs) -> str:
|
||||
"""
|
||||
Convert chat messages to prompt text using tokenizer's chat template.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'
|
||||
**template_kwargs: Additional kwargs to forward to apply_chat_template
|
||||
(e.g., tools=, thinking=, enable_thinking=)
|
||||
|
||||
Returns:
|
||||
Formatted prompt string
|
||||
|
|
@ -123,10 +134,18 @@ class ManagedServer:
|
|||
len(messages) == 0 or messages[-1].get("role") != "assistant"
|
||||
)
|
||||
|
||||
# Build kwargs for apply_chat_template, forwarding tools/thinking/etc.
|
||||
ct_kwargs = {
|
||||
"tokenize": False,
|
||||
"add_generation_prompt": add_generation_prompt,
|
||||
}
|
||||
# Forward any template-relevant kwargs (tools, thinking, etc.)
|
||||
for key, value in template_kwargs.items():
|
||||
if value is not None:
|
||||
ct_kwargs[key] = value
|
||||
|
||||
# Use the tokenizer's chat template
|
||||
return self.tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=add_generation_prompt
|
||||
)
|
||||
return self.tokenizer.apply_chat_template(messages, **ct_kwargs)
|
||||
else:
|
||||
# Fallback for tokenizers without chat template
|
||||
return "\n".join([f"{m['role']}: {m['content']}" for m in messages])
|
||||
|
|
@ -260,15 +279,28 @@ class ManagedServer:
|
|||
Internally converts to prompt, calls tokens_and_logprobs_completion,
|
||||
tracks the sequence, and reconstructs a ChatCompletion response.
|
||||
|
||||
Supports forwarding tools=, thinking=, enable_thinking=, etc. to
|
||||
apply_chat_template() for proper prompt construction with tool schemas
|
||||
and reasoning configuration.
|
||||
|
||||
If a tool_call_parser is set on this instance, it will be used to
|
||||
extract structured tool_calls from the raw completion text.
|
||||
|
||||
Args:
|
||||
**kwargs: Standard chat completion kwargs (messages, n, etc.)
|
||||
**kwargs: Standard chat completion kwargs (messages, tools, n, etc.)
|
||||
|
||||
Returns:
|
||||
ChatCompletion response
|
||||
"""
|
||||
# Get input text
|
||||
# Extract kwargs meant for the chat template (not for the completion endpoint)
|
||||
template_kwargs = {}
|
||||
for key in list(kwargs.keys()):
|
||||
if key in self.CHAT_TEMPLATE_KWARGS:
|
||||
template_kwargs[key] = kwargs.pop(key)
|
||||
|
||||
# Get input text, forwarding template kwargs (tools, thinking, etc.)
|
||||
messages = kwargs.get("messages", [])
|
||||
prompt = self._convert_messages_to_prompt(messages)
|
||||
prompt = self._convert_messages_to_prompt(messages, **template_kwargs)
|
||||
|
||||
# Handle parent node and extending logic based on mode
|
||||
if self.track_tree:
|
||||
|
|
@ -318,12 +350,20 @@ class ManagedServer:
|
|||
finish_reason = finish_reason_raw
|
||||
|
||||
# Decode completion text
|
||||
# Use skip_special_tokens=False if we have a tool call parser,
|
||||
# so we can see the tool call special tokens for parsing
|
||||
if self.tokenizer is not None:
|
||||
skip_special = not hasattr(self, "tool_call_parser") or self.tool_call_parser is None
|
||||
completion_text = self.tokenizer.decode(
|
||||
output_tokens, skip_special_tokens=True
|
||||
output_tokens, skip_special_tokens=skip_special
|
||||
)
|
||||
# Also decode a clean version for the content field
|
||||
clean_text = self.tokenizer.decode(
|
||||
output_tokens, skip_special_tokens=True
|
||||
) if not skip_special else completion_text
|
||||
else:
|
||||
completion_text = "".join([chr(t) for t in output_tokens if t > 31])
|
||||
clean_text = completion_text
|
||||
|
||||
# Create and store sequence node
|
||||
node = self._create_sequence_node(
|
||||
|
|
@ -354,14 +394,47 @@ class ManagedServer:
|
|||
# New context - append to list
|
||||
self.current_nodes.append(node)
|
||||
|
||||
# Build choice
|
||||
# Apply tool call parser if configured (Phase 2)
|
||||
parsed_content = clean_text
|
||||
parsed_tool_calls = None
|
||||
parsed_finish_reason = finish_reason
|
||||
|
||||
if hasattr(self, "tool_call_parser") and self.tool_call_parser is not None and template_kwargs.get("tools"):
|
||||
try:
|
||||
parsed_content, parsed_tool_calls = self.tool_call_parser.parse(completion_text)
|
||||
if parsed_tool_calls:
|
||||
parsed_finish_reason = "tool_calls"
|
||||
if parsed_content is None:
|
||||
parsed_content = ""
|
||||
except Exception:
|
||||
pass # Fall through to no tool calls
|
||||
|
||||
# Extract reasoning content from <think> blocks or similar
|
||||
reasoning_content = None
|
||||
final_content = parsed_content
|
||||
import re as _re
|
||||
think_match = _re.search(r"<think>(.*?)</think>", final_content or "", _re.DOTALL)
|
||||
if think_match:
|
||||
reasoning_content = think_match.group(1).strip()
|
||||
# Strip <think> blocks from content
|
||||
final_content = _re.sub(r"<think>.*?</think>", "", final_content, flags=_re.DOTALL).strip()
|
||||
|
||||
# Build choice with tool_calls and reasoning if found
|
||||
message_kwargs = {"content": final_content or "", "role": "assistant"}
|
||||
choice = Choice(
|
||||
finish_reason=finish_reason,
|
||||
finish_reason=parsed_finish_reason,
|
||||
index=i,
|
||||
message=ChatCompletionMessage(
|
||||
content=completion_text, role="assistant"
|
||||
),
|
||||
message=ChatCompletionMessage(**message_kwargs),
|
||||
)
|
||||
|
||||
# Attach tool_calls if parsed
|
||||
if parsed_tool_calls:
|
||||
choice.message.tool_calls = parsed_tool_calls
|
||||
|
||||
# Attach reasoning_content as an extra attribute
|
||||
if reasoning_content:
|
||||
choice.message.reasoning_content = reasoning_content
|
||||
|
||||
choices.append(choice)
|
||||
|
||||
# Construct ChatCompletion response
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue