diff --git a/atroposlib/envs/server_handling/server_manager.py b/atroposlib/envs/server_handling/server_manager.py index 00d2bf5d..ac6a2058 100644 --- a/atroposlib/envs/server_handling/server_manager.py +++ b/atroposlib/envs/server_handling/server_manager.py @@ -1,4 +1,5 @@ import asyncio +import inspect import os from contextlib import asynccontextmanager from typing import AsyncGenerator, List, Union @@ -34,8 +35,9 @@ class ServerManager: slurm=False, testing=False, ): - if type(server_class) is APIServer: - if isinstance(configs, ServerBaseline): + # First we check to see if it's the base server class, and if so, we need to select the appropriate server class + if inspect.isabstract(server_class): + if not isinstance(configs, list): if configs.server_type == "openai": server_class = OpenAIServer elif configs.server_type == "trl": @@ -53,7 +55,7 @@ class ServerManager: # testing :) self.servers = [ServerHarness()] return - if isinstance(configs, ServerBaseline): + if not isinstance(configs, list): urls = [] if os.environ.get("SLURM_JOB_NODELIST", None) is not None: nodelist = (