Enhance ManagedServer and ServerManager to support tool call parsing. Added optional tool_call_parser parameter to ManagedServer and managed_server context manager, enabling structured extraction of tool calls from raw completion text. Updated _convert_messages_to_prompt to forward additional template kwargs for improved prompt construction.

This commit is contained in:
teknium 2026-02-07 09:12:21 +00:00
parent 7da681ec46
commit 49199fa6b7
2 changed files with 94 additions and 15 deletions

View file

@ -58,6 +58,7 @@ class ManagedServer:
server: APIServer,
tokenizer: Optional[Any] = None,
track_tree: bool = False,
tool_call_parser=None,
):
"""
Initialize the managed server.
@ -69,10 +70,14 @@ class ManagedServer:
track_tree: If True, maintains a tree structure with parent-child links
(for multi-turn RL with per-step advantages). If False (default),
maintains a simple list of current nodes that updates in-place.
tool_call_parser: Optional tool call parser instance that extracts structured
tool_calls from raw completion text. Used in Phase 2 when
/generate returns raw tokens without tool call parsing.
"""
self.server = server
self.tokenizer = tokenizer
self.track_tree = track_tree
self.tool_call_parser = tool_call_parser
# Initialize storage based on mode
if track_tree:
@ -103,12 +108,18 @@ class ManagedServer:
)
self.tokenizer = None
def _convert_messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
# Kwargs that should be forwarded to apply_chat_template (not to the completion endpoint)
# Models like Kimi-K2, Qwen3 use these for tool formatting, reasoning block handling, etc.
CHAT_TEMPLATE_KWARGS = {"tools", "thinking", "enable_thinking", "documents"}
def _convert_messages_to_prompt(self, messages: List[Dict[str, str]], **template_kwargs) -> str:
"""
Convert chat messages to prompt text using tokenizer's chat template.
Args:
messages: List of message dicts with 'role' and 'content'
**template_kwargs: Additional kwargs to forward to apply_chat_template
(e.g., tools=, thinking=, enable_thinking=)
Returns:
Formatted prompt string
@ -123,10 +134,18 @@ class ManagedServer:
len(messages) == 0 or messages[-1].get("role") != "assistant"
)
# Build kwargs for apply_chat_template, forwarding tools/thinking/etc.
ct_kwargs = {
"tokenize": False,
"add_generation_prompt": add_generation_prompt,
}
# Forward any template-relevant kwargs (tools, thinking, etc.)
for key, value in template_kwargs.items():
if value is not None:
ct_kwargs[key] = value
# Use the tokenizer's chat template
return self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=add_generation_prompt
)
return self.tokenizer.apply_chat_template(messages, **ct_kwargs)
else:
# Fallback for tokenizers without chat template
return "\n".join([f"{m['role']}: {m['content']}" for m in messages])
@ -260,15 +279,28 @@ class ManagedServer:
Internally converts to prompt, calls tokens_and_logprobs_completion,
tracks the sequence, and reconstructs a ChatCompletion response.
Supports forwarding tools=, thinking=, enable_thinking=, etc. to
apply_chat_template() for proper prompt construction with tool schemas
and reasoning configuration.
If a tool_call_parser is set on this instance, it will be used to
extract structured tool_calls from the raw completion text.
Args:
**kwargs: Standard chat completion kwargs (messages, n, etc.)
**kwargs: Standard chat completion kwargs (messages, tools, n, etc.)
Returns:
ChatCompletion response
"""
# Get input text
# Extract kwargs meant for the chat template (not for the completion endpoint)
template_kwargs = {}
for key in list(kwargs.keys()):
if key in self.CHAT_TEMPLATE_KWARGS:
template_kwargs[key] = kwargs.pop(key)
# Get input text, forwarding template kwargs (tools, thinking, etc.)
messages = kwargs.get("messages", [])
prompt = self._convert_messages_to_prompt(messages)
prompt = self._convert_messages_to_prompt(messages, **template_kwargs)
# Handle parent node and extending logic based on mode
if self.track_tree:
@ -318,12 +350,20 @@ class ManagedServer:
finish_reason = finish_reason_raw
# Decode completion text
# Use skip_special_tokens=False if we have a tool call parser,
# so we can see the tool call special tokens for parsing
if self.tokenizer is not None:
skip_special = not hasattr(self, "tool_call_parser") or self.tool_call_parser is None
completion_text = self.tokenizer.decode(
output_tokens, skip_special_tokens=True
output_tokens, skip_special_tokens=skip_special
)
# Also decode a clean version for the content field
clean_text = self.tokenizer.decode(
output_tokens, skip_special_tokens=True
) if not skip_special else completion_text
else:
completion_text = "".join([chr(t) for t in output_tokens if t > 31])
clean_text = completion_text
# Create and store sequence node
node = self._create_sequence_node(
@ -354,14 +394,47 @@ class ManagedServer:
# New context - append to list
self.current_nodes.append(node)
# Build choice
# Apply tool call parser if configured (Phase 2)
parsed_content = clean_text
parsed_tool_calls = None
parsed_finish_reason = finish_reason
if hasattr(self, "tool_call_parser") and self.tool_call_parser is not None and template_kwargs.get("tools"):
try:
parsed_content, parsed_tool_calls = self.tool_call_parser.parse(completion_text)
if parsed_tool_calls:
parsed_finish_reason = "tool_calls"
if parsed_content is None:
parsed_content = ""
except Exception:
pass # Fall through to no tool calls
# Extract reasoning content from <think> blocks or similar
reasoning_content = None
final_content = parsed_content
import re as _re
think_match = _re.search(r"<think>(.*?)</think>", final_content or "", _re.DOTALL)
if think_match:
reasoning_content = think_match.group(1).strip()
# Strip <think> blocks from content
final_content = _re.sub(r"<think>.*?</think>", "", final_content, flags=_re.DOTALL).strip()
# Build choice with tool_calls and reasoning if found
message_kwargs = {"content": final_content or "", "role": "assistant"}
choice = Choice(
finish_reason=finish_reason,
finish_reason=parsed_finish_reason,
index=i,
message=ChatCompletionMessage(
content=completion_text, role="assistant"
),
message=ChatCompletionMessage(**message_kwargs),
)
# Attach tool_calls if parsed
if parsed_tool_calls:
choice.message.tool_calls = parsed_tool_calls
# Attach reasoning_content as an extra attribute
if reasoning_content:
choice.message.reasoning_content = reasoning_content
choices.append(choice)
# Construct ChatCompletion response