diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 99917546..6abf13d8 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -1142,17 +1142,25 @@ class BaseEnv(ABC): generate_html(self.config.data_path_to_save_groups) + async def _run_evaluate(self): + """ + Internal method to run evaluation with proper setup. + """ + await self.setup() + await self.evaluate() + @classmethod def cli(cls): """ Command-line interface entry point for the environment. - This method handles the CLI commands for serve and process. + This method handles the CLI commands for serve, process, and evaluate. """ # Create subcommands dictionary subcommands = { "serve": cls.get_cli_serve_config_cls(), "process": cls.get_cli_process_config_cls(), + "evaluate": cls.get_cli_evaluate_config_cls(), } # Custom exception handler for cleaner error output @@ -1603,3 +1611,251 @@ class BaseEnv(ABC): asyncio.run(env.process_manager()) return CliProcessConfig + + @classmethod + def get_cli_evaluate_config_cls(cls) -> type: + """ + Returns the CLI configuration class for evaluate commands. + + Returns: + type: The CliEvaluateConfig class for evaluate commands. + """ + # Get the default configurations from the specific environment class via config_init + default_env_config_from_init, default_server_configs_from_init = ( + cls.config_init() + ) + + # Define namespace prefixes + env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}" + openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}" + + # Create Pydantic model classes based on the types from config_init. + # The defaults from config_init will be the primary source of defaults. + env_config_cls_from_init = type(default_env_config_from_init) + + # Handle server_configs_from_init appropriately for creating a default CLI model + # If it's a list (multiple servers), we'll take the first one as a template for CLI args, + # or use APIServerConfig if the list is empty or contains ServerBaseline. + # If it's a single APIServerConfig, we use its type. + # If it's ServerBaseline, we use APIServerConfig type for CLI args to allow overrides. + if isinstance(default_server_configs_from_init, list): + if default_server_configs_from_init and isinstance( + default_server_configs_from_init[0], APIServerConfig + ): + openai_config_cls_for_cli = type(default_server_configs_from_init[0]) + # Use the actual instance for default values later if it's a single config + default_openai_config_instance_for_cli = ( + default_server_configs_from_init[0] + if len(default_server_configs_from_init) == 1 + else openai_config_cls_for_cli() + ) + else: + openai_config_cls_for_cli = ( + APIServerConfig # Default to APIServerConfig for CLI definition + ) + default_openai_config_instance_for_cli = APIServerConfig() + elif isinstance(default_server_configs_from_init, APIServerConfig): + openai_config_cls_for_cli = type(default_server_configs_from_init) + default_openai_config_instance_for_cli = default_server_configs_from_init + else: # ServerBaseline or other + openai_config_cls_for_cli = APIServerConfig + default_openai_config_instance_for_cli = APIServerConfig() + + class CliEvaluateConfig( + get_prefixed_pydantic_model(env_config_cls_from_init, env_full_prefix), + get_prefixed_pydantic_model(openai_config_cls_for_cli, openai_full_prefix), + ServerManagerConfig, # ServerManagerConfig defaults are fine as is. + Cmd, + ): + """ + Configuration for the evaluate command. + Supports overrides via YAML config file and CLI arguments. + Order of precedence: CLI > YAML > `config_init` defaults. + """ + + config: str | None = Field( + default=None, + description="Path to .yaml config file. CLI args override this.", + ) + + def run(self) -> None: + """The logic to execute for the 'evaluate' command.""" + # Set default wandb name if not provided and class has a name + wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name" + if ( + getattr(self, wandb_name_attr, None) is None + and cls.name is not None + ): + setattr(self, wandb_name_attr, cls.name) + + # Load configuration from YAML file if specified + if self.config is not None: + with open(self.config, "r") as f: + yaml_config = yaml.safe_load(f) + print(f"Loaded config from {self.config}") + else: + yaml_config = {} + + # Get CLI flags passed with double dashes + cli_passed_flags = get_double_dash_flags() + + # --- Configuration Merging --- + # Priority: CLI > YAML > `config_init` defaults + + # 1. Environment Configuration + # Start with defaults from config_init + env_config_dict_base = default_env_config_from_init.model_dump() + # Apply specific overrides for evaluate mode that are generally useful + env_config_dict_base["use_wandb"] = True + + env_config_dict = merge_dicts( + env_config_dict_base, # `config_init` defaults with evaluate adjustments + yaml_config.get(ENV_NAMESPACE, {}), # YAML config + extract_namespace(cli_passed_flags, env_full_prefix), # CLI args + ) + + # 2. OpenAI Configuration + oai_cli_passed_args = extract_namespace( + cli_passed_flags, openai_full_prefix + ) # CLI args + yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {}) + + # Determine the base OpenAI config from config_init for merging + # This uses the instance we determined earlier for CLI definition defaults + openai_config_dict_base = ( + default_openai_config_instance_for_cli.model_dump() + ) + + if isinstance(default_server_configs_from_init, ServerBaseline) and ( + oai_cli_passed_args or yaml_oai_config + ): + # If config_init provided ServerBaseline, but CLI/YAML provides OpenAI specifics, + # it implies an override intent for a single server. + # We use the default_openai_config_instance_for_cli (which would be a default APIServerConfig) + # as the base for merging, allowing it to be fully specified by YAML/CLI. + pass # Base is already set correctly for this case + + if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1: + # If YAML specifies a single server config for OpenAI namespace + yaml_oai_single_server_config = yaml_oai_config[0] + elif isinstance(yaml_oai_config, dict): + yaml_oai_single_server_config = yaml_oai_config + else: + yaml_oai_single_server_config = {} + + openai_config_dict = merge_dicts( + openai_config_dict_base, # Default from config_init (or default APIServerConfig) + yaml_oai_single_server_config, # YAML config for a single server + oai_cli_passed_args, # CLI args + ) + + # 3. Server Manager Configuration + server_manager_cli_passed_flags = {} + if "slurm" in cli_passed_flags: + server_manager_cli_passed_flags["slurm"] = cli_passed_flags["slurm"] + if "testing" in cli_passed_flags: + server_manager_cli_passed_flags["testing"] = cli_passed_flags[ + "testing" + ] + + server_manager_yaml_dict = {} + if "slurm" in yaml_config: + server_manager_yaml_dict["slurm"] = yaml_config["slurm"] + if "testing" in yaml_config: + server_manager_yaml_dict["testing"] = yaml_config["testing"] + + # Start with ServerManagerConfig defaults, then apply YAML, then CLI + # For evaluate mode, slurm and testing are typically False unless specified. + server_manager_config_dict_base = ServerManagerConfig( + slurm=False, testing=False + ).model_dump() + + server_manager_config_dict = merge_dicts( + server_manager_config_dict_base, + server_manager_yaml_dict, + server_manager_cli_passed_flags, + ) + + # --- Instantiate Final Config Objects --- + # Use the original class types from config_init (or APIServerConfig for OpenAI CLI) + + env_config = env_config_cls_from_init(**env_config_dict) + server_manager_config = ServerManagerConfig( + **server_manager_config_dict + ) + + # Determine the final server_configs. + # For 'evaluate', we typically expect a single server configuration for the OAI part. + # The resolve_openai_configs will handle complex cases, but for 'evaluate', + # the openai_config_dict we built should represent the single intended server. + + # If default_server_configs_from_init was ServerBaseline, and we have openai_config_dict, + # it means we are overriding to use a specific APIServerConfig. + # If default_server_configs_from_init was a list or single APIServerConfig, + # resolve_openai_configs will merge appropriately. + + final_openai_configs = resolve_openai_configs( + default_server_configs=default_server_configs_from_init, # Pass the original structure + openai_config_dict=openai_config_dict, # This is the merged single server config for CLI/YAML + yaml_config=yaml_config, # Pass full YAML for resolve_openai_configs logic + cli_passed_flags=cli_passed_flags, # Pass full CLI for resolve_openai_configs + logger=logger, + ) + + # Add warning for localhost or 0.0.0.0 + if isinstance(final_openai_configs, list): + for cfg in final_openai_configs: + if ( + isinstance(cfg, APIServerConfig) + and cfg.base_url + and ( + "localhost" in cfg.base_url + or "0.0.0.0" in cfg.base_url + or "127.0.0.1" in cfg.base_url + ) + ): + warnings.warn( + "You are using a local Base URL for an OpenAI compatible server in 'evaluate' mode. " + "Ensure you have a server running at this address or results may not be generated.", + UserWarning, + ) + break # Warn once + elif ( + isinstance(final_openai_configs, APIServerConfig) + and final_openai_configs.base_url + and ( + "localhost" in final_openai_configs.base_url + or "0.0.0.0" in final_openai_configs.base_url + or "127.0.0.1" in final_openai_configs.base_url + ) + ): + warnings.warn( + "You are using a local Base URL for an OpenAI compatible server in 'evaluate' mode. " + "Ensure you have a server running at this address or results may not be generated.", + UserWarning, + ) + + rprint(env_config) + rprint(final_openai_configs) + + # --- Create and Run Environment --- + # Create the environment instance + env = cls( + config=env_config, + server_configs=final_openai_configs, + slurm=server_manager_config.slurm, + testing=server_manager_config.testing, + ) + + print("Running evaluation...") + # Handle the case where we might already be in an event loop + try: + loop = asyncio.get_running_loop() + task = loop.create_task(env._run_evaluate()) + loop.run_until_complete(task) + except RuntimeError: + asyncio.run(env._run_evaluate()) + + print("Evaluation completed.") + + return CliEvaluateConfig