diff --git a/README.md b/README.md index aaaf3a52..43142b50 100644 --- a/README.md +++ b/README.md @@ -167,9 +167,15 @@ pre-commit install You should edit the config_init section of the environment file you want ([For example, in GSM8K Environment](https://github.com/NousResearch/atropos/blob/main/environments/gsm8k_server.py#L53)) to point to a running VLLM or SGLang inference server as well as any other configuration changes you'd like to make, such as the group size, then: ```bash - # Start the API server and run the GSM8K environment - run-api & python environments/gsm8k_server.py serve \ - --slurm false + # Start the API server + run-api + ``` + In a separate terminal, start the GSM8K environment microservice + ```bash + python environments/gsm8k_server.py serve --openai.model_name Qwen/Qwen2.5-1.5B-Instruct --slurm false + # alternatively + # python environments/gsm8k_server.py serve --config environments/configs/example.yaml + # python environments/gsm8k_server.py serve --config environments/configs/example.yaml --env.group_size 8 # cli args override corresponding config settings ``` 3. **Query the the API (Optional)** diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index ba8b8ee4..c8b9d04b 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -20,15 +20,12 @@ import wandb import yaml from pydantic import BaseModel, Field from pydantic_cli import Cmd, FailedExecutionException, run_and_exit +from rich import print as rprint from tenacity import retry, stop_after_attempt, wait_random_exponential from transformers import AutoTokenizer -from atroposlib.envs.constants import ( - ENV_NAMESPACE, - NAMESPACE_SEP, - OPENAI_NAMESPACE, - SERVER_MANAGER_NAMESPACE, -) +from atroposlib.envs.constants import ENV_NAMESPACE, NAMESPACE_SEP, OPENAI_NAMESPACE +from atroposlib.envs.server_handling.openai_server import resolve_openai_configs from atroposlib.frontend.jsonl2html import generate_html from atroposlib.type_definitions import UUID from atroposlib.utils.cli import ( @@ -957,36 +954,150 @@ class BaseEnv(ABC): type: The CliServeConfig class for serving commands. """ - env_config, server_configs = cls.config_init() + # Get the default configurations defined by the specific environment class + default_env_config, default_server_configs = cls.config_init() + + # Define namespace prefixes for CLI arguments and YAML keys env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}" openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}" + # Define the CLI configuration class dynamically class CliServeConfig( - get_prefixed_pydantic_model(type(env_config), env_full_prefix), - get_prefixed_pydantic_model(OpenaiConfig, openai_full_prefix), - ServerManagerConfig, + get_prefixed_pydantic_model(type(default_env_config), env_full_prefix), + get_prefixed_pydantic_model( + OpenaiConfig, openai_full_prefix + ), # Use OpenaiConfig for CLI args + ServerManagerConfig, # ServerManager args are not namespaced by default Cmd, ): """ Configuration for the serve command. - This combines BaseEnvConfig and OpenaiConfig into a single command. + Supports overrides via YAML config file and CLI arguments. + Order of precedence: CLI > YAML > Class Defaults. """ + config: str | None = Field( + default=None, + description="Path to .yaml config file. CLI args override this.", + ) + def run(self) -> None: """The logic to execute for the 'serve' command.""" - # Convert this config into the formats needed by BaseEnv + # Set default wandb name if not provided and class has a name + # Note: This modifies the 'self' instance based on CLI args before full parsing. wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name" - if getattr(self, wandb_name_attr) is None and cls.name is not None: + if ( + getattr(self, wandb_name_attr, None) is None + and cls.name is not None + ): setattr(self, wandb_name_attr, cls.name) - model_dumped = self.model_dump(exclude_unset=True) - server_manager_config = ServerManagerConfig(**model_dumped) - # Create the environment instance + + # Load configuration from YAML file if specified + if self.config is not None: + with open(self.config, "r") as f: + yaml_config = yaml.safe_load(f) + print(f"Loaded config from {self.config}") + else: + yaml_config = {} + + # Get CLI flags passed with double dashes (e.g., --env--foo bar) + cli_passed_flags = get_double_dash_flags() + + # --- Configuration Merging --- + # Priority: CLI > YAML > Class Defaults + + # 1. Environment Configuration + env_config_dict = merge_dicts( + default_env_config.model_dump(), # Class Defaults + yaml_config.get(ENV_NAMESPACE, {}), # YAML config + extract_namespace(cli_passed_flags, env_full_prefix), # CLI args + ) + + # 2. OpenAI Configuration (used for potential overrides) + oai_cli_passed_args = extract_namespace( + cli_passed_flags, openai_full_prefix + ) # CLI args + yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {}) + if isinstance(default_server_configs, ServerBaseline) and ( + oai_cli_passed_args or yaml_oai_config + ): + raise ValueError( + "ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use OpenaiConfig." # noqa: E501 + ) + if ( + isinstance(default_server_configs, list) + and len(default_server_configs) == 1 + ): + # can't use the same var name because it shadows the class variable and we get an error + default_openai_config_ = default_server_configs[0] + else: + default_openai_config_ = default_server_configs + if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1: + yaml_oai_config = yaml_oai_config[0] + if isinstance(default_openai_config_, OpenaiConfig) and isinstance( + yaml_oai_config, dict + ): + openai_config_dict = merge_dicts( + default_openai_config_.model_dump(), # Default OpenaiConfig (or from class init) + yaml_oai_config, + oai_cli_passed_args, + ) + else: + openai_config_dict = {} + + # 3. Server Manager Configuration (slurm, testing - not namespaced) + # Extract only relevant CLI flags for ServerManager + server_manager_cli_passed_flags = {} + if "slurm" in cli_passed_flags: + server_manager_cli_passed_flags["slurm"] = cli_passed_flags["slurm"] + if "testing" in cli_passed_flags: + server_manager_cli_passed_flags["testing"] = cli_passed_flags[ + "testing" + ] + + server_manager_yaml_dict = {} + if "slurm" in yaml_config: + server_manager_yaml_dict["slurm"] = yaml_config["slurm"] + if "testing" in yaml_config: + server_manager_yaml_dict["testing"] = yaml_config["testing"] + + server_manager_config_dict = merge_dicts( + ServerManagerConfig().model_dump(), # Base defaults for ServerManager + server_manager_yaml_dict, # YAML config + server_manager_cli_passed_flags, # CLI args + ) + + # --- Instantiate Final Config Objects --- + # Create instances from the merged dictionaries using the original default types where appropriate + + # Instantiate the final environment config using its original type + env_config = type(default_env_config)(**env_config_dict) + + # Instantiate the final server manager config + server_manager_config = ServerManagerConfig( + **server_manager_config_dict + ) + + # Determine the final server_configs, handling single, multiple servers, and overrides. + + openai_configs = resolve_openai_configs( + default_server_configs=default_server_configs, + openai_config_dict=openai_config_dict, + yaml_config=yaml_config, + cli_passed_flags=cli_passed_flags, + logger=logger, + ) + + # --- Create and Run Environment --- + # Create the environment instance using the final, instantiated config objects env = cls( config=env_config, - server_configs=server_configs, + server_configs=openai_configs, slurm=server_manager_config.slurm, testing=server_manager_config.testing, ) + rprint(env_config) + rprint(openai_configs) # Run the environment asyncio.run(env.env_manager()) @@ -1002,11 +1113,14 @@ class BaseEnv(ABC): type: The CliProcessConfig class for processing commands. """ + # Define specific default configurations for the 'process' mode PROCESS_MODE_ENV_DEFAULT_CONFIG = BaseEnvConfig( group_size=8, total_steps=2, ensure_scores_are_not_same=False, include_messages=True, + data_path_to_save_groups=f"data/{cls.name or 'groups'}.jsonl", + use_wandb=True, ) PROCESS_MODE_OPENAI_DEFAULT_CONFIG = OpenaiConfig( model_name="gpt-4.1-nano", @@ -1018,19 +1132,24 @@ class BaseEnv(ABC): testing=False, ) - default_env_config, default_openai_config = cls.config_init() - - if isinstance(default_openai_config, list): - default_openai_config = default_openai_config[0] + # Get the base default configurations from the specific environment class + ( + default_env_config, + default_server_configs, + ) = cls.config_init() + # Define namespace prefixes env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}" openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}" + # Create Pydantic model classes with the 'process' mode defaults applied. + # These adjusted classes will be used for final instantiation. env_config_cls_new_defaults = adjust_model_defaults( type(default_env_config), PROCESS_MODE_ENV_DEFAULT_CONFIG ) openai_config_cls_new_defaults = adjust_model_defaults( - OpenaiConfig, PROCESS_MODE_OPENAI_DEFAULT_CONFIG + OpenaiConfig, + PROCESS_MODE_OPENAI_DEFAULT_CONFIG, ) server_manager_config_cls_new_defaults = adjust_model_defaults( ServerManagerConfig, @@ -1047,6 +1166,8 @@ class BaseEnv(ABC): ): """ Configuration for the process command. + Supports overrides via YAML config file and CLI arguments. + Order of precedence: CLI > YAML > Process Mode Defaults > `config_init` defaults. """ config: str | None = Field( @@ -1056,42 +1177,72 @@ class BaseEnv(ABC): def run(self) -> None: """The logic to execute for the 'process' command.""" - # Setup environment configuration + # Set default wandb name if not provided and class has a name wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name" - if getattr(self, wandb_name_attr) is None and cls.name is not None: + if ( + getattr(self, wandb_name_attr, None) is None + and cls.name is not None + ): setattr(self, wandb_name_attr, cls.name) + # Load configuration from YAML file if specified if self.config is not None: with open(self.config, "r") as f: - config = yaml.safe_load(f) + yaml_config = yaml.safe_load(f) print(f"Loaded config from {self.config}") else: - config = {} + yaml_config = {} + # Get CLI flags passed with double dashes cli_passed_flags = get_double_dash_flags() - # cli args overrides config file which overrides class defaults which overrides process mode defaults - env_config = env_config_cls_new_defaults( - **merge_dicts( - default_env_config.model_dump(), - PROCESS_MODE_ENV_DEFAULT_CONFIG.model_dump(), - config.get(ENV_NAMESPACE, {}), - extract_namespace( - cli_passed_flags, env_full_prefix - ), # only extract namespace for cli-passed args - ) - ) - openai_config = openai_config_cls_new_defaults( - **merge_dicts( - default_openai_config.model_dump(), - PROCESS_MODE_OPENAI_DEFAULT_CONFIG.model_dump(), - config.get(OPENAI_NAMESPACE, {}), - extract_namespace( - cli_passed_flags, openai_full_prefix - ), # only extract namespace for cli-passed args - ) + # --- Configuration Merging --- + # Priority: CLI > YAML > Process Mode Defaults > `config_init` defaults + + # 1. Environment Configuration + env_config_dict = merge_dicts( + default_env_config.model_dump(), # Class Defaults + PROCESS_MODE_ENV_DEFAULT_CONFIG.model_dump(), # Process Mode Defaults + yaml_config.get(ENV_NAMESPACE, {}), # YAML config + extract_namespace(cli_passed_flags, env_full_prefix), # CLI args ) + # 2. OpenAI Configuration + oai_cli_passed_args = extract_namespace( + cli_passed_flags, openai_full_prefix + ) # CLI args + yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {}) + if isinstance(default_server_configs, ServerBaseline) and ( + oai_cli_passed_args or yaml_oai_config + ): + raise ValueError( + "ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use OpenaiConfig." # noqa: E501 + ) + + if ( + isinstance(default_server_configs, list) + and len(default_server_configs) == 1 + ): + # can't use the same var name because it shadows the class variable and we get an error + default_openai_config_ = default_server_configs[0] + else: + default_openai_config_ = default_server_configs + if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1: + yaml_oai_config = yaml_oai_config[0] + if isinstance(default_openai_config_, OpenaiConfig) and isinstance( + yaml_oai_config, dict + ): + openai_config_dict = merge_dicts( + default_openai_config_.model_dump(), # Default OpenaiConfig (or from class init) + PROCESS_MODE_OPENAI_DEFAULT_CONFIG.model_dump(), # Process Mode Defaults + yaml_oai_config, + oai_cli_passed_args, + ) + else: + openai_config_dict = {} + + # 3. Server Manager Configuration + # Extract only relevant CLI flags server_manager_cli_passed_flags = {} if "slurm" in cli_passed_flags: server_manager_cli_passed_flags["slurm"] = cli_passed_flags["slurm"] @@ -1100,36 +1251,68 @@ class BaseEnv(ABC): "testing" ] - server_manager_config = server_manager_config_cls_new_defaults( - **merge_dicts( - ServerManagerConfig().model_dump(), - PROCESS_MODE_SERVER_MANAGER_DEFAULT_CONFIG.model_dump(), - config.get(SERVER_MANAGER_NAMESPACE, {}), - server_manager_cli_passed_flags, - ) + server_manager_yaml_dict = {} + if "slurm" in yaml_config: + server_manager_yaml_dict["slurm"] = yaml_config["slurm"] + if "testing" in yaml_config: + server_manager_yaml_dict["testing"] = yaml_config["testing"] + + server_manager_config_dict = merge_dicts( + ServerManagerConfig().model_dump(), # Base defaults + PROCESS_MODE_SERVER_MANAGER_DEFAULT_CONFIG.model_dump(), # Process Mode Defaults + server_manager_yaml_dict, + server_manager_cli_passed_flags, # CLI args ) + # --- Instantiate Final Config Objects --- + # Use the classes with adjusted defaults for instantiation + + env_config = env_config_cls_new_defaults(**env_config_dict) + server_manager_config = server_manager_config_cls_new_defaults( + **server_manager_config_dict + ) + + # Determine the final server_configs, handling single, multiple servers, and overrides. + + openai_configs = resolve_openai_configs( + default_server_configs=default_server_configs, + openai_config_dict=openai_config_dict, + yaml_config=yaml_config, + cli_passed_flags=cli_passed_flags, + logger=logger, + ) + + rprint(env_config) + rprint(openai_configs) + + # --- Create and Run Environment --- # Create the environment instance env = cls( config=env_config, - server_configs=[openai_config], + server_configs=openai_configs, slurm=server_manager_config.slurm, testing=server_manager_config.testing, ) - # Set the process mode parameters + # Set specific parameters for process mode on the environment instance env.process_mode = True env.n_groups_to_process = env_config.total_steps env.group_size_to_process = env_config.group_size + # Validate that an output path is set (should have a default from PROCESS_MODE_ENV_DEFAULT_CONFIG) + if env_config.data_path_to_save_groups is None: + # This check might be redundant if the default is always set, but good practice. + raise ValueError( + "data_path_to_save_groups must be set for process mode" + ) + print( f"Processing {env_config.total_steps} groups of " f"{env_config.group_size} responses and " f"writing to {env_config.data_path_to_save_groups}" ) + # Run the environment's asynchronous process manager function asyncio.run(env.process_manager()) - # Actual implementation would go here - return CliProcessConfig diff --git a/atroposlib/envs/server_handling/openai_server.py b/atroposlib/envs/server_handling/openai_server.py index 4a3562cb..f161869e 100644 --- a/atroposlib/envs/server_handling/openai_server.py +++ b/atroposlib/envs/server_handling/openai_server.py @@ -10,8 +10,11 @@ import openai from openai.types.chat.chat_completion import ChatCompletion from openai.types.completion import Completion from pydantic import BaseModel, Field +from pydantic_cli import FailedExecutionException from tenacity import retry, stop_after_attempt, wait_random_exponential +from atroposlib.envs.constants import NAMESPACE_SEP, OPENAI_NAMESPACE + class OpenaiConfig(BaseModel): """ @@ -294,3 +297,79 @@ class OpenAIServer: self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"]) self.eval_attempts_list.append(stat_dict["attempts"]) return ret_data + + +def resolve_openai_configs( + default_server_configs, + openai_config_dict, + yaml_config, + cli_passed_flags, + logger, +): + """ + Helper to resolve the final server_configs, handling single, multiple servers, and overrides. + """ + from atroposlib.envs.server_handling.server_manager import ServerBaseline + + openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}" + openai_yaml_config = yaml_config.get(OPENAI_NAMESPACE, None) + openai_cli_config = { + k: v for k, v in cli_passed_flags.items() if k.startswith(openai_full_prefix) + } + + is_multi_server_yaml = ( + isinstance(openai_yaml_config, list) and len(openai_yaml_config) >= 2 + ) + is_multi_server_default = ( + (not is_multi_server_yaml) + and isinstance(default_server_configs, list) + and len(default_server_configs) >= 2 + ) + + if (is_multi_server_yaml or is_multi_server_default) and openai_cli_config: + raise FailedExecutionException( + f"CLI overrides for OpenAI settings (--{openai_full_prefix}*) are not supported " + f"when multiple servers are defined (either via YAML list under '{OPENAI_NAMESPACE}' " + "or a default list with length >= 2)." + ) + + if is_multi_server_yaml: + logger.info( + f"Using multi-server configuration defined in YAML under '{OPENAI_NAMESPACE}'." + ) + try: + server_configs = [OpenaiConfig(**cfg) for cfg in openai_yaml_config] + except Exception as e: + raise FailedExecutionException( + f"Error parsing multi-server OpenAI configuration from YAML under '{OPENAI_NAMESPACE}': {e}" + ) from e + elif isinstance(default_server_configs, ServerBaseline): + logger.info("Using ServerBaseline configuration.") + server_configs = default_server_configs + elif is_multi_server_default: + logger.info("Using default multi-server configuration (length >= 2).") + server_configs = default_server_configs + else: + logger.info( + "Using single OpenAI server configuration based on merged settings (default/YAML/CLI)." + ) + try: + final_openai_config = OpenaiConfig(**openai_config_dict) + except Exception as e: + raise FailedExecutionException( + f"Error creating final OpenAI configuration from merged settings: {e}\n" + f"Merged Dict: {openai_config_dict}" + ) from e + + if isinstance(default_server_configs, OpenaiConfig): + server_configs = final_openai_config + elif isinstance(default_server_configs, list): + server_configs = [final_openai_config] + else: + logger.warning( + f"Unexpected type for default_server_configs: {type(default_server_configs)}. " + f"Proceeding with single OpenAI server configuration based on merged settings." + ) + server_configs = [final_openai_config] + + return server_configs diff --git a/atroposlib/utils/cli.py b/atroposlib/utils/cli.py index a20b6d6d..1ca6082f 100644 --- a/atroposlib/utils/cli.py +++ b/atroposlib/utils/cli.py @@ -156,24 +156,41 @@ def get_double_dash_flags() -> Dict[str, Any]: # Remove '--' prefix key_part = arg[2:] + key = "" + value_str = ( + None # Variable to hold the string value before potential conversion + ) # Check for '--key=value' format if "=" in key_part: - key, value = key_part.split("=", 1) - if key: # Ensure key is not empty (e.g. --=value) - flags_dict[key] = value + key, value_str = key_part.split("=", 1) + if not key: # Ensure key is not empty (e.g. --=value) + i += 1 + continue # Skip if key is empty + + # Process value: Convert "None" string to None object + if value_str == "None": + flags_dict[key] = None + else: + flags_dict[key] = value_str i += 1 # Check if next argument exists and is a value (doesn't start with '-') elif i + 1 < len(args) and not args[i + 1].startswith("-"): key = key_part - value = args[i + 1] - flags_dict[key] = value + value_str = args[i + 1] + + # Process value: Convert "None" string to None object + if value_str == "None": + flags_dict[key] = None + else: + flags_dict[key] = value_str # Skip the next argument since we've consumed it as a value i += 2 # Otherwise, treat as a boolean flag else: key = key_part - flags_dict[key] = True + if key: # Ensure key is not empty (e.g. just '--') + flags_dict[key] = True i += 1 return flags_dict diff --git a/environments/configs/example.yaml b/environments/configs/example.yaml new file mode 100644 index 00000000..680458cf --- /dev/null +++ b/environments/configs/example.yaml @@ -0,0 +1,21 @@ +# Environment configuration +env: + group_size: 4 + max_batches_offpolicy: 3 + tokenizer_name: "Qwen/Qwen2.5-1.5B-Instruct" + use_wandb: true + rollout_server_url: "http://localhost:8000" + wandb_name: "example_env" + ensure_scores_are_not_same: true + data_path_to_save_groups: null + include_messages: true # if data_path_to_save_groups is set this will add the messages to the saved .jsonl + +# OpenAI server configurations +openai: + - model_name: "Qwen/Qwen2.5-1.5B-Instruct" + base_url: "http://localhost:9001/v1" + api_key: "x" + weight: 1.0 + +slurm: false +testing: false diff --git a/pyproject.toml b/pyproject.toml index 36b6f2bb..ddd8ffbe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "math-verify==0.7.0", "jinja2", "nltk", + "rich", "polars", "aiofiles", "jsonlines",