mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
math zero work arounds
This commit is contained in:
parent
a9ebdc50b8
commit
d07ab3e3ce
3 changed files with 31 additions and 32 deletions
|
|
@ -4,10 +4,11 @@ Original Repository: https://github.com/Open-Reasoner-Zero/Open-Reasoner-Zero
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import wandb
|
||||
from datasets import load_dataset
|
||||
|
|
@ -24,6 +25,7 @@ from atroposlib.envs.base import (
|
|||
ScoredDataGroup,
|
||||
ServerBaseline,
|
||||
)
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
|
||||
|
||||
prompt_format = (
|
||||
"A conversation between User and Assistant. The User asks a question, and the Assistant solves it. The Assistant "
|
||||
|
|
@ -119,7 +121,7 @@ class MathEnv(BaseEnv):
|
|||
def __init__(
|
||||
self,
|
||||
config: RSConfig,
|
||||
server_configs: ServerBaseline,
|
||||
server_configs: Union[ServerBaseline, List[APIServerConfig]],
|
||||
slurm=True,
|
||||
testing=False,
|
||||
):
|
||||
|
|
@ -137,26 +139,39 @@ class MathEnv(BaseEnv):
|
|||
self.iter = 0
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[RSConfig, ServerBaseline]:
|
||||
def config_init(cls) -> Tuple[RSConfig, List[APIServerConfig]]:
|
||||
# Allow configuration via environment variables for running multiple instances
|
||||
model_name = os.environ.get("MATH_ENV_MODEL", "Qwen/Qwen3-4B-Instruct-2507")
|
||||
rollout_url = os.environ.get("MATH_ENV_ROLLOUT_URL", "http://localhost:8000")
|
||||
vllm_url = os.environ.get("MATH_ENV_VLLM_URL", "http://localhost:9001/v1")
|
||||
wandb_name = os.environ.get("MATH_ENV_WANDB_NAME", "math-zero-env")
|
||||
max_token_length = int(os.environ.get("MATH_ENV_MAX_TOKENS", "8192"))
|
||||
|
||||
env_config = RSConfig(
|
||||
tokenizer_name="Qwen/Qwen2.5-7B",
|
||||
group_size=16,
|
||||
tokenizer_name=model_name,
|
||||
group_size=8,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=1024,
|
||||
steps_per_eval=25,
|
||||
max_token_length=31000, # 22000 // (2 ** i),
|
||||
wandb_name="math",
|
||||
rollout_server_url=rollout_url,
|
||||
total_steps=120,
|
||||
batch_size=64,
|
||||
steps_per_eval=20,
|
||||
max_token_length=max_token_length,
|
||||
start_tok_length=max_token_length,
|
||||
wandb_name=wandb_name,
|
||||
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
|
||||
eval_limit_ratio=0.1,
|
||||
max_num_workers_per_node=24,
|
||||
)
|
||||
server_configs = ServerBaseline(
|
||||
model_name="Qwen/Qwen2.5-7B",
|
||||
num_requests_for_eval=256, # since evaling only on one...
|
||||
server_type="vllm",
|
||||
)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name=model_name,
|
||||
base_url=vllm_url,
|
||||
api_key="x",
|
||||
num_requests_for_eval=256,
|
||||
server_type="vllm",
|
||||
weight=1.0,
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
|
|
|
|||
|
|
@ -11,11 +11,3 @@ env:
|
|||
wandb_name: "math-zero-lora-env"
|
||||
eval_limit_ratio: 0.1
|
||||
max_num_workers_per_node: 24
|
||||
|
||||
openai:
|
||||
base_url: "http://localhost:9002/v1"
|
||||
model_name: "Qwen/Qwen3-4B-Instruct-2507"
|
||||
server_type: "vllm"
|
||||
api_key: "x"
|
||||
num_requests_for_eval: 256
|
||||
weight: 1.0
|
||||
|
|
|
|||
|
|
@ -11,11 +11,3 @@ env:
|
|||
wandb_name: "math-zero-shared-env"
|
||||
eval_limit_ratio: 0.1
|
||||
max_num_workers_per_node: 24
|
||||
|
||||
openai:
|
||||
base_url: "http://localhost:9001/v1"
|
||||
model_name: "Qwen/Qwen3-4B-Instruct-2507"
|
||||
server_type: "vllm"
|
||||
api_key: "x"
|
||||
num_requests_for_eval: 256
|
||||
weight: 1.0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue