mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
hide complicated openai config override behavior somewhere else
This commit is contained in:
parent
fe616ec7fa
commit
4348dd2ec1
3 changed files with 166 additions and 41 deletions
|
|
@ -20,15 +20,12 @@ import wandb
|
|||
import yaml
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_cli import Cmd, FailedExecutionException, run_and_exit
|
||||
from rich import print as rprint
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from atroposlib.envs.constants import (
|
||||
ENV_NAMESPACE,
|
||||
NAMESPACE_SEP,
|
||||
OPENAI_NAMESPACE,
|
||||
SERVER_MANAGER_NAMESPACE,
|
||||
)
|
||||
from atroposlib.envs.constants import ENV_NAMESPACE, NAMESPACE_SEP, OPENAI_NAMESPACE
|
||||
from atroposlib.envs.server_handling.openai_server import resolve_openai_configs
|
||||
from atroposlib.frontend.jsonl2html import generate_html
|
||||
from atroposlib.type_definitions import UUID
|
||||
from atroposlib.utils.cli import (
|
||||
|
|
@ -1022,11 +1019,29 @@ class BaseEnv(ABC):
|
|||
)
|
||||
|
||||
# 2. OpenAI Configuration (used for potential overrides)
|
||||
openai_config_dict = merge_dicts(
|
||||
default_openai_config.model_dump(), # Default OpenaiConfig (or from class init)
|
||||
yaml_config.get(OPENAI_NAMESPACE, {}), # YAML config
|
||||
extract_namespace(cli_passed_flags, openai_full_prefix), # CLI args
|
||||
)
|
||||
yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {})
|
||||
if (
|
||||
isinstance(default_openai_config, list)
|
||||
and len(default_openai_config) == 1
|
||||
):
|
||||
# can't use the same var name because it shadows the class variable and we get an error
|
||||
default_openai_config_ = default_openai_config[0]
|
||||
else:
|
||||
default_openai_config_ = default_openai_config
|
||||
if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1:
|
||||
yaml_oai_config = yaml_oai_config[0]
|
||||
if isinstance(default_openai_config, OpenaiConfig) and isinstance(
|
||||
yaml_oai_config, dict
|
||||
):
|
||||
openai_config_dict = merge_dicts(
|
||||
default_openai_config_.model_dump(), # Default OpenaiConfig (or from class init)
|
||||
yaml_oai_config,
|
||||
extract_namespace(
|
||||
cli_passed_flags, openai_full_prefix
|
||||
), # CLI args
|
||||
)
|
||||
else:
|
||||
openai_config_dict = {}
|
||||
|
||||
# 3. Server Manager Configuration (slurm, testing - not namespaced)
|
||||
# Extract only relevant CLI flags for ServerManager
|
||||
|
|
@ -1038,9 +1053,15 @@ class BaseEnv(ABC):
|
|||
"testing"
|
||||
]
|
||||
|
||||
server_manager_yaml_dict = {}
|
||||
if "slurm" in yaml_config:
|
||||
server_manager_yaml_dict["slurm"] = yaml_config["slurm"]
|
||||
if "testing" in yaml_config:
|
||||
server_manager_yaml_dict["testing"] = yaml_config["testing"]
|
||||
|
||||
server_manager_config_dict = merge_dicts(
|
||||
ServerManagerConfig().model_dump(), # Base defaults for ServerManager
|
||||
yaml_config.get(SERVER_MANAGER_NAMESPACE, {}), # YAML config
|
||||
server_manager_yaml_dict, # YAML config
|
||||
server_manager_cli_passed_flags, # CLI args
|
||||
)
|
||||
|
||||
|
|
@ -1055,40 +1076,26 @@ class BaseEnv(ABC):
|
|||
**server_manager_config_dict
|
||||
)
|
||||
|
||||
# Determine the final server_configs based on the original default type from cls.config_init()
|
||||
# This allows handling ServerBaseline, single OpenaiConfig, or list[OpenaiConfig]
|
||||
server_configs = (
|
||||
default_server_configs # Start with the original default
|
||||
# Determine the final server_configs, handling single, multiple servers, and overrides.
|
||||
|
||||
openai_configs = resolve_openai_configs(
|
||||
default_server_configs=default_server_configs,
|
||||
openai_config_dict=openai_config_dict,
|
||||
yaml_config=yaml_config,
|
||||
cli_passed_flags=cli_passed_flags,
|
||||
logger=logger,
|
||||
)
|
||||
if isinstance(default_server_configs, OpenaiConfig):
|
||||
# If default was single OpenaiConfig, update it with merged values
|
||||
server_configs = OpenaiConfig(**openai_config_dict)
|
||||
elif isinstance(default_server_configs, list):
|
||||
# If default was list (presumably of OpenaiConfig), update the first or create one
|
||||
# This assumes the primary server config is the one overridden via CLI/YAML
|
||||
if default_server_configs and isinstance(
|
||||
default_server_configs[0], OpenaiConfig
|
||||
):
|
||||
# Update the first element, keep others as they were
|
||||
server_configs = [
|
||||
OpenaiConfig(**openai_config_dict)
|
||||
] + default_server_configs[1:]
|
||||
else:
|
||||
# If list was empty or didn't contain OpenaiConfig, create a new list with the merged config
|
||||
server_configs = [OpenaiConfig(**openai_config_dict)]
|
||||
# If default_server_configs was ServerBaseline, server_configs remains ServerBaseline,
|
||||
# effectively ignoring the openai_config_dict unless the user explicitly provides
|
||||
# OpenaiConfig settings via CLI/YAML, which would be captured but not used here unless
|
||||
# the environment class's config_init returned an OpenaiConfig or list.
|
||||
|
||||
# --- Create and Run Environment ---
|
||||
# Create the environment instance using the final, instantiated config objects
|
||||
env = cls(
|
||||
config=env_config,
|
||||
server_configs=server_configs,
|
||||
server_configs=openai_configs,
|
||||
slurm=server_manager_config.slurm,
|
||||
testing=server_manager_config.testing,
|
||||
)
|
||||
rprint(env_config)
|
||||
rprint(openai_configs)
|
||||
|
||||
# Run the environment's main asynchronous manager function
|
||||
asyncio.run(env.env_manager())
|
||||
|
|
@ -1231,10 +1238,16 @@ class BaseEnv(ABC):
|
|||
"testing"
|
||||
]
|
||||
|
||||
server_manager_yaml_dict = {}
|
||||
if "slurm" in yaml_config:
|
||||
server_manager_yaml_dict["slurm"] = yaml_config["slurm"]
|
||||
if "testing" in yaml_config:
|
||||
server_manager_yaml_dict["testing"] = yaml_config["testing"]
|
||||
|
||||
server_manager_config_dict = merge_dicts(
|
||||
ServerManagerConfig().model_dump(), # Base defaults
|
||||
PROCESS_MODE_SERVER_MANAGER_DEFAULT_CONFIG.model_dump(), # Process Mode Defaults
|
||||
yaml_config.get(SERVER_MANAGER_NAMESPACE, {}), # YAML config
|
||||
server_manager_yaml_dict,
|
||||
server_manager_cli_passed_flags, # CLI args
|
||||
)
|
||||
|
||||
|
|
@ -1242,17 +1255,29 @@ class BaseEnv(ABC):
|
|||
# Use the classes with adjusted defaults for instantiation
|
||||
|
||||
env_config = env_config_cls_new_defaults(**env_config_dict)
|
||||
openai_config = openai_config_cls_new_defaults(**openai_config_dict)
|
||||
server_manager_config = server_manager_config_cls_new_defaults(
|
||||
**server_manager_config_dict
|
||||
)
|
||||
|
||||
# Determine the final server_configs, handling single, multiple servers, and overrides.
|
||||
|
||||
openai_configs = resolve_openai_configs(
|
||||
default_server_configs=default_server_configs,
|
||||
openai_config_dict=openai_config_dict,
|
||||
yaml_config=yaml_config,
|
||||
cli_passed_flags=cli_passed_flags,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
rprint(env_config)
|
||||
rprint(openai_configs)
|
||||
|
||||
# --- Create and Run Environment ---
|
||||
# Create the environment instance
|
||||
env = cls(
|
||||
config=env_config,
|
||||
# Process mode always uses a single OpenAI config, passed as a list
|
||||
server_configs=[openai_config],
|
||||
# Use the resolved configs (single or list)
|
||||
server_configs=openai_configs,
|
||||
slurm=server_manager_config.slurm,
|
||||
testing=server_manager_config.testing,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue