mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
remove process defaults, respect config init
This commit is contained in:
parent
0f8b60c119
commit
b5e7746c99
1 changed files with 135 additions and 80 deletions
|
|
@ -29,7 +29,6 @@ from atroposlib.envs.server_handling.openai_server import resolve_openai_configs
|
|||
from atroposlib.frontend.jsonl2html import generate_html
|
||||
from atroposlib.type_definitions import UUID
|
||||
from atroposlib.utils.cli import (
|
||||
adjust_model_defaults,
|
||||
extract_namespace,
|
||||
get_double_dash_flags,
|
||||
get_prefixed_pydantic_model,
|
||||
|
|
@ -1193,57 +1192,57 @@ class BaseEnv(ABC):
|
|||
type: The CliProcessConfig class for processing commands.
|
||||
"""
|
||||
|
||||
# Define specific default configurations for the 'process' mode
|
||||
PROCESS_MODE_ENV_DEFAULT_CONFIG = BaseEnvConfig(
|
||||
group_size=8,
|
||||
total_steps=2,
|
||||
ensure_scores_are_not_same=False,
|
||||
include_messages=True,
|
||||
data_path_to_save_groups=f"data/{cls.name or 'groups'}.jsonl",
|
||||
use_wandb=True,
|
||||
# Get the default configurations from the specific environment class via config_init
|
||||
default_env_config_from_init, default_server_configs_from_init = (
|
||||
cls.config_init()
|
||||
)
|
||||
PROCESS_MODE_OPENAI_DEFAULT_CONFIG = APIServerConfig(
|
||||
model_name="gpt-4.1-nano",
|
||||
base_url=None,
|
||||
api_key=None,
|
||||
)
|
||||
PROCESS_MODE_SERVER_MANAGER_DEFAULT_CONFIG = ServerManagerConfig(
|
||||
slurm=False,
|
||||
testing=False,
|
||||
)
|
||||
|
||||
# Get the base default configurations from the specific environment class
|
||||
default_env_config, default_server_configs = cls.config_init()
|
||||
|
||||
# Define namespace prefixes
|
||||
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
|
||||
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
|
||||
|
||||
# Create Pydantic model classes with the 'process' mode defaults applied.
|
||||
# These adjusted classes will be used for final instantiation.
|
||||
env_config_cls_new_defaults = adjust_model_defaults(
|
||||
type(default_env_config), PROCESS_MODE_ENV_DEFAULT_CONFIG
|
||||
)
|
||||
openai_config_cls_new_defaults = adjust_model_defaults(
|
||||
APIServerConfig, PROCESS_MODE_OPENAI_DEFAULT_CONFIG
|
||||
)
|
||||
server_manager_config_cls_new_defaults = adjust_model_defaults(
|
||||
ServerManagerConfig,
|
||||
PROCESS_MODE_SERVER_MANAGER_DEFAULT_CONFIG,
|
||||
)
|
||||
# Create Pydantic model classes based on the types from config_init.
|
||||
# The defaults from config_init will be the primary source of defaults.
|
||||
env_config_cls_from_init = type(default_env_config_from_init)
|
||||
|
||||
# Handle server_configs_from_init appropriately for creating a default CLI model
|
||||
# If it's a list (multiple servers), we'll take the first one as a template for CLI args,
|
||||
# or use APIServerConfig if the list is empty or contains ServerBaseline.
|
||||
# If it's a single APIServerConfig, we use its type.
|
||||
# If it's ServerBaseline, we use APIServerConfig type for CLI args to allow overrides.
|
||||
if isinstance(default_server_configs_from_init, list):
|
||||
if default_server_configs_from_init and isinstance(
|
||||
default_server_configs_from_init[0], APIServerConfig
|
||||
):
|
||||
openai_config_cls_for_cli = type(default_server_configs_from_init[0])
|
||||
# Use the actual instance for default values later if it's a single config
|
||||
default_openai_config_instance_for_cli = (
|
||||
default_server_configs_from_init[0]
|
||||
if len(default_server_configs_from_init) == 1
|
||||
else openai_config_cls_for_cli()
|
||||
)
|
||||
else:
|
||||
openai_config_cls_for_cli = (
|
||||
APIServerConfig # Default to APIServerConfig for CLI definition
|
||||
)
|
||||
default_openai_config_instance_for_cli = APIServerConfig()
|
||||
elif isinstance(default_server_configs_from_init, APIServerConfig):
|
||||
openai_config_cls_for_cli = type(default_server_configs_from_init)
|
||||
default_openai_config_instance_for_cli = default_server_configs_from_init
|
||||
else: # ServerBaseline or other
|
||||
openai_config_cls_for_cli = APIServerConfig
|
||||
default_openai_config_instance_for_cli = APIServerConfig()
|
||||
|
||||
class CliProcessConfig(
|
||||
get_prefixed_pydantic_model(env_config_cls_new_defaults, env_full_prefix),
|
||||
get_prefixed_pydantic_model(
|
||||
openai_config_cls_new_defaults, openai_full_prefix
|
||||
),
|
||||
server_manager_config_cls_new_defaults,
|
||||
get_prefixed_pydantic_model(env_config_cls_from_init, env_full_prefix),
|
||||
get_prefixed_pydantic_model(openai_config_cls_for_cli, openai_full_prefix),
|
||||
ServerManagerConfig, # ServerManagerConfig defaults are fine as is.
|
||||
Cmd,
|
||||
):
|
||||
"""
|
||||
Configuration for the process command.
|
||||
Supports overrides via YAML config file and CLI arguments.
|
||||
Order of precedence: CLI > YAML > Process Mode Defaults > `config_init` defaults.
|
||||
Order of precedence: CLI > YAML > `config_init` defaults.
|
||||
"""
|
||||
|
||||
config: str | None = Field(
|
||||
|
|
@ -1273,12 +1272,22 @@ class BaseEnv(ABC):
|
|||
cli_passed_flags = get_double_dash_flags()
|
||||
|
||||
# --- Configuration Merging ---
|
||||
# Priority: CLI > YAML > Process Mode Defaults > `config_init` defaults
|
||||
# Priority: CLI > YAML > `config_init` defaults
|
||||
|
||||
# 1. Environment Configuration
|
||||
# Start with defaults from config_init
|
||||
env_config_dict_base = default_env_config_from_init.model_dump()
|
||||
# Apply specific overrides for process mode that are generally useful
|
||||
env_config_dict_base["ensure_scores_are_not_same"] = False
|
||||
env_config_dict_base["include_messages"] = True
|
||||
if env_config_dict_base.get("data_path_to_save_groups") is None:
|
||||
env_config_dict_base["data_path_to_save_groups"] = (
|
||||
f"data/{cls.name or 'groups'}.jsonl"
|
||||
)
|
||||
env_config_dict_base["use_wandb"] = True
|
||||
|
||||
env_config_dict = merge_dicts(
|
||||
default_env_config.model_dump(), # Class Defaults
|
||||
PROCESS_MODE_ENV_DEFAULT_CONFIG.model_dump(), # Process Mode Defaults
|
||||
env_config_dict_base, # `config_init` defaults with process adjustments
|
||||
yaml_config.get(ENV_NAMESPACE, {}), # YAML config
|
||||
extract_namespace(cli_passed_flags, env_full_prefix), # CLI args
|
||||
)
|
||||
|
|
@ -1288,37 +1297,37 @@ class BaseEnv(ABC):
|
|||
cli_passed_flags, openai_full_prefix
|
||||
) # CLI args
|
||||
yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {})
|
||||
if isinstance(default_server_configs, ServerBaseline) and (
|
||||
|
||||
# Determine the base OpenAI config from config_init for merging
|
||||
# This uses the instance we determined earlier for CLI definition defaults
|
||||
openai_config_dict_base = (
|
||||
default_openai_config_instance_for_cli.model_dump()
|
||||
)
|
||||
|
||||
if isinstance(default_server_configs_from_init, ServerBaseline) and (
|
||||
oai_cli_passed_args or yaml_oai_config
|
||||
):
|
||||
raise ValueError(
|
||||
"ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use APIServerConfig." # noqa: E501
|
||||
)
|
||||
# If config_init provided ServerBaseline, but CLI/YAML provides OpenAI specifics,
|
||||
# it implies an override intent for a single server.
|
||||
# We use the default_openai_config_instance_for_cli (which would be a default APIServerConfig)
|
||||
# as the base for merging, allowing it to be fully specified by YAML/CLI.
|
||||
pass # Base is already set correctly for this case
|
||||
|
||||
if (
|
||||
isinstance(default_server_configs, list)
|
||||
and len(default_server_configs) == 1
|
||||
):
|
||||
# can't use the same var name because it shadows the class variable and we get an error
|
||||
default_openai_config_ = default_server_configs[0]
|
||||
else:
|
||||
default_openai_config_ = default_server_configs
|
||||
if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1:
|
||||
yaml_oai_config = yaml_oai_config[0]
|
||||
if isinstance(default_openai_config_, APIServerConfig) and isinstance(
|
||||
yaml_oai_config, dict
|
||||
):
|
||||
openai_config_dict = merge_dicts(
|
||||
default_openai_config_.model_dump(), # Default APIServerConfig (or from class init)
|
||||
PROCESS_MODE_OPENAI_DEFAULT_CONFIG.model_dump(), # Process Mode Defaults
|
||||
yaml_oai_config,
|
||||
oai_cli_passed_args,
|
||||
)
|
||||
# If YAML specifies a single server config for OpenAI namespace
|
||||
yaml_oai_single_server_config = yaml_oai_config[0]
|
||||
elif isinstance(yaml_oai_config, dict):
|
||||
yaml_oai_single_server_config = yaml_oai_config
|
||||
else:
|
||||
openai_config_dict = {}
|
||||
yaml_oai_single_server_config = {}
|
||||
|
||||
openai_config_dict = merge_dicts(
|
||||
openai_config_dict_base, # Default from config_init (or default APIServerConfig)
|
||||
yaml_oai_single_server_config, # YAML config for a single server
|
||||
oai_cli_passed_args, # CLI args
|
||||
)
|
||||
|
||||
# 3. Server Manager Configuration
|
||||
# Extract only relevant CLI flags
|
||||
server_manager_cli_passed_flags = {}
|
||||
if "slurm" in cli_passed_flags:
|
||||
server_manager_cli_passed_flags["slurm"] = cli_passed_flags["slurm"]
|
||||
|
|
@ -1333,39 +1342,85 @@ class BaseEnv(ABC):
|
|||
if "testing" in yaml_config:
|
||||
server_manager_yaml_dict["testing"] = yaml_config["testing"]
|
||||
|
||||
# Start with ServerManagerConfig defaults, then apply YAML, then CLI
|
||||
# For process mode, slurm and testing are typically False unless specified.
|
||||
server_manager_config_dict_base = ServerManagerConfig(
|
||||
slurm=False, testing=False
|
||||
).model_dump()
|
||||
|
||||
server_manager_config_dict = merge_dicts(
|
||||
ServerManagerConfig().model_dump(), # Base defaults
|
||||
PROCESS_MODE_SERVER_MANAGER_DEFAULT_CONFIG.model_dump(), # Process Mode Defaults
|
||||
server_manager_config_dict_base,
|
||||
server_manager_yaml_dict,
|
||||
server_manager_cli_passed_flags, # CLI args
|
||||
server_manager_cli_passed_flags,
|
||||
)
|
||||
|
||||
# --- Instantiate Final Config Objects ---
|
||||
# Use the classes with adjusted defaults for instantiation
|
||||
# Use the original class types from config_init (or APIServerConfig for OpenAI CLI)
|
||||
|
||||
env_config = env_config_cls_new_defaults(**env_config_dict)
|
||||
server_manager_config = server_manager_config_cls_new_defaults(
|
||||
env_config = env_config_cls_from_init(**env_config_dict)
|
||||
server_manager_config = ServerManagerConfig(
|
||||
**server_manager_config_dict
|
||||
)
|
||||
|
||||
# Determine the final server_configs, handling single, multiple servers, and overrides.
|
||||
# Determine the final server_configs.
|
||||
# For 'process', we typically expect a single server configuration for the OAI part.
|
||||
# The resolve_openai_configs will handle complex cases, but for 'process',
|
||||
# the openai_config_dict we built should represent the single intended server.
|
||||
|
||||
openai_configs = resolve_openai_configs(
|
||||
default_server_configs=default_server_configs,
|
||||
openai_config_dict=openai_config_dict,
|
||||
yaml_config=yaml_config,
|
||||
cli_passed_flags=cli_passed_flags,
|
||||
# If default_server_configs_from_init was ServerBaseline, and we have openai_config_dict,
|
||||
# it means we are overriding to use a specific APIServerConfig.
|
||||
# If default_server_configs_from_init was a list or single APIServerConfig,
|
||||
# resolve_openai_configs will merge appropriately.
|
||||
|
||||
final_openai_configs = resolve_openai_configs(
|
||||
default_server_configs=default_server_configs_from_init, # Pass the original structure
|
||||
openai_config_dict=openai_config_dict, # This is the merged single server config for CLI/YAML
|
||||
yaml_config=yaml_config, # Pass full YAML for resolve_openai_configs logic
|
||||
cli_passed_flags=cli_passed_flags, # Pass full CLI for resolve_openai_configs
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
# Add warning for localhost or 0.0.0.0
|
||||
if isinstance(final_openai_configs, list):
|
||||
for cfg in final_openai_configs:
|
||||
if (
|
||||
isinstance(cfg, APIServerConfig)
|
||||
and cfg.base_url
|
||||
and (
|
||||
"localhost" in cfg.base_url
|
||||
or "0.0.0.0" in cfg.base_url
|
||||
or "127.0.0.1" in cfg.base_url
|
||||
)
|
||||
):
|
||||
warnings.warn(
|
||||
"You are using a local Base URL for an OpenAI compatible server in 'process' mode. "
|
||||
"Ensure you have a server running at this address or results may not be generated.",
|
||||
UserWarning,
|
||||
)
|
||||
break # Warn once
|
||||
elif (
|
||||
isinstance(final_openai_configs, APIServerConfig)
|
||||
and final_openai_configs.base_url
|
||||
and (
|
||||
"localhost" in final_openai_configs.base_url
|
||||
or "0.0.0.0" in final_openai_configs.base_url
|
||||
or "127.0.0.1" in final_openai_configs.base_url
|
||||
)
|
||||
):
|
||||
warnings.warn(
|
||||
"You are using a local Base URL for an OpenAI compatible server in 'process' mode. "
|
||||
"Ensure you have a server running at this address or results may not be generated.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
rprint(env_config)
|
||||
rprint(openai_configs)
|
||||
rprint(final_openai_configs)
|
||||
|
||||
# --- Create and Run Environment ---
|
||||
# Create the environment instance
|
||||
env = cls(
|
||||
config=env_config,
|
||||
server_configs=openai_configs,
|
||||
server_configs=final_openai_configs,
|
||||
slurm=server_manager_config.slurm,
|
||||
testing=server_manager_config.testing,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue