mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-30 17:40:36 +00:00
refactor to not mess up process...
This commit is contained in:
parent
6e9405ba95
commit
df62979b90
5 changed files with 346 additions and 424 deletions
|
|
@ -40,6 +40,7 @@ from atroposlib.utils.metrics import get_std_min_max_avg
|
|||
|
||||
from ..type_definitions import Item, Message
|
||||
from .server_handling.server_manager import (
|
||||
APIServer,
|
||||
APIServerConfig,
|
||||
ServerBaseline,
|
||||
ServerManager,
|
||||
|
|
@ -163,8 +164,9 @@ class BaseEnvConfig(BaseModel):
|
|||
|
||||
class BaseEnv(ABC):
|
||||
|
||||
name = None
|
||||
env_config_cls = BaseEnvConfig
|
||||
name: Optional[str] = None
|
||||
env_config_cls: BaseEnvConfig = BaseEnvConfig
|
||||
server_cls: APIServer = APIServer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -172,7 +174,6 @@ class BaseEnv(ABC):
|
|||
server_configs: Union[ServerBaseline, List[APIServerConfig]],
|
||||
slurm=False,
|
||||
testing=False,
|
||||
server_class=None,
|
||||
):
|
||||
self.items_sent_this_step = 0
|
||||
self.eval_runner = None # type: Optional[asyncio.Task]
|
||||
|
|
@ -186,7 +187,7 @@ class BaseEnv(ABC):
|
|||
self.last_completed_item = None
|
||||
self.config = config
|
||||
self.server = ServerManager(
|
||||
server_configs, slurm=slurm, testing=testing, server_class=server_class
|
||||
server_configs, slurm=slurm, testing=testing, server_class=self.server_cls
|
||||
)
|
||||
self.workers = set()
|
||||
self.eval_workers = set()
|
||||
|
|
@ -1001,14 +1002,7 @@ class BaseEnv(ABC):
|
|||
type: The CliServeConfig class for serving commands.
|
||||
"""
|
||||
# Get the default configurations defined by the specific environment class
|
||||
configs_and_maybe_server_class = cls.config_init()
|
||||
if len(configs_and_maybe_server_class) == 2:
|
||||
default_env_config, default_server_configs = configs_and_maybe_server_class
|
||||
server_class = None
|
||||
else:
|
||||
default_env_config, default_server_configs, server_class = (
|
||||
configs_and_maybe_server_class
|
||||
)
|
||||
default_env_config, default_server_configs = cls.config_init()
|
||||
|
||||
# Define namespace prefixes for CLI arguments and YAML keys
|
||||
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
|
||||
|
|
@ -1186,14 +1180,7 @@ class BaseEnv(ABC):
|
|||
)
|
||||
|
||||
# Get the base default configurations from the specific environment class
|
||||
configs_and_maybe_server_class = cls.config_init()
|
||||
if len(configs_and_maybe_server_class) == 2:
|
||||
default_env_config, default_server_configs = configs_and_maybe_server_class
|
||||
server_class = None
|
||||
else:
|
||||
default_env_config, default_server_configs, server_class = (
|
||||
configs_and_maybe_server_class
|
||||
)
|
||||
default_env_config, default_server_configs = cls.config_init()
|
||||
|
||||
# Define namespace prefixes
|
||||
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
|
||||
|
|
@ -1348,7 +1335,6 @@ class BaseEnv(ABC):
|
|||
server_configs=openai_configs,
|
||||
slurm=server_manager_config.slurm,
|
||||
testing=server_manager_config.testing,
|
||||
server_class=server_class,
|
||||
)
|
||||
|
||||
# Set specific parameters for process mode on the environment instance
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue