mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
add tokenizer name config to set the vllm/sglang tokenizer to something different if needed
This commit is contained in:
parent
13f282aabc
commit
7d6aeb9bbf
4 changed files with 17 additions and 2 deletions
|
|
@ -250,6 +250,10 @@ class ServerBaseline(BaseModel):
|
|||
server_type: Literal["openai", "trl", "sglang", "vllm"] = Field(
|
||||
default="openai", description="Type of server to use"
|
||||
)
|
||||
tokenizer_name: str = Field(
|
||||
default="none",
|
||||
description="The tokenizer name to use. If none, will use the model_name as the tokenizer.",
|
||||
)
|
||||
|
||||
|
||||
class APIServerConfig(ServerBaseline):
|
||||
|
|
|
|||
|
|
@ -121,6 +121,7 @@ class ServerManager:
|
|||
model_name=configs.model_name,
|
||||
rolling_buffer_length=configs.rolling_buffer_length,
|
||||
api_key="x",
|
||||
tokenizer_name=configs.tokenizer_name,
|
||||
)
|
||||
)
|
||||
self.servers = [
|
||||
|
|
|
|||
|
|
@ -31,7 +31,12 @@ class SGLangServer(APIServer):
|
|||
base_url=config.base_url,
|
||||
timeout=config.timeout,
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
||||
tokenizer_name = (
|
||||
config.model_name
|
||||
if config.tokenizer_name == "none"
|
||||
else config.tokenizer_name
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
super().__init__(config, reasoning_config=reasoning_config)
|
||||
|
||||
async def check_server_status_task(self, chat_completion: bool = True):
|
||||
|
|
|
|||
|
|
@ -34,7 +34,12 @@ class VLLMServer(APIServer):
|
|||
base_url=config.base_url,
|
||||
timeout=config.timeout,
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
||||
tokenizer_name = (
|
||||
config.model_name
|
||||
if config.tokenizer_name == "none"
|
||||
else config.tokenizer_name
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
super().__init__(config, reasoning_config=reasoning_config)
|
||||
|
||||
async def check_server_status_task(self, chat_completion: bool = True):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue