diff --git a/environments/math_server_zero.py b/environments/math_server_zero.py index 38acf49d..49e5db7e 100644 --- a/environments/math_server_zero.py +++ b/environments/math_server_zero.py @@ -9,6 +9,7 @@ import re from concurrent.futures import ProcessPoolExecutor from typing import Dict, List, Optional, Tuple +import wandb from datasets import load_dataset from latex2sympy2_extended import NormalizationConfig from math_verify import LatexExtractionConfig, parse, verify @@ -16,13 +17,12 @@ from math_verify.errors import TimeoutException from pydantic import Field from tqdm.asyncio import tqdm_asyncio -import wandb from atroposlib.envs.base import ( BaseEnv, BaseEnvConfig, EvalHandlingEnum, - OpenaiConfig, ScoredDataGroup, + ServerBaseline, ) prompt_format = ( @@ -115,7 +115,7 @@ class MathEnv(BaseEnv): def __init__( self, config: RSConfig, - server_configs: List[OpenaiConfig], + server_configs: ServerBaseline, slurm=True, testing=False, ): @@ -133,7 +133,7 @@ class MathEnv(BaseEnv): self.iter = 0 @classmethod - def config_init(cls) -> Tuple[RSConfig, List[OpenaiConfig]]: + def config_init(cls) -> Tuple[RSConfig, ServerBaseline]: env_config = RSConfig( tokenizer_name="Qwen/Qwen2.5-7B", group_size=8, @@ -147,14 +147,10 @@ class MathEnv(BaseEnv): eval_handling=EvalHandlingEnum.LIMIT_TRAIN, eval_limit_ratio=0.1, ) - server_configs = [ - OpenaiConfig( - model_name="default", - base_url="http://localhost:9004/v1", - api_key="x", - num_requests_for_eval=256, # since evaling only on one... - ), - ] + server_configs = ServerBaseline( + model_name="default", + num_requests_for_eval=256, # since evaling only on one... + ) return env_config, server_configs @@ -222,8 +218,8 @@ class MathEnv(BaseEnv): ) ) for name, t_dataset in zip( - ["amc23", "minerva", "olympiad"], - [amc_test_data, minerva_test_data, olympiad_test_data], + ["amc23", "minerva"], + [amc_test_data, minerva_test_data], ): for item in t_dataset: self.test.append( @@ -235,6 +231,17 @@ class MathEnv(BaseEnv): name, ) ) + for name, t_dataset in zip(["olympiad"], [olympiad_test_data]): + for item in t_dataset: + self.test.append( + ( + prompt_format.format( + prompt=problem_format.format(problem=item["question"]) + ), + item["final_answer"][0], + name, + ) + ) return async def rollout_and_score_eval(self, question, answer, subset):