refactor to not mess up process...

This commit is contained in:
dmahan93 2025-05-13 09:22:07 -05:00
parent 6e9405ba95
commit df62979b90
5 changed files with 346 additions and 424 deletions

View file

@ -40,6 +40,7 @@ from atroposlib.utils.metrics import get_std_min_max_avg
from ..type_definitions import Item, Message
from .server_handling.server_manager import (
APIServer,
APIServerConfig,
ServerBaseline,
ServerManager,
@ -163,8 +164,9 @@ class BaseEnvConfig(BaseModel):
class BaseEnv(ABC):
name = None
env_config_cls = BaseEnvConfig
name: Optional[str] = None
env_config_cls: BaseEnvConfig = BaseEnvConfig
server_cls: APIServer = APIServer
def __init__(
self,
@ -172,7 +174,6 @@ class BaseEnv(ABC):
server_configs: Union[ServerBaseline, List[APIServerConfig]],
slurm=False,
testing=False,
server_class=None,
):
self.items_sent_this_step = 0
self.eval_runner = None # type: Optional[asyncio.Task]
@ -186,7 +187,7 @@ class BaseEnv(ABC):
self.last_completed_item = None
self.config = config
self.server = ServerManager(
server_configs, slurm=slurm, testing=testing, server_class=server_class
server_configs, slurm=slurm, testing=testing, server_class=self.server_cls
)
self.workers = set()
self.eval_workers = set()
@ -1001,14 +1002,7 @@ class BaseEnv(ABC):
type: The CliServeConfig class for serving commands.
"""
# Get the default configurations defined by the specific environment class
configs_and_maybe_server_class = cls.config_init()
if len(configs_and_maybe_server_class) == 2:
default_env_config, default_server_configs = configs_and_maybe_server_class
server_class = None
else:
default_env_config, default_server_configs, server_class = (
configs_and_maybe_server_class
)
default_env_config, default_server_configs = cls.config_init()
# Define namespace prefixes for CLI arguments and YAML keys
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
@ -1186,14 +1180,7 @@ class BaseEnv(ABC):
)
# Get the base default configurations from the specific environment class
configs_and_maybe_server_class = cls.config_init()
if len(configs_and_maybe_server_class) == 2:
default_env_config, default_server_configs = configs_and_maybe_server_class
server_class = None
else:
default_env_config, default_server_configs, server_class = (
configs_and_maybe_server_class
)
default_env_config, default_server_configs = cls.config_init()
# Define namespace prefixes
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
@ -1348,7 +1335,6 @@ class BaseEnv(ABC):
server_configs=openai_configs,
slurm=server_manager_config.slurm,
testing=server_manager_config.testing,
server_class=server_class,
)
# Set specific parameters for process mode on the environment instance