diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index c8aaf53a..2c2c8612 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -20,15 +20,12 @@ import wandb import yaml from pydantic import BaseModel, Field from pydantic_cli import Cmd, FailedExecutionException, run_and_exit +from rich import print as rprint from tenacity import retry, stop_after_attempt, wait_random_exponential from transformers import AutoTokenizer -from atroposlib.envs.constants import ( - ENV_NAMESPACE, - NAMESPACE_SEP, - OPENAI_NAMESPACE, - SERVER_MANAGER_NAMESPACE, -) +from atroposlib.envs.constants import ENV_NAMESPACE, NAMESPACE_SEP, OPENAI_NAMESPACE +from atroposlib.envs.server_handling.openai_server import resolve_openai_configs from atroposlib.frontend.jsonl2html import generate_html from atroposlib.type_definitions import UUID from atroposlib.utils.cli import ( @@ -1022,11 +1019,29 @@ class BaseEnv(ABC): ) # 2. OpenAI Configuration (used for potential overrides) - openai_config_dict = merge_dicts( - default_openai_config.model_dump(), # Default OpenaiConfig (or from class init) - yaml_config.get(OPENAI_NAMESPACE, {}), # YAML config - extract_namespace(cli_passed_flags, openai_full_prefix), # CLI args - ) + yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {}) + if ( + isinstance(default_openai_config, list) + and len(default_openai_config) == 1 + ): + # can't use the same var name because it shadows the class variable and we get an error + default_openai_config_ = default_openai_config[0] + else: + default_openai_config_ = default_openai_config + if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1: + yaml_oai_config = yaml_oai_config[0] + if isinstance(default_openai_config, OpenaiConfig) and isinstance( + yaml_oai_config, dict + ): + openai_config_dict = merge_dicts( + default_openai_config_.model_dump(), # Default OpenaiConfig (or from class init) + yaml_oai_config, + extract_namespace( + cli_passed_flags, openai_full_prefix + ), # CLI args + ) + else: + openai_config_dict = {} # 3. Server Manager Configuration (slurm, testing - not namespaced) # Extract only relevant CLI flags for ServerManager @@ -1038,9 +1053,15 @@ class BaseEnv(ABC): "testing" ] + server_manager_yaml_dict = {} + if "slurm" in yaml_config: + server_manager_yaml_dict["slurm"] = yaml_config["slurm"] + if "testing" in yaml_config: + server_manager_yaml_dict["testing"] = yaml_config["testing"] + server_manager_config_dict = merge_dicts( ServerManagerConfig().model_dump(), # Base defaults for ServerManager - yaml_config.get(SERVER_MANAGER_NAMESPACE, {}), # YAML config + server_manager_yaml_dict, # YAML config server_manager_cli_passed_flags, # CLI args ) @@ -1055,40 +1076,26 @@ class BaseEnv(ABC): **server_manager_config_dict ) - # Determine the final server_configs based on the original default type from cls.config_init() - # This allows handling ServerBaseline, single OpenaiConfig, or list[OpenaiConfig] - server_configs = ( - default_server_configs # Start with the original default + # Determine the final server_configs, handling single, multiple servers, and overrides. + + openai_configs = resolve_openai_configs( + default_server_configs=default_server_configs, + openai_config_dict=openai_config_dict, + yaml_config=yaml_config, + cli_passed_flags=cli_passed_flags, + logger=logger, ) - if isinstance(default_server_configs, OpenaiConfig): - # If default was single OpenaiConfig, update it with merged values - server_configs = OpenaiConfig(**openai_config_dict) - elif isinstance(default_server_configs, list): - # If default was list (presumably of OpenaiConfig), update the first or create one - # This assumes the primary server config is the one overridden via CLI/YAML - if default_server_configs and isinstance( - default_server_configs[0], OpenaiConfig - ): - # Update the first element, keep others as they were - server_configs = [ - OpenaiConfig(**openai_config_dict) - ] + default_server_configs[1:] - else: - # If list was empty or didn't contain OpenaiConfig, create a new list with the merged config - server_configs = [OpenaiConfig(**openai_config_dict)] - # If default_server_configs was ServerBaseline, server_configs remains ServerBaseline, - # effectively ignoring the openai_config_dict unless the user explicitly provides - # OpenaiConfig settings via CLI/YAML, which would be captured but not used here unless - # the environment class's config_init returned an OpenaiConfig or list. # --- Create and Run Environment --- # Create the environment instance using the final, instantiated config objects env = cls( config=env_config, - server_configs=server_configs, + server_configs=openai_configs, slurm=server_manager_config.slurm, testing=server_manager_config.testing, ) + rprint(env_config) + rprint(openai_configs) # Run the environment's main asynchronous manager function asyncio.run(env.env_manager()) @@ -1231,10 +1238,16 @@ class BaseEnv(ABC): "testing" ] + server_manager_yaml_dict = {} + if "slurm" in yaml_config: + server_manager_yaml_dict["slurm"] = yaml_config["slurm"] + if "testing" in yaml_config: + server_manager_yaml_dict["testing"] = yaml_config["testing"] + server_manager_config_dict = merge_dicts( ServerManagerConfig().model_dump(), # Base defaults PROCESS_MODE_SERVER_MANAGER_DEFAULT_CONFIG.model_dump(), # Process Mode Defaults - yaml_config.get(SERVER_MANAGER_NAMESPACE, {}), # YAML config + server_manager_yaml_dict, server_manager_cli_passed_flags, # CLI args ) @@ -1242,17 +1255,29 @@ class BaseEnv(ABC): # Use the classes with adjusted defaults for instantiation env_config = env_config_cls_new_defaults(**env_config_dict) - openai_config = openai_config_cls_new_defaults(**openai_config_dict) server_manager_config = server_manager_config_cls_new_defaults( **server_manager_config_dict ) + # Determine the final server_configs, handling single, multiple servers, and overrides. + + openai_configs = resolve_openai_configs( + default_server_configs=default_server_configs, + openai_config_dict=openai_config_dict, + yaml_config=yaml_config, + cli_passed_flags=cli_passed_flags, + logger=logger, + ) + + rprint(env_config) + rprint(openai_configs) + # --- Create and Run Environment --- # Create the environment instance env = cls( config=env_config, - # Process mode always uses a single OpenAI config, passed as a list - server_configs=[openai_config], + # Use the resolved configs (single or list) + server_configs=openai_configs, slurm=server_manager_config.slurm, testing=server_manager_config.testing, ) diff --git a/atroposlib/envs/server_handling/openai_server.py b/atroposlib/envs/server_handling/openai_server.py index 4a3562cb..f161869e 100644 --- a/atroposlib/envs/server_handling/openai_server.py +++ b/atroposlib/envs/server_handling/openai_server.py @@ -10,8 +10,11 @@ import openai from openai.types.chat.chat_completion import ChatCompletion from openai.types.completion import Completion from pydantic import BaseModel, Field +from pydantic_cli import FailedExecutionException from tenacity import retry, stop_after_attempt, wait_random_exponential +from atroposlib.envs.constants import NAMESPACE_SEP, OPENAI_NAMESPACE + class OpenaiConfig(BaseModel): """ @@ -294,3 +297,79 @@ class OpenAIServer: self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"]) self.eval_attempts_list.append(stat_dict["attempts"]) return ret_data + + +def resolve_openai_configs( + default_server_configs, + openai_config_dict, + yaml_config, + cli_passed_flags, + logger, +): + """ + Helper to resolve the final server_configs, handling single, multiple servers, and overrides. + """ + from atroposlib.envs.server_handling.server_manager import ServerBaseline + + openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}" + openai_yaml_config = yaml_config.get(OPENAI_NAMESPACE, None) + openai_cli_config = { + k: v for k, v in cli_passed_flags.items() if k.startswith(openai_full_prefix) + } + + is_multi_server_yaml = ( + isinstance(openai_yaml_config, list) and len(openai_yaml_config) >= 2 + ) + is_multi_server_default = ( + (not is_multi_server_yaml) + and isinstance(default_server_configs, list) + and len(default_server_configs) >= 2 + ) + + if (is_multi_server_yaml or is_multi_server_default) and openai_cli_config: + raise FailedExecutionException( + f"CLI overrides for OpenAI settings (--{openai_full_prefix}*) are not supported " + f"when multiple servers are defined (either via YAML list under '{OPENAI_NAMESPACE}' " + "or a default list with length >= 2)." + ) + + if is_multi_server_yaml: + logger.info( + f"Using multi-server configuration defined in YAML under '{OPENAI_NAMESPACE}'." + ) + try: + server_configs = [OpenaiConfig(**cfg) for cfg in openai_yaml_config] + except Exception as e: + raise FailedExecutionException( + f"Error parsing multi-server OpenAI configuration from YAML under '{OPENAI_NAMESPACE}': {e}" + ) from e + elif isinstance(default_server_configs, ServerBaseline): + logger.info("Using ServerBaseline configuration.") + server_configs = default_server_configs + elif is_multi_server_default: + logger.info("Using default multi-server configuration (length >= 2).") + server_configs = default_server_configs + else: + logger.info( + "Using single OpenAI server configuration based on merged settings (default/YAML/CLI)." + ) + try: + final_openai_config = OpenaiConfig(**openai_config_dict) + except Exception as e: + raise FailedExecutionException( + f"Error creating final OpenAI configuration from merged settings: {e}\n" + f"Merged Dict: {openai_config_dict}" + ) from e + + if isinstance(default_server_configs, OpenaiConfig): + server_configs = final_openai_config + elif isinstance(default_server_configs, list): + server_configs = [final_openai_config] + else: + logger.warning( + f"Unexpected type for default_server_configs: {type(default_server_configs)}. " + f"Proceeding with single OpenAI server configuration based on merged settings." + ) + server_configs = [final_openai_config] + + return server_configs diff --git a/environments/configs/example.yaml b/environments/configs/example.yaml new file mode 100644 index 00000000..680458cf --- /dev/null +++ b/environments/configs/example.yaml @@ -0,0 +1,21 @@ +# Environment configuration +env: + group_size: 4 + max_batches_offpolicy: 3 + tokenizer_name: "Qwen/Qwen2.5-1.5B-Instruct" + use_wandb: true + rollout_server_url: "http://localhost:8000" + wandb_name: "example_env" + ensure_scores_are_not_same: true + data_path_to_save_groups: null + include_messages: true # if data_path_to_save_groups is set this will add the messages to the saved .jsonl + +# OpenAI server configurations +openai: + - model_name: "Qwen/Qwen2.5-1.5B-Instruct" + base_url: "http://localhost:9001/v1" + api_key: "x" + weight: 1.0 + +slurm: false +testing: false