diff --git a/environments/math_server_zero.py b/environments/math_server_zero.py index 1432ab4d..0d6c140c 100644 --- a/environments/math_server_zero.py +++ b/environments/math_server_zero.py @@ -18,6 +18,7 @@ from pydantic import Field from tqdm.asyncio import tqdm_asyncio from atroposlib.envs.base import ( + APIServerConfig, BaseEnv, BaseEnvConfig, EvalHandlingEnum, @@ -119,7 +120,7 @@ class MathEnv(BaseEnv): def __init__( self, config: RSConfig, - server_configs: ServerBaseline, + server_configs: APIServerConfig | ServerBaseline, slurm=True, testing=False, ): @@ -137,7 +138,7 @@ class MathEnv(BaseEnv): self.iter = 0 @classmethod - def config_init(cls) -> Tuple[RSConfig, ServerBaseline]: + def config_init(cls) -> Tuple[RSConfig, APIServerConfig]: env_config = RSConfig( tokenizer_name="Qwen/Qwen2.5-7B", group_size=16, @@ -152,10 +153,11 @@ class MathEnv(BaseEnv): eval_limit_ratio=0.1, max_num_workers_per_node=24, ) - server_configs = ServerBaseline( + server_configs = APIServerConfig( model_name="Qwen/Qwen2.5-7B", num_requests_for_eval=256, # since evaling only on one... server_type="vllm", + base_url="", # Override via CLI: --openai.base_url ) return env_config, server_configs