mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
qwen tokenizer wrapper & fixed jinja template for tool handling (#224)
* added qwen tokenizer wrapper & fixed jinja template for tool handling issues in the official HF one * moved jinja template into it's own file
This commit is contained in:
parent
56fb50a503
commit
9f23c732dd
3 changed files with 117 additions and 0 deletions
72
atroposlib/utils/tokenizers/qwen_fixed_tokenizer.py
Normal file
72
atroposlib/utils/tokenizers/qwen_fixed_tokenizer.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
"""
|
||||
Custom Qwen tokenizer wrapper with fixed Jinja2 template.
|
||||
|
||||
This wrapper overrides the chat_template to avoid Jinja2 sandbox restrictions
|
||||
that prevent list.append() operations in the original Qwen tokenizer.
|
||||
|
||||
TLDR; tool calls with Qwen are a PITA
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
class QwenFixedTokenizer:
|
||||
"""Wrapper around Qwen tokenizer with fixed chat template."""
|
||||
|
||||
@classmethod
|
||||
def _load_chat_template(cls) -> str:
|
||||
"""Load the chat template from the .jinja file."""
|
||||
template_path = Path(__file__).parent / "qwen_chat_template.jinja"
|
||||
with open(template_path, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
def __init__(self, model_name_or_path: str, **kwargs):
|
||||
"""Initialize the tokenizer wrapper.
|
||||
|
||||
Args:
|
||||
model_name_or_path: Model name or path to load tokenizer from
|
||||
**kwargs: Additional arguments passed to AutoTokenizer.from_pretrained
|
||||
"""
|
||||
# Load the base tokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, **kwargs)
|
||||
|
||||
# Override the chat template with our fixed version
|
||||
self.tokenizer.chat_template = self._load_chat_template()
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Delegate all other attributes to the underlying tokenizer."""
|
||||
return getattr(self.tokenizer, name)
|
||||
|
||||
def apply_chat_template(
|
||||
self,
|
||||
conversation: List[Dict[str, str]],
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
tokenize: bool = True,
|
||||
padding: bool = False,
|
||||
truncation: bool = False,
|
||||
max_length: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, bool]] = None,
|
||||
return_dict: bool = False,
|
||||
add_generation_prompt: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Apply the fixed chat template.
|
||||
|
||||
This method delegates to the underlying tokenizer's apply_chat_template
|
||||
but ensures our fixed template is used.
|
||||
"""
|
||||
return self.tokenizer.apply_chat_template(
|
||||
conversation,
|
||||
tools=tools,
|
||||
tokenize=tokenize,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
return_tensors=return_tensors,
|
||||
return_dict=return_dict,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
**kwargs,
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue