mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
propagate cli stuff to serve command
This commit is contained in:
parent
a8b59ccc9b
commit
af26b2e68a
1 changed files with 204 additions and 52 deletions
|
|
@ -939,30 +939,138 @@ class BaseEnv(ABC):
|
|||
type: The CliServeConfig class for serving commands.
|
||||
"""
|
||||
|
||||
env_config, server_configs = cls.config_init()
|
||||
# Get the default configurations defined by the specific environment class
|
||||
default_env_config, default_server_configs = cls.config_init()
|
||||
|
||||
# Determine a default OpenaiConfig instance for merging purposes,
|
||||
# even if the actual default is ServerBaseline or a list.
|
||||
# This allows overriding with OpenAI settings via CLI/YAML consistently.
|
||||
default_openai_config = OpenaiConfig() # Base default
|
||||
if isinstance(default_server_configs, OpenaiConfig):
|
||||
default_openai_config = default_server_configs
|
||||
elif isinstance(default_server_configs, list) and default_server_configs:
|
||||
# If the default is a list, use the first item if it's OpenaiConfig
|
||||
if isinstance(default_server_configs[0], OpenaiConfig):
|
||||
default_openai_config = default_server_configs[0]
|
||||
|
||||
# Define namespace prefixes for CLI arguments and YAML keys
|
||||
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
|
||||
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
|
||||
|
||||
# Define the CLI configuration class dynamically
|
||||
class CliServeConfig(
|
||||
get_prefixed_pydantic_model(type(env_config), env_full_prefix),
|
||||
get_prefixed_pydantic_model(OpenaiConfig, openai_full_prefix),
|
||||
ServerManagerConfig,
|
||||
get_prefixed_pydantic_model(type(default_env_config), env_full_prefix),
|
||||
get_prefixed_pydantic_model(
|
||||
OpenaiConfig, openai_full_prefix
|
||||
), # Use OpenaiConfig for CLI args
|
||||
ServerManagerConfig, # ServerManager args are not namespaced by default
|
||||
Cmd,
|
||||
):
|
||||
"""
|
||||
Configuration for the serve command.
|
||||
This combines BaseEnvConfig and OpenaiConfig into a single command.
|
||||
Supports overrides via YAML config file and CLI arguments.
|
||||
Order of precedence: CLI > YAML > Class Defaults.
|
||||
"""
|
||||
|
||||
config: str | None = Field(
|
||||
default=None,
|
||||
description="Path to .yaml config file. CLI args override this.",
|
||||
)
|
||||
|
||||
def run(self) -> None:
|
||||
"""The logic to execute for the 'serve' command."""
|
||||
# Convert this config into the formats needed by BaseEnv
|
||||
# Set default wandb name if not provided and class has a name
|
||||
# Note: This modifies the 'self' instance based on CLI args before full parsing.
|
||||
wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name"
|
||||
if getattr(self, wandb_name_attr) is None and cls.name is not None:
|
||||
if (
|
||||
getattr(self, wandb_name_attr, None) is None
|
||||
and cls.name is not None
|
||||
):
|
||||
setattr(self, wandb_name_attr, cls.name)
|
||||
model_dumped = self.model_dump(exclude_unset=True)
|
||||
server_manager_config = ServerManagerConfig(**model_dumped)
|
||||
# Create the environment instance
|
||||
|
||||
# Load configuration from YAML file if specified
|
||||
if self.config is not None:
|
||||
with open(self.config, "r") as f:
|
||||
yaml_config = yaml.safe_load(f)
|
||||
print(f"Loaded config from {self.config}")
|
||||
else:
|
||||
yaml_config = {}
|
||||
|
||||
# Get CLI flags passed with double dashes (e.g., --env--foo bar)
|
||||
cli_passed_flags = get_double_dash_flags()
|
||||
|
||||
# --- Configuration Merging ---
|
||||
# Priority: CLI > YAML > Class Defaults
|
||||
|
||||
# 1. Environment Configuration
|
||||
env_config_dict = merge_dicts(
|
||||
default_env_config.model_dump(), # Class Defaults
|
||||
yaml_config.get(ENV_NAMESPACE, {}), # YAML config
|
||||
extract_namespace(cli_passed_flags, env_full_prefix), # CLI args
|
||||
)
|
||||
|
||||
# 2. OpenAI Configuration (used for potential overrides)
|
||||
openai_config_dict = merge_dicts(
|
||||
default_openai_config.model_dump(), # Default OpenaiConfig (or from class init)
|
||||
yaml_config.get(OPENAI_NAMESPACE, {}), # YAML config
|
||||
extract_namespace(cli_passed_flags, openai_full_prefix), # CLI args
|
||||
)
|
||||
|
||||
# 3. Server Manager Configuration (slurm, testing - not namespaced)
|
||||
# Extract only relevant CLI flags for ServerManager
|
||||
server_manager_cli_passed_flags = {}
|
||||
if "slurm" in cli_passed_flags:
|
||||
server_manager_cli_passed_flags["slurm"] = cli_passed_flags["slurm"]
|
||||
if "testing" in cli_passed_flags:
|
||||
server_manager_cli_passed_flags["testing"] = cli_passed_flags[
|
||||
"testing"
|
||||
]
|
||||
|
||||
server_manager_config_dict = merge_dicts(
|
||||
ServerManagerConfig().model_dump(), # Base defaults for ServerManager
|
||||
yaml_config.get(SERVER_MANAGER_NAMESPACE, {}), # YAML config
|
||||
server_manager_cli_passed_flags, # CLI args
|
||||
)
|
||||
|
||||
# --- Instantiate Final Config Objects ---
|
||||
# Create instances from the merged dictionaries using the original default types where appropriate
|
||||
|
||||
# Instantiate the final environment config using its original type
|
||||
env_config = type(default_env_config)(**env_config_dict)
|
||||
|
||||
# Instantiate the final server manager config
|
||||
server_manager_config = ServerManagerConfig(
|
||||
**server_manager_config_dict
|
||||
)
|
||||
|
||||
# Determine the final server_configs based on the original default type from cls.config_init()
|
||||
# This allows handling ServerBaseline, single OpenaiConfig, or list[OpenaiConfig]
|
||||
server_configs = (
|
||||
default_server_configs # Start with the original default
|
||||
)
|
||||
if isinstance(default_server_configs, OpenaiConfig):
|
||||
# If default was single OpenaiConfig, update it with merged values
|
||||
server_configs = OpenaiConfig(**openai_config_dict)
|
||||
elif isinstance(default_server_configs, list):
|
||||
# If default was list (presumably of OpenaiConfig), update the first or create one
|
||||
# This assumes the primary server config is the one overridden via CLI/YAML
|
||||
if default_server_configs and isinstance(
|
||||
default_server_configs[0], OpenaiConfig
|
||||
):
|
||||
# Update the first element, keep others as they were
|
||||
server_configs = [
|
||||
OpenaiConfig(**openai_config_dict)
|
||||
] + default_server_configs[1:]
|
||||
else:
|
||||
# If list was empty or didn't contain OpenaiConfig, create a new list with the merged config
|
||||
server_configs = [OpenaiConfig(**openai_config_dict)]
|
||||
# If default_server_configs was ServerBaseline, server_configs remains ServerBaseline,
|
||||
# effectively ignoring the openai_config_dict unless the user explicitly provides
|
||||
# OpenaiConfig settings via CLI/YAML, which would be captured but not used here unless
|
||||
# the environment class's config_init returned an OpenaiConfig or list.
|
||||
|
||||
# --- Create and Run Environment ---
|
||||
# Create the environment instance using the final, instantiated config objects
|
||||
env = cls(
|
||||
config=env_config,
|
||||
server_configs=server_configs,
|
||||
|
|
@ -970,7 +1078,7 @@ class BaseEnv(ABC):
|
|||
testing=server_manager_config.testing,
|
||||
)
|
||||
|
||||
# Run the environment
|
||||
# Run the environment's main asynchronous manager function
|
||||
asyncio.run(env.env_manager())
|
||||
|
||||
return CliServeConfig
|
||||
|
|
@ -984,51 +1092,76 @@ class BaseEnv(ABC):
|
|||
type: The CliProcessConfig class for processing commands.
|
||||
"""
|
||||
|
||||
# Define specific default configurations for the 'process' mode
|
||||
PROCESS_MODE_ENV_DEFAULT_CONFIG = BaseEnvConfig(
|
||||
group_size=8,
|
||||
total_steps=2,
|
||||
ensure_scores_are_not_same=False,
|
||||
include_messages=True,
|
||||
# Ensure a default path for process mode if not set by class/cli/yaml
|
||||
data_path_to_save_groups="output_groups.jsonl",
|
||||
use_wandb=False, # Typically disable wandb for simple processing
|
||||
)
|
||||
PROCESS_MODE_OPENAI_DEFAULT_CONFIG = OpenaiConfig(
|
||||
model_name="gpt-4.1-nano",
|
||||
model_name="gpt-4.1-nano", # A reasonable default for processing
|
||||
base_url=None,
|
||||
api_key=None,
|
||||
)
|
||||
PROCESS_MODE_SERVER_MANAGER_DEFAULT_CONFIG = ServerManagerConfig(
|
||||
slurm=False,
|
||||
slurm=False, # Usually run locally
|
||||
testing=False,
|
||||
)
|
||||
|
||||
default_env_config, default_openai_config = cls.config_init()
|
||||
# Get the base default configurations from the specific environment class
|
||||
default_env_config, default_server_configs = cls.config_init()
|
||||
|
||||
if isinstance(default_openai_config, list):
|
||||
default_openai_config = default_openai_config[0]
|
||||
# Ensure default_openai_config is a single instance for default merging logic.
|
||||
# Process mode specifically uses OpenaiConfig, so we establish a base default.
|
||||
if isinstance(default_server_configs, list):
|
||||
# Use the first if available and is OpenaiConfig, otherwise use a base OpenaiConfig
|
||||
default_openai_config = (
|
||||
default_server_configs[0]
|
||||
if default_server_configs
|
||||
and isinstance(default_server_configs[0], OpenaiConfig)
|
||||
else OpenaiConfig()
|
||||
)
|
||||
elif isinstance(default_server_configs, OpenaiConfig):
|
||||
default_openai_config = default_server_configs
|
||||
else:
|
||||
# If config_init returned ServerBaseline or something else, use a base OpenaiConfig for defaults
|
||||
default_openai_config = OpenaiConfig()
|
||||
|
||||
# Define namespace prefixes
|
||||
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
|
||||
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
|
||||
|
||||
# Create Pydantic model classes with the 'process' mode defaults applied.
|
||||
# These adjusted classes will be used for final instantiation.
|
||||
env_config_cls_new_defaults = adjust_model_defaults(
|
||||
type(default_env_config), PROCESS_MODE_ENV_DEFAULT_CONFIG
|
||||
)
|
||||
openai_config_cls_new_defaults = adjust_model_defaults(
|
||||
OpenaiConfig, PROCESS_MODE_OPENAI_DEFAULT_CONFIG
|
||||
OpenaiConfig,
|
||||
PROCESS_MODE_OPENAI_DEFAULT_CONFIG, # Process always uses OpenaiConfig type
|
||||
)
|
||||
server_manager_config_cls_new_defaults = adjust_model_defaults(
|
||||
ServerManagerConfig,
|
||||
PROCESS_MODE_SERVER_MANAGER_DEFAULT_CONFIG,
|
||||
)
|
||||
|
||||
# Define the CLI configuration class dynamically
|
||||
class CliProcessConfig(
|
||||
get_prefixed_pydantic_model(env_config_cls_new_defaults, env_full_prefix),
|
||||
get_prefixed_pydantic_model(
|
||||
openai_config_cls_new_defaults, openai_full_prefix
|
||||
),
|
||||
server_manager_config_cls_new_defaults,
|
||||
server_manager_config_cls_new_defaults, # Uses adjusted defaults
|
||||
Cmd,
|
||||
):
|
||||
"""
|
||||
Configuration for the process command.
|
||||
Supports overrides via YAML config file and CLI arguments.
|
||||
Order of precedence: CLI > YAML > Class Defaults > Process Mode Defaults.
|
||||
"""
|
||||
|
||||
config: str | None = Field(
|
||||
|
|
@ -1038,42 +1171,46 @@ class BaseEnv(ABC):
|
|||
|
||||
def run(self) -> None:
|
||||
"""The logic to execute for the 'process' command."""
|
||||
# Setup environment configuration
|
||||
# Set default wandb name if not provided and class has a name
|
||||
wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name"
|
||||
if getattr(self, wandb_name_attr) is None and cls.name is not None:
|
||||
if (
|
||||
getattr(self, wandb_name_attr, None) is None
|
||||
and cls.name is not None
|
||||
):
|
||||
setattr(self, wandb_name_attr, cls.name)
|
||||
|
||||
# Load configuration from YAML file if specified
|
||||
if self.config is not None:
|
||||
with open(self.config, "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
yaml_config = yaml.safe_load(f)
|
||||
print(f"Loaded config from {self.config}")
|
||||
else:
|
||||
config = {}
|
||||
yaml_config = {}
|
||||
|
||||
# Get CLI flags passed with double dashes
|
||||
cli_passed_flags = get_double_dash_flags()
|
||||
|
||||
# cli args overrides config file which overrides class defaults which overrides process mode defaults
|
||||
env_config = env_config_cls_new_defaults(
|
||||
**merge_dicts(
|
||||
default_env_config.model_dump(),
|
||||
PROCESS_MODE_ENV_DEFAULT_CONFIG.model_dump(),
|
||||
config.get(ENV_NAMESPACE, {}),
|
||||
extract_namespace(
|
||||
cli_passed_flags, env_full_prefix
|
||||
), # only extract namespace for cli-passed args
|
||||
)
|
||||
)
|
||||
openai_config = openai_config_cls_new_defaults(
|
||||
**merge_dicts(
|
||||
default_openai_config.model_dump(),
|
||||
PROCESS_MODE_OPENAI_DEFAULT_CONFIG.model_dump(),
|
||||
config.get(OPENAI_NAMESPACE, {}),
|
||||
extract_namespace(
|
||||
cli_passed_flags, openai_full_prefix
|
||||
), # only extract namespace for cli-passed args
|
||||
)
|
||||
# --- Configuration Merging ---
|
||||
# Priority: CLI > YAML > Class Defaults > Process Mode Defaults
|
||||
|
||||
# 1. Environment Configuration
|
||||
env_config_dict = merge_dicts(
|
||||
default_env_config.model_dump(), # Class Defaults
|
||||
PROCESS_MODE_ENV_DEFAULT_CONFIG.model_dump(), # Process Mode Defaults
|
||||
yaml_config.get(ENV_NAMESPACE, {}), # YAML config
|
||||
extract_namespace(cli_passed_flags, env_full_prefix), # CLI args
|
||||
)
|
||||
|
||||
# 2. OpenAI Configuration
|
||||
openai_config_dict = merge_dicts(
|
||||
default_openai_config.model_dump(), # Class Defaults (adjusted to be OpenaiConfig)
|
||||
PROCESS_MODE_OPENAI_DEFAULT_CONFIG.model_dump(), # Process Mode Defaults
|
||||
yaml_config.get(OPENAI_NAMESPACE, {}), # YAML config
|
||||
extract_namespace(cli_passed_flags, openai_full_prefix), # CLI args
|
||||
)
|
||||
|
||||
# 3. Server Manager Configuration
|
||||
# Extract only relevant CLI flags
|
||||
server_manager_cli_passed_flags = {}
|
||||
if "slurm" in cli_passed_flags:
|
||||
server_manager_cli_passed_flags["slurm"] = cli_passed_flags["slurm"]
|
||||
|
|
@ -1082,36 +1219,51 @@ class BaseEnv(ABC):
|
|||
"testing"
|
||||
]
|
||||
|
||||
server_manager_config = server_manager_config_cls_new_defaults(
|
||||
**merge_dicts(
|
||||
ServerManagerConfig().model_dump(),
|
||||
PROCESS_MODE_SERVER_MANAGER_DEFAULT_CONFIG.model_dump(),
|
||||
config.get(SERVER_MANAGER_NAMESPACE, {}),
|
||||
server_manager_cli_passed_flags,
|
||||
)
|
||||
server_manager_config_dict = merge_dicts(
|
||||
ServerManagerConfig().model_dump(), # Base defaults
|
||||
PROCESS_MODE_SERVER_MANAGER_DEFAULT_CONFIG.model_dump(), # Process Mode Defaults
|
||||
yaml_config.get(SERVER_MANAGER_NAMESPACE, {}), # YAML config
|
||||
server_manager_cli_passed_flags, # CLI args
|
||||
)
|
||||
|
||||
# --- Instantiate Final Config Objects ---
|
||||
# Use the classes with adjusted defaults for instantiation
|
||||
|
||||
env_config = env_config_cls_new_defaults(**env_config_dict)
|
||||
openai_config = openai_config_cls_new_defaults(**openai_config_dict)
|
||||
server_manager_config = server_manager_config_cls_new_defaults(
|
||||
**server_manager_config_dict
|
||||
)
|
||||
|
||||
# --- Create and Run Environment ---
|
||||
# Create the environment instance
|
||||
env = cls(
|
||||
config=env_config,
|
||||
# Process mode always uses a single OpenAI config, passed as a list
|
||||
server_configs=[openai_config],
|
||||
slurm=server_manager_config.slurm,
|
||||
testing=server_manager_config.testing,
|
||||
)
|
||||
|
||||
# Set the process mode parameters
|
||||
# Set specific parameters for process mode on the environment instance
|
||||
env.process_mode = True
|
||||
env.n_groups_to_process = env_config.total_steps
|
||||
env.group_size_to_process = env_config.group_size
|
||||
|
||||
# Validate that an output path is set (should have a default from PROCESS_MODE_ENV_DEFAULT_CONFIG)
|
||||
if env_config.data_path_to_save_groups is None:
|
||||
# This check might be redundant if the default is always set, but good practice.
|
||||
raise ValueError(
|
||||
"data_path_to_save_groups must be set for process mode"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Processing {env_config.total_steps} groups of "
|
||||
f"{env_config.group_size} responses and "
|
||||
f"writing to {env_config.data_path_to_save_groups}"
|
||||
)
|
||||
|
||||
# Run the environment's asynchronous process manager function
|
||||
asyncio.run(env.process_manager())
|
||||
|
||||
# Actual implementation would go here
|
||||
|
||||
return CliProcessConfig
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue