hide complicated openai config override behavior somewhere else

This commit is contained in:
hjc-puro 2025-05-03 14:18:50 -07:00
parent fe616ec7fa
commit 4348dd2ec1
3 changed files with 166 additions and 41 deletions

View file

@ -20,15 +20,12 @@ import wandb
import yaml
from pydantic import BaseModel, Field
from pydantic_cli import Cmd, FailedExecutionException, run_and_exit
from rich import print as rprint
from tenacity import retry, stop_after_attempt, wait_random_exponential
from transformers import AutoTokenizer
from atroposlib.envs.constants import (
ENV_NAMESPACE,
NAMESPACE_SEP,
OPENAI_NAMESPACE,
SERVER_MANAGER_NAMESPACE,
)
from atroposlib.envs.constants import ENV_NAMESPACE, NAMESPACE_SEP, OPENAI_NAMESPACE
from atroposlib.envs.server_handling.openai_server import resolve_openai_configs
from atroposlib.frontend.jsonl2html import generate_html
from atroposlib.type_definitions import UUID
from atroposlib.utils.cli import (
@ -1022,11 +1019,29 @@ class BaseEnv(ABC):
)
# 2. OpenAI Configuration (used for potential overrides)
openai_config_dict = merge_dicts(
default_openai_config.model_dump(), # Default OpenaiConfig (or from class init)
yaml_config.get(OPENAI_NAMESPACE, {}), # YAML config
extract_namespace(cli_passed_flags, openai_full_prefix), # CLI args
)
yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {})
if (
isinstance(default_openai_config, list)
and len(default_openai_config) == 1
):
# can't use the same var name because it shadows the class variable and we get an error
default_openai_config_ = default_openai_config[0]
else:
default_openai_config_ = default_openai_config
if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1:
yaml_oai_config = yaml_oai_config[0]
if isinstance(default_openai_config, OpenaiConfig) and isinstance(
yaml_oai_config, dict
):
openai_config_dict = merge_dicts(
default_openai_config_.model_dump(), # Default OpenaiConfig (or from class init)
yaml_oai_config,
extract_namespace(
cli_passed_flags, openai_full_prefix
), # CLI args
)
else:
openai_config_dict = {}
# 3. Server Manager Configuration (slurm, testing - not namespaced)
# Extract only relevant CLI flags for ServerManager
@ -1038,9 +1053,15 @@ class BaseEnv(ABC):
"testing"
]
server_manager_yaml_dict = {}
if "slurm" in yaml_config:
server_manager_yaml_dict["slurm"] = yaml_config["slurm"]
if "testing" in yaml_config:
server_manager_yaml_dict["testing"] = yaml_config["testing"]
server_manager_config_dict = merge_dicts(
ServerManagerConfig().model_dump(), # Base defaults for ServerManager
yaml_config.get(SERVER_MANAGER_NAMESPACE, {}), # YAML config
server_manager_yaml_dict, # YAML config
server_manager_cli_passed_flags, # CLI args
)
@ -1055,40 +1076,26 @@ class BaseEnv(ABC):
**server_manager_config_dict
)
# Determine the final server_configs based on the original default type from cls.config_init()
# This allows handling ServerBaseline, single OpenaiConfig, or list[OpenaiConfig]
server_configs = (
default_server_configs # Start with the original default
# Determine the final server_configs, handling single, multiple servers, and overrides.
openai_configs = resolve_openai_configs(
default_server_configs=default_server_configs,
openai_config_dict=openai_config_dict,
yaml_config=yaml_config,
cli_passed_flags=cli_passed_flags,
logger=logger,
)
if isinstance(default_server_configs, OpenaiConfig):
# If default was single OpenaiConfig, update it with merged values
server_configs = OpenaiConfig(**openai_config_dict)
elif isinstance(default_server_configs, list):
# If default was list (presumably of OpenaiConfig), update the first or create one
# This assumes the primary server config is the one overridden via CLI/YAML
if default_server_configs and isinstance(
default_server_configs[0], OpenaiConfig
):
# Update the first element, keep others as they were
server_configs = [
OpenaiConfig(**openai_config_dict)
] + default_server_configs[1:]
else:
# If list was empty or didn't contain OpenaiConfig, create a new list with the merged config
server_configs = [OpenaiConfig(**openai_config_dict)]
# If default_server_configs was ServerBaseline, server_configs remains ServerBaseline,
# effectively ignoring the openai_config_dict unless the user explicitly provides
# OpenaiConfig settings via CLI/YAML, which would be captured but not used here unless
# the environment class's config_init returned an OpenaiConfig or list.
# --- Create and Run Environment ---
# Create the environment instance using the final, instantiated config objects
env = cls(
config=env_config,
server_configs=server_configs,
server_configs=openai_configs,
slurm=server_manager_config.slurm,
testing=server_manager_config.testing,
)
rprint(env_config)
rprint(openai_configs)
# Run the environment's main asynchronous manager function
asyncio.run(env.env_manager())
@ -1231,10 +1238,16 @@ class BaseEnv(ABC):
"testing"
]
server_manager_yaml_dict = {}
if "slurm" in yaml_config:
server_manager_yaml_dict["slurm"] = yaml_config["slurm"]
if "testing" in yaml_config:
server_manager_yaml_dict["testing"] = yaml_config["testing"]
server_manager_config_dict = merge_dicts(
ServerManagerConfig().model_dump(), # Base defaults
PROCESS_MODE_SERVER_MANAGER_DEFAULT_CONFIG.model_dump(), # Process Mode Defaults
yaml_config.get(SERVER_MANAGER_NAMESPACE, {}), # YAML config
server_manager_yaml_dict,
server_manager_cli_passed_flags, # CLI args
)
@ -1242,17 +1255,29 @@ class BaseEnv(ABC):
# Use the classes with adjusted defaults for instantiation
env_config = env_config_cls_new_defaults(**env_config_dict)
openai_config = openai_config_cls_new_defaults(**openai_config_dict)
server_manager_config = server_manager_config_cls_new_defaults(
**server_manager_config_dict
)
# Determine the final server_configs, handling single, multiple servers, and overrides.
openai_configs = resolve_openai_configs(
default_server_configs=default_server_configs,
openai_config_dict=openai_config_dict,
yaml_config=yaml_config,
cli_passed_flags=cli_passed_flags,
logger=logger,
)
rprint(env_config)
rprint(openai_configs)
# --- Create and Run Environment ---
# Create the environment instance
env = cls(
config=env_config,
# Process mode always uses a single OpenAI config, passed as a list
server_configs=[openai_config],
# Use the resolved configs (single or list)
server_configs=openai_configs,
slurm=server_manager_config.slurm,
testing=server_manager_config.testing,
)