--slurm and --testing in outer namespace

This commit is contained in:
hjc-puro 2025-05-02 03:46:34 -07:00
parent 9a8ae1630b
commit 60d67d91e7
2 changed files with 20 additions and 19 deletions

View file

@ -21,6 +21,8 @@ import yaml
from pydantic import BaseModel, Field
from pydantic_cli import Cmd, FailedExecutionException, run_and_exit
from tenacity import retry, stop_after_attempt, wait_random_exponential
from transformers import AutoTokenizer
from atroposlib.envs.constants import (
ENV_NAMESPACE,
NAMESPACE_SEP,
@ -28,6 +30,7 @@ from atroposlib.envs.constants import (
SERVER_MANAGER_NAMESPACE,
)
from atroposlib.frontend.jsonl2html import generate_html
from atroposlib.type_definitions import UUID
from atroposlib.utils.cli import (
adjust_model_defaults,
extract_namespace,
@ -35,9 +38,6 @@ from atroposlib.utils.cli import (
get_prefixed_pydantic_model,
merge_dicts,
)
from transformers import AutoTokenizer
from atroposlib.type_definitions import UUID
from atroposlib.utils.metrics import get_std_min_max_avg
from ..type_definitions import Item, Message
@ -160,7 +160,7 @@ class BaseEnv(ABC):
self,
config: BaseEnvConfig,
server_configs: Union[ServerBaseline, List[OpenaiConfig]],
slurm=True,
slurm=False,
testing=False,
):
self.items_sent_this_step = 0
@ -195,6 +195,10 @@ class BaseEnv(ABC):
self.checkpoint_dir = ""
self.checkpoint_interval = -1
if self.config.data_path_to_save_groups is not None:
Path(self.config.data_path_to_save_groups).parent.mkdir(
parents=True, exist_ok=True
)
# Find a suitable filename by appending _1, _2, etc. if the file already exists
original_path = self.config.data_path_to_save_groups
counter = 1
@ -937,15 +941,11 @@ class BaseEnv(ABC):
env_config, server_configs = cls.config_init()
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
server_full_prefix = f"{SERVER_MANAGER_NAMESPACE}{NAMESPACE_SEP}"
class CliServeConfig(
get_prefixed_pydantic_model(type(env_config), env_full_prefix),
get_prefixed_pydantic_model(OpenaiConfig, openai_full_prefix),
get_prefixed_pydantic_model(
ServerManagerConfig,
server_full_prefix,
),
ServerManagerConfig,
Cmd,
):
"""
@ -1006,7 +1006,6 @@ class BaseEnv(ABC):
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
server_full_prefix = f"{SERVER_MANAGER_NAMESPACE}{NAMESPACE_SEP}"
env_config_cls_new_defaults = adjust_model_defaults(
type(default_env_config), PROCESS_MODE_ENV_DEFAULT_CONFIG
@ -1024,9 +1023,7 @@ class BaseEnv(ABC):
get_prefixed_pydantic_model(
openai_config_cls_new_defaults, openai_full_prefix
),
get_prefixed_pydantic_model(
server_manager_config_cls_new_defaults, server_full_prefix
),
server_manager_config_cls_new_defaults,
Cmd,
):
"""
@ -1075,14 +1072,21 @@ class BaseEnv(ABC):
), # only extract namespace for cli-passed args
)
)
server_manager_cli_passed_flags = {}
if "slurm" in cli_passed_flags:
server_manager_cli_passed_flags["slurm"] = cli_passed_flags["slurm"]
if "testing" in cli_passed_flags:
server_manager_cli_passed_flags["testing"] = cli_passed_flags[
"testing"
]
server_manager_config = server_manager_config_cls_new_defaults(
**merge_dicts(
ServerManagerConfig().model_dump(),
PROCESS_MODE_SERVER_MANAGER_DEFAULT_CONFIG.model_dump(),
config.get(SERVER_MANAGER_NAMESPACE, {}),
extract_namespace(
cli_passed_flags, server_full_prefix
), # only extract namespace for cli-passed args
server_manager_cli_passed_flags,
)
)
@ -1100,9 +1104,6 @@ class BaseEnv(ABC):
env.group_size_to_process = env_config.group_size
print(
f"Processing {self.n_groups} groups of "
f"{self.group_size} responses and "
f"writing to {self.output_file}"
f"Processing {env_config.total_steps} groups of "
f"{env_config.group_size} responses and "
f"writing to {env_config.data_path_to_save_groups}"