add evaluate subcommand to cli

This commit is contained in:
hjc-puro 2025-07-07 17:39:33 -04:00
parent ecc5eebeca
commit 5519f190d2

View file

@ -1142,17 +1142,25 @@ class BaseEnv(ABC):
generate_html(self.config.data_path_to_save_groups)
async def _run_evaluate(self):
"""
Internal method to run evaluation with proper setup.
"""
await self.setup()
await self.evaluate()
@classmethod
def cli(cls):
"""
Command-line interface entry point for the environment.
This method handles the CLI commands for serve and process.
This method handles the CLI commands for serve, process, and evaluate.
"""
# Create subcommands dictionary
subcommands = {
"serve": cls.get_cli_serve_config_cls(),
"process": cls.get_cli_process_config_cls(),
"evaluate": cls.get_cli_evaluate_config_cls(),
}
# Custom exception handler for cleaner error output
@ -1603,3 +1611,251 @@ class BaseEnv(ABC):
asyncio.run(env.process_manager())
return CliProcessConfig
@classmethod
def get_cli_evaluate_config_cls(cls) -> type:
"""
Returns the CLI configuration class for evaluate commands.
Returns:
type: The CliEvaluateConfig class for evaluate commands.
"""
# Get the default configurations from the specific environment class via config_init
default_env_config_from_init, default_server_configs_from_init = (
cls.config_init()
)
# Define namespace prefixes
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
# Create Pydantic model classes based on the types from config_init.
# The defaults from config_init will be the primary source of defaults.
env_config_cls_from_init = type(default_env_config_from_init)
# Handle server_configs_from_init appropriately for creating a default CLI model
# If it's a list (multiple servers), we'll take the first one as a template for CLI args,
# or use APIServerConfig if the list is empty or contains ServerBaseline.
# If it's a single APIServerConfig, we use its type.
# If it's ServerBaseline, we use APIServerConfig type for CLI args to allow overrides.
if isinstance(default_server_configs_from_init, list):
if default_server_configs_from_init and isinstance(
default_server_configs_from_init[0], APIServerConfig
):
openai_config_cls_for_cli = type(default_server_configs_from_init[0])
# Use the actual instance for default values later if it's a single config
default_openai_config_instance_for_cli = (
default_server_configs_from_init[0]
if len(default_server_configs_from_init) == 1
else openai_config_cls_for_cli()
)
else:
openai_config_cls_for_cli = (
APIServerConfig # Default to APIServerConfig for CLI definition
)
default_openai_config_instance_for_cli = APIServerConfig()
elif isinstance(default_server_configs_from_init, APIServerConfig):
openai_config_cls_for_cli = type(default_server_configs_from_init)
default_openai_config_instance_for_cli = default_server_configs_from_init
else: # ServerBaseline or other
openai_config_cls_for_cli = APIServerConfig
default_openai_config_instance_for_cli = APIServerConfig()
class CliEvaluateConfig(
get_prefixed_pydantic_model(env_config_cls_from_init, env_full_prefix),
get_prefixed_pydantic_model(openai_config_cls_for_cli, openai_full_prefix),
ServerManagerConfig, # ServerManagerConfig defaults are fine as is.
Cmd,
):
"""
Configuration for the evaluate command.
Supports overrides via YAML config file and CLI arguments.
Order of precedence: CLI > YAML > `config_init` defaults.
"""
config: str | None = Field(
default=None,
description="Path to .yaml config file. CLI args override this.",
)
def run(self) -> None:
"""The logic to execute for the 'evaluate' command."""
# Set default wandb name if not provided and class has a name
wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name"
if (
getattr(self, wandb_name_attr, None) is None
and cls.name is not None
):
setattr(self, wandb_name_attr, cls.name)
# Load configuration from YAML file if specified
if self.config is not None:
with open(self.config, "r") as f:
yaml_config = yaml.safe_load(f)
print(f"Loaded config from {self.config}")
else:
yaml_config = {}
# Get CLI flags passed with double dashes
cli_passed_flags = get_double_dash_flags()
# --- Configuration Merging ---
# Priority: CLI > YAML > `config_init` defaults
# 1. Environment Configuration
# Start with defaults from config_init
env_config_dict_base = default_env_config_from_init.model_dump()
# Apply specific overrides for evaluate mode that are generally useful
env_config_dict_base["use_wandb"] = True
env_config_dict = merge_dicts(
env_config_dict_base, # `config_init` defaults with evaluate adjustments
yaml_config.get(ENV_NAMESPACE, {}), # YAML config
extract_namespace(cli_passed_flags, env_full_prefix), # CLI args
)
# 2. OpenAI Configuration
oai_cli_passed_args = extract_namespace(
cli_passed_flags, openai_full_prefix
) # CLI args
yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {})
# Determine the base OpenAI config from config_init for merging
# This uses the instance we determined earlier for CLI definition defaults
openai_config_dict_base = (
default_openai_config_instance_for_cli.model_dump()
)
if isinstance(default_server_configs_from_init, ServerBaseline) and (
oai_cli_passed_args or yaml_oai_config
):
# If config_init provided ServerBaseline, but CLI/YAML provides OpenAI specifics,
# it implies an override intent for a single server.
# We use the default_openai_config_instance_for_cli (which would be a default APIServerConfig)
# as the base for merging, allowing it to be fully specified by YAML/CLI.
pass # Base is already set correctly for this case
if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1:
# If YAML specifies a single server config for OpenAI namespace
yaml_oai_single_server_config = yaml_oai_config[0]
elif isinstance(yaml_oai_config, dict):
yaml_oai_single_server_config = yaml_oai_config
else:
yaml_oai_single_server_config = {}
openai_config_dict = merge_dicts(
openai_config_dict_base, # Default from config_init (or default APIServerConfig)
yaml_oai_single_server_config, # YAML config for a single server
oai_cli_passed_args, # CLI args
)
# 3. Server Manager Configuration
server_manager_cli_passed_flags = {}
if "slurm" in cli_passed_flags:
server_manager_cli_passed_flags["slurm"] = cli_passed_flags["slurm"]
if "testing" in cli_passed_flags:
server_manager_cli_passed_flags["testing"] = cli_passed_flags[
"testing"
]
server_manager_yaml_dict = {}
if "slurm" in yaml_config:
server_manager_yaml_dict["slurm"] = yaml_config["slurm"]
if "testing" in yaml_config:
server_manager_yaml_dict["testing"] = yaml_config["testing"]
# Start with ServerManagerConfig defaults, then apply YAML, then CLI
# For evaluate mode, slurm and testing are typically False unless specified.
server_manager_config_dict_base = ServerManagerConfig(
slurm=False, testing=False
).model_dump()
server_manager_config_dict = merge_dicts(
server_manager_config_dict_base,
server_manager_yaml_dict,
server_manager_cli_passed_flags,
)
# --- Instantiate Final Config Objects ---
# Use the original class types from config_init (or APIServerConfig for OpenAI CLI)
env_config = env_config_cls_from_init(**env_config_dict)
server_manager_config = ServerManagerConfig(
**server_manager_config_dict
)
# Determine the final server_configs.
# For 'evaluate', we typically expect a single server configuration for the OAI part.
# The resolve_openai_configs will handle complex cases, but for 'evaluate',
# the openai_config_dict we built should represent the single intended server.
# If default_server_configs_from_init was ServerBaseline, and we have openai_config_dict,
# it means we are overriding to use a specific APIServerConfig.
# If default_server_configs_from_init was a list or single APIServerConfig,
# resolve_openai_configs will merge appropriately.
final_openai_configs = resolve_openai_configs(
default_server_configs=default_server_configs_from_init, # Pass the original structure
openai_config_dict=openai_config_dict, # This is the merged single server config for CLI/YAML
yaml_config=yaml_config, # Pass full YAML for resolve_openai_configs logic
cli_passed_flags=cli_passed_flags, # Pass full CLI for resolve_openai_configs
logger=logger,
)
# Add warning for localhost or 0.0.0.0
if isinstance(final_openai_configs, list):
for cfg in final_openai_configs:
if (
isinstance(cfg, APIServerConfig)
and cfg.base_url
and (
"localhost" in cfg.base_url
or "0.0.0.0" in cfg.base_url
or "127.0.0.1" in cfg.base_url
)
):
warnings.warn(
"You are using a local Base URL for an OpenAI compatible server in 'evaluate' mode. "
"Ensure you have a server running at this address or results may not be generated.",
UserWarning,
)
break # Warn once
elif (
isinstance(final_openai_configs, APIServerConfig)
and final_openai_configs.base_url
and (
"localhost" in final_openai_configs.base_url
or "0.0.0.0" in final_openai_configs.base_url
or "127.0.0.1" in final_openai_configs.base_url
)
):
warnings.warn(
"You are using a local Base URL for an OpenAI compatible server in 'evaluate' mode. "
"Ensure you have a server running at this address or results may not be generated.",
UserWarning,
)
rprint(env_config)
rprint(final_openai_configs)
# --- Create and Run Environment ---
# Create the environment instance
env = cls(
config=env_config,
server_configs=final_openai_configs,
slurm=server_manager_config.slurm,
testing=server_manager_config.testing,
)
print("Running evaluation...")
# Handle the case where we might already be in an event loop
try:
loop = asyncio.get_running_loop()
task = loop.create_task(env._run_evaluate())
loop.run_until_complete(task)
except RuntimeError:
asyncio.run(env._run_evaluate())
print("Evaluation completed.")
return CliEvaluateConfig