diff --git a/atroposlib/envs/server_handling/server_baseline.py b/atroposlib/envs/server_handling/server_baseline.py index 5b88f388..3040c9ca 100644 --- a/atroposlib/envs/server_handling/server_baseline.py +++ b/atroposlib/envs/server_handling/server_baseline.py @@ -250,6 +250,10 @@ class ServerBaseline(BaseModel): server_type: Literal["openai", "trl", "sglang", "vllm"] = Field( default="openai", description="Type of server to use" ) + tokenizer_name: str = Field( + default="none", + description="The tokenizer name to use. If none, will use the model_name as the tokenizer.", + ) class APIServerConfig(ServerBaseline): diff --git a/atroposlib/envs/server_handling/server_manager.py b/atroposlib/envs/server_handling/server_manager.py index 9e760564..e76dea32 100644 --- a/atroposlib/envs/server_handling/server_manager.py +++ b/atroposlib/envs/server_handling/server_manager.py @@ -121,6 +121,7 @@ class ServerManager: model_name=configs.model_name, rolling_buffer_length=configs.rolling_buffer_length, api_key="x", + tokenizer_name=configs.tokenizer_name, ) ) self.servers = [ diff --git a/atroposlib/envs/server_handling/sglang_server.py b/atroposlib/envs/server_handling/sglang_server.py index 19ac3d6e..63201b3e 100644 --- a/atroposlib/envs/server_handling/sglang_server.py +++ b/atroposlib/envs/server_handling/sglang_server.py @@ -31,7 +31,12 @@ class SGLangServer(APIServer): base_url=config.base_url, timeout=config.timeout, ) - self.tokenizer = AutoTokenizer.from_pretrained(config.model_name) + tokenizer_name = ( + config.model_name + if config.tokenizer_name == "none" + else config.tokenizer_name + ) + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) super().__init__(config, reasoning_config=reasoning_config) async def check_server_status_task(self, chat_completion: bool = True): diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py index 41fec651..96242754 100644 --- a/atroposlib/envs/server_handling/vllm_server.py +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -34,7 +34,12 @@ class VLLMServer(APIServer): base_url=config.base_url, timeout=config.timeout, ) - self.tokenizer = AutoTokenizer.from_pretrained(config.model_name) + tokenizer_name = ( + config.model_name + if config.tokenizer_name == "none" + else config.tokenizer_name + ) + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) super().__init__(config, reasoning_config=reasoning_config) async def check_server_status_task(self, chat_completion: bool = True):