math zero work arounds

This commit is contained in:
Jai Suphavadeeprasit 2026-02-04 18:01:59 -05:00
parent a9ebdc50b8
commit d07ab3e3ce
3 changed files with 31 additions and 32 deletions

View file

@ -4,10 +4,11 @@ Original Repository: https://github.com/Open-Reasoner-Zero/Open-Reasoner-Zero
"""
import asyncio
import os
import random
import re
from concurrent.futures import ProcessPoolExecutor
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union
import wandb
from datasets import load_dataset
@ -24,6 +25,7 @@ from atroposlib.envs.base import (
ScoredDataGroup,
ServerBaseline,
)
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
prompt_format = (
"A conversation between User and Assistant. The User asks a question, and the Assistant solves it. The Assistant "
@ -119,7 +121,7 @@ class MathEnv(BaseEnv):
def __init__(
self,
config: RSConfig,
server_configs: ServerBaseline,
server_configs: Union[ServerBaseline, List[APIServerConfig]],
slurm=True,
testing=False,
):
@ -137,26 +139,39 @@ class MathEnv(BaseEnv):
self.iter = 0
@classmethod
def config_init(cls) -> Tuple[RSConfig, ServerBaseline]:
def config_init(cls) -> Tuple[RSConfig, List[APIServerConfig]]:
# Allow configuration via environment variables for running multiple instances
model_name = os.environ.get("MATH_ENV_MODEL", "Qwen/Qwen3-4B-Instruct-2507")
rollout_url = os.environ.get("MATH_ENV_ROLLOUT_URL", "http://localhost:8000")
vllm_url = os.environ.get("MATH_ENV_VLLM_URL", "http://localhost:9001/v1")
wandb_name = os.environ.get("MATH_ENV_WANDB_NAME", "math-zero-env")
max_token_length = int(os.environ.get("MATH_ENV_MAX_TOKENS", "8192"))
env_config = RSConfig(
tokenizer_name="Qwen/Qwen2.5-7B",
group_size=16,
tokenizer_name=model_name,
group_size=8,
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=1024,
steps_per_eval=25,
max_token_length=31000, # 22000 // (2 ** i),
wandb_name="math",
rollout_server_url=rollout_url,
total_steps=120,
batch_size=64,
steps_per_eval=20,
max_token_length=max_token_length,
start_tok_length=max_token_length,
wandb_name=wandb_name,
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
eval_limit_ratio=0.1,
max_num_workers_per_node=24,
)
server_configs = ServerBaseline(
model_name="Qwen/Qwen2.5-7B",
num_requests_for_eval=256, # since evaling only on one...
server_type="vllm",
)
server_configs = [
APIServerConfig(
model_name=model_name,
base_url=vllm_url,
api_key="x",
num_requests_for_eval=256,
server_type="vllm",
weight=1.0,
)
]
return env_config, server_configs

View file

@ -11,11 +11,3 @@ env:
wandb_name: "math-zero-lora-env"
eval_limit_ratio: 0.1
max_num_workers_per_node: 24
openai:
base_url: "http://localhost:9002/v1"
model_name: "Qwen/Qwen3-4B-Instruct-2507"
server_type: "vllm"
api_key: "x"
num_requests_for_eval: 256
weight: 1.0

View file

@ -11,11 +11,3 @@ env:
wandb_name: "math-zero-shared-env"
eval_limit_ratio: 0.1
max_num_workers_per_node: 24
openai:
base_url: "http://localhost:9001/v1"
model_name: "Qwen/Qwen3-4B-Instruct-2507"
server_type: "vllm"
api_key: "x"
num_requests_for_eval: 256
weight: 1.0