diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 81569f0d..75ed08e2 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -1654,20 +1654,29 @@ class BaseEnv(ABC): cli_passed_flags, openai_full_prefix ) # CLI args yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {}) - if isinstance(default_server_configs, ServerBaseline) and ( + + # Auto-convert ServerBaseline to APIServerConfig when CLI/YAML overrides are provided + # This allows any environment to use --openai.* CLI args without modifying config_init + # Use a new variable to avoid UnboundLocalError from closure scoping + effective_server_configs = default_server_configs + if isinstance(effective_server_configs, ServerBaseline) and ( oai_cli_passed_args or yaml_oai_config ): - raise ValueError( - "ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use APIServerConfig." # noqa: E501 + # Convert ServerBaseline to APIServerConfig, preserving common fields + baseline_dict = effective_server_configs.model_dump() + effective_server_configs = APIServerConfig(**baseline_dict) + logger.info( + "Auto-converted ServerBaseline to APIServerConfig for CLI/YAML overrides" ) + if ( - isinstance(default_server_configs, list) - and len(default_server_configs) == 1 + isinstance(effective_server_configs, list) + and len(effective_server_configs) == 1 ): # can't use the same var name because it shadows the class variable and we get an error - default_openai_config_ = default_server_configs[0] + default_openai_config_ = effective_server_configs[0] else: - default_openai_config_ = default_server_configs + default_openai_config_ = effective_server_configs if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1: yaml_oai_config = yaml_oai_config[0] if isinstance(default_openai_config_, APIServerConfig) and isinstance( @@ -1717,7 +1726,7 @@ class BaseEnv(ABC): # Determine the final server_configs, handling single, multiple servers, and overrides. openai_configs = resolve_openai_configs( - default_server_configs=default_server_configs, + default_server_configs=effective_server_configs, openai_config_dict=openai_config_dict, yaml_config=yaml_config, cli_passed_flags=cli_passed_flags, diff --git a/environments/math_server_zero.py b/environments/math_server_zero.py index 0d6c140c..1432ab4d 100644 --- a/environments/math_server_zero.py +++ b/environments/math_server_zero.py @@ -18,7 +18,6 @@ from pydantic import Field from tqdm.asyncio import tqdm_asyncio from atroposlib.envs.base import ( - APIServerConfig, BaseEnv, BaseEnvConfig, EvalHandlingEnum, @@ -120,7 +119,7 @@ class MathEnv(BaseEnv): def __init__( self, config: RSConfig, - server_configs: APIServerConfig | ServerBaseline, + server_configs: ServerBaseline, slurm=True, testing=False, ): @@ -138,7 +137,7 @@ class MathEnv(BaseEnv): self.iter = 0 @classmethod - def config_init(cls) -> Tuple[RSConfig, APIServerConfig]: + def config_init(cls) -> Tuple[RSConfig, ServerBaseline]: env_config = RSConfig( tokenizer_name="Qwen/Qwen2.5-7B", group_size=16, @@ -153,11 +152,10 @@ class MathEnv(BaseEnv): eval_limit_ratio=0.1, max_num_workers_per_node=24, ) - server_configs = APIServerConfig( + server_configs = ServerBaseline( model_name="Qwen/Qwen2.5-7B", num_requests_for_eval=256, # since evaling only on one... server_type="vllm", - base_url="", # Override via CLI: --openai.base_url ) return env_config, server_configs