Merge branch 'pipelineRL' into OnPolicyDistillation

This commit is contained in:
Jai Suphavadeeprasit 2026-02-19 16:39:21 -05:00
commit 33f5696171
23 changed files with 6975 additions and 758 deletions

View file

@ -4,11 +4,12 @@ Original Repository: https://github.com/Open-Reasoner-Zero/Open-Reasoner-Zero
"""
import asyncio
import os
import random
import re
import logging
from concurrent.futures import ProcessPoolExecutor
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union
import aiohttp
import wandb
@ -30,6 +31,7 @@ from atroposlib.envs.base import (
ScoredDataGroup,
ServerBaseline,
)
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
prompt_format = (
"A conversation between User and Assistant. The User asks a question, and the Assistant solves it. The Assistant "
@ -125,7 +127,7 @@ class MathEnv(BaseEnv):
def __init__(
self,
config: RSConfig,
server_configs: ServerBaseline,
server_configs: Union[ServerBaseline, List[APIServerConfig]],
slurm=True,
testing=False,
):
@ -152,26 +154,41 @@ class MathEnv(BaseEnv):
print("=" * 60)
@classmethod
def config_init(cls) -> Tuple[RSConfig, ServerBaseline]:
def config_init(cls) -> Tuple[RSConfig, List[APIServerConfig]]:
# Allow configuration via environment variables for running multiple instances
model_name = os.environ.get("MATH_ENV_MODEL", "Qwen/Qwen3-4B-Instruct-2507")
rollout_url = os.environ.get("MATH_ENV_ROLLOUT_URL", "http://localhost:8000")
vllm_url = os.environ.get("MATH_ENV_VLLM_URL", "http://localhost:9001/v1")
wandb_name = os.environ.get("MATH_ENV_WANDB_NAME", "math-zero-env")
max_token_length = int(os.environ.get("MATH_ENV_MAX_TOKENS", "32000"))
worker_timeout = float(os.environ.get("MATH_ENV_WORKER_TIMEOUT", "1500"))
env_config = RSConfig(
tokenizer_name="Qwen/Qwen2.5-7B",
group_size=16,
tokenizer_name=model_name,
group_size=8,
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=1024,
steps_per_eval=25,
max_token_length=31000, # 22000 // (2 ** i),
wandb_name="math",
rollout_server_url=rollout_url,
total_steps=120,
batch_size=64,
steps_per_eval=20,
max_token_length=max_token_length,
start_tok_length=max_token_length,
wandb_name=wandb_name,
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
eval_limit_ratio=0.1,
max_num_workers_per_node=24,
worker_timeout=worker_timeout,
)
server_configs = ServerBaseline(
model_name="Qwen/Qwen2.5-7B",
num_requests_for_eval=256, # since evaling only on one...
server_type="vllm",
)
server_configs = [
APIServerConfig(
model_name=model_name,
base_url=vllm_url,
api_key="x",
num_requests_for_eval=256,
server_type="vllm",
weight=1.0,
)
]
return env_config, server_configs
@ -352,7 +369,7 @@ class MathEnv(BaseEnv):
completion = await managed.completion(
prompt=question,
n=1,
max_tokens=32765,
max_tokens=self.config.max_token_length,
temperature=0.0,
split="eval",
stop=stop_list,
@ -376,6 +393,10 @@ class MathEnv(BaseEnv):
async def evaluate(self, *args, **kwargs):
if not self.config.run_evaluation:
return
import time
start_time = time.time()
eval_tasks = []
for item in self.test:
eval_tasks.append(self.rollout_and_score_eval(item[0], item[1], item[2]))
@ -385,17 +406,53 @@ class MathEnv(BaseEnv):
if subset not in task_lists:
task_lists[subset] = list()
task_lists[subset].append(score)
# Now get the average
# Build metrics dictionary for saving
metrics = {}
# Now get the average per subset
for subset, scores in task_lists.items():
self.eval_metrics.append(
(f"eval/{subset}_percent_correct", sum(scores) / len(scores))
)
accuracy = sum(scores) / len(scores)
metrics[f"{subset}_accuracy"] = accuracy
metrics[f"{subset}_total"] = len(scores)
metrics[f"{subset}_correct"] = sum(scores)
self.eval_metrics.append((f"eval/{subset}_percent_correct", accuracy))
# overall score
scores = []
all_scores = []
for subset, score in task_lists.items():
scores.extend(score)
self.eval_metrics.append(
("eval/overall_percent_correct", sum(scores) / len(scores))
all_scores.extend(score)
overall_accuracy = sum(all_scores) / len(all_scores)
metrics["overall_accuracy"] = overall_accuracy
metrics["overall_total"] = len(all_scores)
metrics["overall_correct"] = sum(all_scores)
self.eval_metrics.append(("eval/overall_percent_correct", overall_accuracy))
end_time = time.time()
# Print results to console
print("\n" + "=" * 60)
print("Math Zero Evaluation Results")
print("=" * 60)
print(
f"Overall Accuracy: {overall_accuracy:.2%} ({sum(all_scores)}/{len(all_scores)})"
)
print("\nPer-subset breakdown:")
for subset, scores in sorted(task_lists.items()):
acc = sum(scores) / len(scores)
print(f" {subset}: {acc:.2%} ({sum(scores)}/{len(scores)})")
print("=" * 60 + "\n")
# Save results to disk
await self.evaluate_log(
metrics=metrics,
task_name="math_zero",
start_time=start_time,
end_time=end_time,
generation_parameters={
"max_tokens": self.config.max_token_length,
"temperature": 0.0,
},
)
async def collect_trajectories(self, item) -> Tuple[List, List]: