remove process defaults, respect config init

This commit is contained in:
hjc-puro 2025-06-02 21:19:45 -04:00
parent 0f8b60c119
commit b5e7746c99

View file

@ -29,7 +29,6 @@ from atroposlib.envs.server_handling.openai_server import resolve_openai_configs
from atroposlib.frontend.jsonl2html import generate_html
from atroposlib.type_definitions import UUID
from atroposlib.utils.cli import (
adjust_model_defaults,
extract_namespace,
get_double_dash_flags,
get_prefixed_pydantic_model,
@ -1193,57 +1192,57 @@ class BaseEnv(ABC):
type: The CliProcessConfig class for processing commands.
"""
# Define specific default configurations for the 'process' mode
PROCESS_MODE_ENV_DEFAULT_CONFIG = BaseEnvConfig(
group_size=8,
total_steps=2,
ensure_scores_are_not_same=False,
include_messages=True,
data_path_to_save_groups=f"data/{cls.name or 'groups'}.jsonl",
use_wandb=True,
# Get the default configurations from the specific environment class via config_init
default_env_config_from_init, default_server_configs_from_init = (
cls.config_init()
)
PROCESS_MODE_OPENAI_DEFAULT_CONFIG = APIServerConfig(
model_name="gpt-4.1-nano",
base_url=None,
api_key=None,
)
PROCESS_MODE_SERVER_MANAGER_DEFAULT_CONFIG = ServerManagerConfig(
slurm=False,
testing=False,
)
# Get the base default configurations from the specific environment class
default_env_config, default_server_configs = cls.config_init()
# Define namespace prefixes
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
# Create Pydantic model classes with the 'process' mode defaults applied.
# These adjusted classes will be used for final instantiation.
env_config_cls_new_defaults = adjust_model_defaults(
type(default_env_config), PROCESS_MODE_ENV_DEFAULT_CONFIG
)
openai_config_cls_new_defaults = adjust_model_defaults(
APIServerConfig, PROCESS_MODE_OPENAI_DEFAULT_CONFIG
)
server_manager_config_cls_new_defaults = adjust_model_defaults(
ServerManagerConfig,
PROCESS_MODE_SERVER_MANAGER_DEFAULT_CONFIG,
)
# Create Pydantic model classes based on the types from config_init.
# The defaults from config_init will be the primary source of defaults.
env_config_cls_from_init = type(default_env_config_from_init)
# Handle server_configs_from_init appropriately for creating a default CLI model
# If it's a list (multiple servers), we'll take the first one as a template for CLI args,
# or use APIServerConfig if the list is empty or contains ServerBaseline.
# If it's a single APIServerConfig, we use its type.
# If it's ServerBaseline, we use APIServerConfig type for CLI args to allow overrides.
if isinstance(default_server_configs_from_init, list):
if default_server_configs_from_init and isinstance(
default_server_configs_from_init[0], APIServerConfig
):
openai_config_cls_for_cli = type(default_server_configs_from_init[0])
# Use the actual instance for default values later if it's a single config
default_openai_config_instance_for_cli = (
default_server_configs_from_init[0]
if len(default_server_configs_from_init) == 1
else openai_config_cls_for_cli()
)
else:
openai_config_cls_for_cli = (
APIServerConfig # Default to APIServerConfig for CLI definition
)
default_openai_config_instance_for_cli = APIServerConfig()
elif isinstance(default_server_configs_from_init, APIServerConfig):
openai_config_cls_for_cli = type(default_server_configs_from_init)
default_openai_config_instance_for_cli = default_server_configs_from_init
else: # ServerBaseline or other
openai_config_cls_for_cli = APIServerConfig
default_openai_config_instance_for_cli = APIServerConfig()
class CliProcessConfig(
get_prefixed_pydantic_model(env_config_cls_new_defaults, env_full_prefix),
get_prefixed_pydantic_model(
openai_config_cls_new_defaults, openai_full_prefix
),
server_manager_config_cls_new_defaults,
get_prefixed_pydantic_model(env_config_cls_from_init, env_full_prefix),
get_prefixed_pydantic_model(openai_config_cls_for_cli, openai_full_prefix),
ServerManagerConfig, # ServerManagerConfig defaults are fine as is.
Cmd,
):
"""
Configuration for the process command.
Supports overrides via YAML config file and CLI arguments.
Order of precedence: CLI > YAML > Process Mode Defaults > `config_init` defaults.
Order of precedence: CLI > YAML > `config_init` defaults.
"""
config: str | None = Field(
@ -1273,12 +1272,22 @@ class BaseEnv(ABC):
cli_passed_flags = get_double_dash_flags()
# --- Configuration Merging ---
# Priority: CLI > YAML > Process Mode Defaults > `config_init` defaults
# Priority: CLI > YAML > `config_init` defaults
# 1. Environment Configuration
# Start with defaults from config_init
env_config_dict_base = default_env_config_from_init.model_dump()
# Apply specific overrides for process mode that are generally useful
env_config_dict_base["ensure_scores_are_not_same"] = False
env_config_dict_base["include_messages"] = True
if env_config_dict_base.get("data_path_to_save_groups") is None:
env_config_dict_base["data_path_to_save_groups"] = (
f"data/{cls.name or 'groups'}.jsonl"
)
env_config_dict_base["use_wandb"] = True
env_config_dict = merge_dicts(
default_env_config.model_dump(), # Class Defaults
PROCESS_MODE_ENV_DEFAULT_CONFIG.model_dump(), # Process Mode Defaults
env_config_dict_base, # `config_init` defaults with process adjustments
yaml_config.get(ENV_NAMESPACE, {}), # YAML config
extract_namespace(cli_passed_flags, env_full_prefix), # CLI args
)
@ -1288,37 +1297,37 @@ class BaseEnv(ABC):
cli_passed_flags, openai_full_prefix
) # CLI args
yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {})
if isinstance(default_server_configs, ServerBaseline) and (
# Determine the base OpenAI config from config_init for merging
# This uses the instance we determined earlier for CLI definition defaults
openai_config_dict_base = (
default_openai_config_instance_for_cli.model_dump()
)
if isinstance(default_server_configs_from_init, ServerBaseline) and (
oai_cli_passed_args or yaml_oai_config
):
raise ValueError(
"ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use APIServerConfig." # noqa: E501
)
# If config_init provided ServerBaseline, but CLI/YAML provides OpenAI specifics,
# it implies an override intent for a single server.
# We use the default_openai_config_instance_for_cli (which would be a default APIServerConfig)
# as the base for merging, allowing it to be fully specified by YAML/CLI.
pass # Base is already set correctly for this case
if (
isinstance(default_server_configs, list)
and len(default_server_configs) == 1
):
# can't use the same var name because it shadows the class variable and we get an error
default_openai_config_ = default_server_configs[0]
else:
default_openai_config_ = default_server_configs
if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1:
yaml_oai_config = yaml_oai_config[0]
if isinstance(default_openai_config_, APIServerConfig) and isinstance(
yaml_oai_config, dict
):
openai_config_dict = merge_dicts(
default_openai_config_.model_dump(), # Default APIServerConfig (or from class init)
PROCESS_MODE_OPENAI_DEFAULT_CONFIG.model_dump(), # Process Mode Defaults
yaml_oai_config,
oai_cli_passed_args,
)
# If YAML specifies a single server config for OpenAI namespace
yaml_oai_single_server_config = yaml_oai_config[0]
elif isinstance(yaml_oai_config, dict):
yaml_oai_single_server_config = yaml_oai_config
else:
openai_config_dict = {}
yaml_oai_single_server_config = {}
openai_config_dict = merge_dicts(
openai_config_dict_base, # Default from config_init (or default APIServerConfig)
yaml_oai_single_server_config, # YAML config for a single server
oai_cli_passed_args, # CLI args
)
# 3. Server Manager Configuration
# Extract only relevant CLI flags
server_manager_cli_passed_flags = {}
if "slurm" in cli_passed_flags:
server_manager_cli_passed_flags["slurm"] = cli_passed_flags["slurm"]
@ -1333,39 +1342,85 @@ class BaseEnv(ABC):
if "testing" in yaml_config:
server_manager_yaml_dict["testing"] = yaml_config["testing"]
# Start with ServerManagerConfig defaults, then apply YAML, then CLI
# For process mode, slurm and testing are typically False unless specified.
server_manager_config_dict_base = ServerManagerConfig(
slurm=False, testing=False
).model_dump()
server_manager_config_dict = merge_dicts(
ServerManagerConfig().model_dump(), # Base defaults
PROCESS_MODE_SERVER_MANAGER_DEFAULT_CONFIG.model_dump(), # Process Mode Defaults
server_manager_config_dict_base,
server_manager_yaml_dict,
server_manager_cli_passed_flags, # CLI args
server_manager_cli_passed_flags,
)
# --- Instantiate Final Config Objects ---
# Use the classes with adjusted defaults for instantiation
# Use the original class types from config_init (or APIServerConfig for OpenAI CLI)
env_config = env_config_cls_new_defaults(**env_config_dict)
server_manager_config = server_manager_config_cls_new_defaults(
env_config = env_config_cls_from_init(**env_config_dict)
server_manager_config = ServerManagerConfig(
**server_manager_config_dict
)
# Determine the final server_configs, handling single, multiple servers, and overrides.
# Determine the final server_configs.
# For 'process', we typically expect a single server configuration for the OAI part.
# The resolve_openai_configs will handle complex cases, but for 'process',
# the openai_config_dict we built should represent the single intended server.
openai_configs = resolve_openai_configs(
default_server_configs=default_server_configs,
openai_config_dict=openai_config_dict,
yaml_config=yaml_config,
cli_passed_flags=cli_passed_flags,
# If default_server_configs_from_init was ServerBaseline, and we have openai_config_dict,
# it means we are overriding to use a specific APIServerConfig.
# If default_server_configs_from_init was a list or single APIServerConfig,
# resolve_openai_configs will merge appropriately.
final_openai_configs = resolve_openai_configs(
default_server_configs=default_server_configs_from_init, # Pass the original structure
openai_config_dict=openai_config_dict, # This is the merged single server config for CLI/YAML
yaml_config=yaml_config, # Pass full YAML for resolve_openai_configs logic
cli_passed_flags=cli_passed_flags, # Pass full CLI for resolve_openai_configs
logger=logger,
)
# Add warning for localhost or 0.0.0.0
if isinstance(final_openai_configs, list):
for cfg in final_openai_configs:
if (
isinstance(cfg, APIServerConfig)
and cfg.base_url
and (
"localhost" in cfg.base_url
or "0.0.0.0" in cfg.base_url
or "127.0.0.1" in cfg.base_url
)
):
warnings.warn(
"You are using a local Base URL for an OpenAI compatible server in 'process' mode. "
"Ensure you have a server running at this address or results may not be generated.",
UserWarning,
)
break # Warn once
elif (
isinstance(final_openai_configs, APIServerConfig)
and final_openai_configs.base_url
and (
"localhost" in final_openai_configs.base_url
or "0.0.0.0" in final_openai_configs.base_url
or "127.0.0.1" in final_openai_configs.base_url
)
):
warnings.warn(
"You are using a local Base URL for an OpenAI compatible server in 'process' mode. "
"Ensure you have a server running at this address or results may not be generated.",
UserWarning,
)
rprint(env_config)
rprint(openai_configs)
rprint(final_openai_configs)
# --- Create and Run Environment ---
# Create the environment instance
env = cls(
config=env_config,
server_configs=openai_configs,
server_configs=final_openai_configs,
slurm=server_manager_config.slurm,
testing=server_manager_config.testing,
)