diff --git a/environments/eval_environments/gsm8k_eval.py b/environments/eval_environments/gsm8k_eval.py index 9cacb5b4..06f4c28e 100644 --- a/environments/eval_environments/gsm8k_eval.py +++ b/environments/eval_environments/gsm8k_eval.py @@ -20,7 +20,7 @@ Supports thinking mode with tags for extended reasoning. import asyncio import random from concurrent.futures import ProcessPoolExecutor -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import wandb from datasets import load_dataset @@ -147,6 +147,26 @@ class GSM8KEvalEnv(BaseEnv): def config_cls(cls) -> type: return GSM8KEvalConfig + @classmethod + def config_init(cls) -> Tuple[GSM8KEvalConfig, List[APIServerConfig]]: + """Initialize default configuration for the environment.""" + env_config = GSM8KEvalConfig( + tokenizer_name="Qwen/Qwen2.5-3B-Instruct", + group_size=1, + use_wandb=False, + max_num_workers_per_node=128, + rollout_server_url="http://localhost:8000", + total_steps=1, + wandb_name="gsm8k_eval", + ) + server_configs = [ + APIServerConfig( + model_name="Qwen/Qwen2.5-3B-Instruct", + base_url="http://localhost:9001/v1", + ) + ] + return env_config, server_configs + async def setup(self) -> None: """Initialize the environment and load the dataset.""" # Initialize math executor