mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-26 17:13:09 +00:00
--slurm and --testing in outer namespace
This commit is contained in:
parent
9a8ae1630b
commit
60d67d91e7
2 changed files with 20 additions and 19 deletions
|
|
@ -21,6 +21,8 @@ import yaml
|
|||
from pydantic import BaseModel, Field
|
||||
from pydantic_cli import Cmd, FailedExecutionException, run_and_exit
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from atroposlib.envs.constants import (
|
||||
ENV_NAMESPACE,
|
||||
NAMESPACE_SEP,
|
||||
|
|
@ -28,6 +30,7 @@ from atroposlib.envs.constants import (
|
|||
SERVER_MANAGER_NAMESPACE,
|
||||
)
|
||||
from atroposlib.frontend.jsonl2html import generate_html
|
||||
from atroposlib.type_definitions import UUID
|
||||
from atroposlib.utils.cli import (
|
||||
adjust_model_defaults,
|
||||
extract_namespace,
|
||||
|
|
@ -35,9 +38,6 @@ from atroposlib.utils.cli import (
|
|||
get_prefixed_pydantic_model,
|
||||
merge_dicts,
|
||||
)
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from atroposlib.type_definitions import UUID
|
||||
from atroposlib.utils.metrics import get_std_min_max_avg
|
||||
|
||||
from ..type_definitions import Item, Message
|
||||
|
|
@ -160,7 +160,7 @@ class BaseEnv(ABC):
|
|||
self,
|
||||
config: BaseEnvConfig,
|
||||
server_configs: Union[ServerBaseline, List[OpenaiConfig]],
|
||||
slurm=True,
|
||||
slurm=False,
|
||||
testing=False,
|
||||
):
|
||||
self.items_sent_this_step = 0
|
||||
|
|
@ -195,6 +195,10 @@ class BaseEnv(ABC):
|
|||
self.checkpoint_dir = ""
|
||||
self.checkpoint_interval = -1
|
||||
if self.config.data_path_to_save_groups is not None:
|
||||
|
||||
Path(self.config.data_path_to_save_groups).parent.mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
# Find a suitable filename by appending _1, _2, etc. if the file already exists
|
||||
original_path = self.config.data_path_to_save_groups
|
||||
counter = 1
|
||||
|
|
@ -937,15 +941,11 @@ class BaseEnv(ABC):
|
|||
env_config, server_configs = cls.config_init()
|
||||
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
|
||||
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
|
||||
server_full_prefix = f"{SERVER_MANAGER_NAMESPACE}{NAMESPACE_SEP}"
|
||||
|
||||
class CliServeConfig(
|
||||
get_prefixed_pydantic_model(type(env_config), env_full_prefix),
|
||||
get_prefixed_pydantic_model(OpenaiConfig, openai_full_prefix),
|
||||
get_prefixed_pydantic_model(
|
||||
ServerManagerConfig,
|
||||
server_full_prefix,
|
||||
),
|
||||
ServerManagerConfig,
|
||||
Cmd,
|
||||
):
|
||||
"""
|
||||
|
|
@ -1006,7 +1006,6 @@ class BaseEnv(ABC):
|
|||
|
||||
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
|
||||
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
|
||||
server_full_prefix = f"{SERVER_MANAGER_NAMESPACE}{NAMESPACE_SEP}"
|
||||
|
||||
env_config_cls_new_defaults = adjust_model_defaults(
|
||||
type(default_env_config), PROCESS_MODE_ENV_DEFAULT_CONFIG
|
||||
|
|
@ -1024,9 +1023,7 @@ class BaseEnv(ABC):
|
|||
get_prefixed_pydantic_model(
|
||||
openai_config_cls_new_defaults, openai_full_prefix
|
||||
),
|
||||
get_prefixed_pydantic_model(
|
||||
server_manager_config_cls_new_defaults, server_full_prefix
|
||||
),
|
||||
server_manager_config_cls_new_defaults,
|
||||
Cmd,
|
||||
):
|
||||
"""
|
||||
|
|
@ -1075,14 +1072,21 @@ class BaseEnv(ABC):
|
|||
), # only extract namespace for cli-passed args
|
||||
)
|
||||
)
|
||||
|
||||
server_manager_cli_passed_flags = {}
|
||||
if "slurm" in cli_passed_flags:
|
||||
server_manager_cli_passed_flags["slurm"] = cli_passed_flags["slurm"]
|
||||
if "testing" in cli_passed_flags:
|
||||
server_manager_cli_passed_flags["testing"] = cli_passed_flags[
|
||||
"testing"
|
||||
]
|
||||
|
||||
server_manager_config = server_manager_config_cls_new_defaults(
|
||||
**merge_dicts(
|
||||
ServerManagerConfig().model_dump(),
|
||||
PROCESS_MODE_SERVER_MANAGER_DEFAULT_CONFIG.model_dump(),
|
||||
config.get(SERVER_MANAGER_NAMESPACE, {}),
|
||||
extract_namespace(
|
||||
cli_passed_flags, server_full_prefix
|
||||
), # only extract namespace for cli-passed args
|
||||
server_manager_cli_passed_flags,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -1100,9 +1104,6 @@ class BaseEnv(ABC):
|
|||
env.group_size_to_process = env_config.group_size
|
||||
|
||||
print(
|
||||
f"Processing {self.n_groups} groups of "
|
||||
f"{self.group_size} responses and "
|
||||
f"writing to {self.output_file}"
|
||||
f"Processing {env_config.total_steps} groups of "
|
||||
f"{env_config.group_size} responses and "
|
||||
f"writing to {env_config.data_path_to_save_groups}"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue