mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
add evaluate subcommand to cli
This commit is contained in:
parent
ecc5eebeca
commit
5519f190d2
1 changed files with 257 additions and 1 deletions
|
|
@ -1142,17 +1142,25 @@ class BaseEnv(ABC):
|
|||
|
||||
generate_html(self.config.data_path_to_save_groups)
|
||||
|
||||
async def _run_evaluate(self):
|
||||
"""
|
||||
Internal method to run evaluation with proper setup.
|
||||
"""
|
||||
await self.setup()
|
||||
await self.evaluate()
|
||||
|
||||
@classmethod
|
||||
def cli(cls):
|
||||
"""
|
||||
Command-line interface entry point for the environment.
|
||||
This method handles the CLI commands for serve and process.
|
||||
This method handles the CLI commands for serve, process, and evaluate.
|
||||
"""
|
||||
|
||||
# Create subcommands dictionary
|
||||
subcommands = {
|
||||
"serve": cls.get_cli_serve_config_cls(),
|
||||
"process": cls.get_cli_process_config_cls(),
|
||||
"evaluate": cls.get_cli_evaluate_config_cls(),
|
||||
}
|
||||
|
||||
# Custom exception handler for cleaner error output
|
||||
|
|
@ -1603,3 +1611,251 @@ class BaseEnv(ABC):
|
|||
asyncio.run(env.process_manager())
|
||||
|
||||
return CliProcessConfig
|
||||
|
||||
@classmethod
|
||||
def get_cli_evaluate_config_cls(cls) -> type:
|
||||
"""
|
||||
Returns the CLI configuration class for evaluate commands.
|
||||
|
||||
Returns:
|
||||
type: The CliEvaluateConfig class for evaluate commands.
|
||||
"""
|
||||
# Get the default configurations from the specific environment class via config_init
|
||||
default_env_config_from_init, default_server_configs_from_init = (
|
||||
cls.config_init()
|
||||
)
|
||||
|
||||
# Define namespace prefixes
|
||||
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
|
||||
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
|
||||
|
||||
# Create Pydantic model classes based on the types from config_init.
|
||||
# The defaults from config_init will be the primary source of defaults.
|
||||
env_config_cls_from_init = type(default_env_config_from_init)
|
||||
|
||||
# Handle server_configs_from_init appropriately for creating a default CLI model
|
||||
# If it's a list (multiple servers), we'll take the first one as a template for CLI args,
|
||||
# or use APIServerConfig if the list is empty or contains ServerBaseline.
|
||||
# If it's a single APIServerConfig, we use its type.
|
||||
# If it's ServerBaseline, we use APIServerConfig type for CLI args to allow overrides.
|
||||
if isinstance(default_server_configs_from_init, list):
|
||||
if default_server_configs_from_init and isinstance(
|
||||
default_server_configs_from_init[0], APIServerConfig
|
||||
):
|
||||
openai_config_cls_for_cli = type(default_server_configs_from_init[0])
|
||||
# Use the actual instance for default values later if it's a single config
|
||||
default_openai_config_instance_for_cli = (
|
||||
default_server_configs_from_init[0]
|
||||
if len(default_server_configs_from_init) == 1
|
||||
else openai_config_cls_for_cli()
|
||||
)
|
||||
else:
|
||||
openai_config_cls_for_cli = (
|
||||
APIServerConfig # Default to APIServerConfig for CLI definition
|
||||
)
|
||||
default_openai_config_instance_for_cli = APIServerConfig()
|
||||
elif isinstance(default_server_configs_from_init, APIServerConfig):
|
||||
openai_config_cls_for_cli = type(default_server_configs_from_init)
|
||||
default_openai_config_instance_for_cli = default_server_configs_from_init
|
||||
else: # ServerBaseline or other
|
||||
openai_config_cls_for_cli = APIServerConfig
|
||||
default_openai_config_instance_for_cli = APIServerConfig()
|
||||
|
||||
class CliEvaluateConfig(
|
||||
get_prefixed_pydantic_model(env_config_cls_from_init, env_full_prefix),
|
||||
get_prefixed_pydantic_model(openai_config_cls_for_cli, openai_full_prefix),
|
||||
ServerManagerConfig, # ServerManagerConfig defaults are fine as is.
|
||||
Cmd,
|
||||
):
|
||||
"""
|
||||
Configuration for the evaluate command.
|
||||
Supports overrides via YAML config file and CLI arguments.
|
||||
Order of precedence: CLI > YAML > `config_init` defaults.
|
||||
"""
|
||||
|
||||
config: str | None = Field(
|
||||
default=None,
|
||||
description="Path to .yaml config file. CLI args override this.",
|
||||
)
|
||||
|
||||
def run(self) -> None:
|
||||
"""The logic to execute for the 'evaluate' command."""
|
||||
# Set default wandb name if not provided and class has a name
|
||||
wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name"
|
||||
if (
|
||||
getattr(self, wandb_name_attr, None) is None
|
||||
and cls.name is not None
|
||||
):
|
||||
setattr(self, wandb_name_attr, cls.name)
|
||||
|
||||
# Load configuration from YAML file if specified
|
||||
if self.config is not None:
|
||||
with open(self.config, "r") as f:
|
||||
yaml_config = yaml.safe_load(f)
|
||||
print(f"Loaded config from {self.config}")
|
||||
else:
|
||||
yaml_config = {}
|
||||
|
||||
# Get CLI flags passed with double dashes
|
||||
cli_passed_flags = get_double_dash_flags()
|
||||
|
||||
# --- Configuration Merging ---
|
||||
# Priority: CLI > YAML > `config_init` defaults
|
||||
|
||||
# 1. Environment Configuration
|
||||
# Start with defaults from config_init
|
||||
env_config_dict_base = default_env_config_from_init.model_dump()
|
||||
# Apply specific overrides for evaluate mode that are generally useful
|
||||
env_config_dict_base["use_wandb"] = True
|
||||
|
||||
env_config_dict = merge_dicts(
|
||||
env_config_dict_base, # `config_init` defaults with evaluate adjustments
|
||||
yaml_config.get(ENV_NAMESPACE, {}), # YAML config
|
||||
extract_namespace(cli_passed_flags, env_full_prefix), # CLI args
|
||||
)
|
||||
|
||||
# 2. OpenAI Configuration
|
||||
oai_cli_passed_args = extract_namespace(
|
||||
cli_passed_flags, openai_full_prefix
|
||||
) # CLI args
|
||||
yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {})
|
||||
|
||||
# Determine the base OpenAI config from config_init for merging
|
||||
# This uses the instance we determined earlier for CLI definition defaults
|
||||
openai_config_dict_base = (
|
||||
default_openai_config_instance_for_cli.model_dump()
|
||||
)
|
||||
|
||||
if isinstance(default_server_configs_from_init, ServerBaseline) and (
|
||||
oai_cli_passed_args or yaml_oai_config
|
||||
):
|
||||
# If config_init provided ServerBaseline, but CLI/YAML provides OpenAI specifics,
|
||||
# it implies an override intent for a single server.
|
||||
# We use the default_openai_config_instance_for_cli (which would be a default APIServerConfig)
|
||||
# as the base for merging, allowing it to be fully specified by YAML/CLI.
|
||||
pass # Base is already set correctly for this case
|
||||
|
||||
if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1:
|
||||
# If YAML specifies a single server config for OpenAI namespace
|
||||
yaml_oai_single_server_config = yaml_oai_config[0]
|
||||
elif isinstance(yaml_oai_config, dict):
|
||||
yaml_oai_single_server_config = yaml_oai_config
|
||||
else:
|
||||
yaml_oai_single_server_config = {}
|
||||
|
||||
openai_config_dict = merge_dicts(
|
||||
openai_config_dict_base, # Default from config_init (or default APIServerConfig)
|
||||
yaml_oai_single_server_config, # YAML config for a single server
|
||||
oai_cli_passed_args, # CLI args
|
||||
)
|
||||
|
||||
# 3. Server Manager Configuration
|
||||
server_manager_cli_passed_flags = {}
|
||||
if "slurm" in cli_passed_flags:
|
||||
server_manager_cli_passed_flags["slurm"] = cli_passed_flags["slurm"]
|
||||
if "testing" in cli_passed_flags:
|
||||
server_manager_cli_passed_flags["testing"] = cli_passed_flags[
|
||||
"testing"
|
||||
]
|
||||
|
||||
server_manager_yaml_dict = {}
|
||||
if "slurm" in yaml_config:
|
||||
server_manager_yaml_dict["slurm"] = yaml_config["slurm"]
|
||||
if "testing" in yaml_config:
|
||||
server_manager_yaml_dict["testing"] = yaml_config["testing"]
|
||||
|
||||
# Start with ServerManagerConfig defaults, then apply YAML, then CLI
|
||||
# For evaluate mode, slurm and testing are typically False unless specified.
|
||||
server_manager_config_dict_base = ServerManagerConfig(
|
||||
slurm=False, testing=False
|
||||
).model_dump()
|
||||
|
||||
server_manager_config_dict = merge_dicts(
|
||||
server_manager_config_dict_base,
|
||||
server_manager_yaml_dict,
|
||||
server_manager_cli_passed_flags,
|
||||
)
|
||||
|
||||
# --- Instantiate Final Config Objects ---
|
||||
# Use the original class types from config_init (or APIServerConfig for OpenAI CLI)
|
||||
|
||||
env_config = env_config_cls_from_init(**env_config_dict)
|
||||
server_manager_config = ServerManagerConfig(
|
||||
**server_manager_config_dict
|
||||
)
|
||||
|
||||
# Determine the final server_configs.
|
||||
# For 'evaluate', we typically expect a single server configuration for the OAI part.
|
||||
# The resolve_openai_configs will handle complex cases, but for 'evaluate',
|
||||
# the openai_config_dict we built should represent the single intended server.
|
||||
|
||||
# If default_server_configs_from_init was ServerBaseline, and we have openai_config_dict,
|
||||
# it means we are overriding to use a specific APIServerConfig.
|
||||
# If default_server_configs_from_init was a list or single APIServerConfig,
|
||||
# resolve_openai_configs will merge appropriately.
|
||||
|
||||
final_openai_configs = resolve_openai_configs(
|
||||
default_server_configs=default_server_configs_from_init, # Pass the original structure
|
||||
openai_config_dict=openai_config_dict, # This is the merged single server config for CLI/YAML
|
||||
yaml_config=yaml_config, # Pass full YAML for resolve_openai_configs logic
|
||||
cli_passed_flags=cli_passed_flags, # Pass full CLI for resolve_openai_configs
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
# Add warning for localhost or 0.0.0.0
|
||||
if isinstance(final_openai_configs, list):
|
||||
for cfg in final_openai_configs:
|
||||
if (
|
||||
isinstance(cfg, APIServerConfig)
|
||||
and cfg.base_url
|
||||
and (
|
||||
"localhost" in cfg.base_url
|
||||
or "0.0.0.0" in cfg.base_url
|
||||
or "127.0.0.1" in cfg.base_url
|
||||
)
|
||||
):
|
||||
warnings.warn(
|
||||
"You are using a local Base URL for an OpenAI compatible server in 'evaluate' mode. "
|
||||
"Ensure you have a server running at this address or results may not be generated.",
|
||||
UserWarning,
|
||||
)
|
||||
break # Warn once
|
||||
elif (
|
||||
isinstance(final_openai_configs, APIServerConfig)
|
||||
and final_openai_configs.base_url
|
||||
and (
|
||||
"localhost" in final_openai_configs.base_url
|
||||
or "0.0.0.0" in final_openai_configs.base_url
|
||||
or "127.0.0.1" in final_openai_configs.base_url
|
||||
)
|
||||
):
|
||||
warnings.warn(
|
||||
"You are using a local Base URL for an OpenAI compatible server in 'evaluate' mode. "
|
||||
"Ensure you have a server running at this address or results may not be generated.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
rprint(env_config)
|
||||
rprint(final_openai_configs)
|
||||
|
||||
# --- Create and Run Environment ---
|
||||
# Create the environment instance
|
||||
env = cls(
|
||||
config=env_config,
|
||||
server_configs=final_openai_configs,
|
||||
slurm=server_manager_config.slurm,
|
||||
testing=server_manager_config.testing,
|
||||
)
|
||||
|
||||
print("Running evaluation...")
|
||||
# Handle the case where we might already be in an event loop
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
task = loop.create_task(env._run_evaluate())
|
||||
loop.run_until_complete(task)
|
||||
except RuntimeError:
|
||||
asyncio.run(env._run_evaluate())
|
||||
|
||||
print("Evaluation completed.")
|
||||
|
||||
return CliEvaluateConfig
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue