mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-25 17:10:42 +00:00
Merge branch 'pipelineRL' into OnPolicyDistillation
This commit is contained in:
commit
33f5696171
23 changed files with 6975 additions and 758 deletions
|
|
@ -4,11 +4,12 @@ Original Repository: https://github.com/Open-Reasoner-Zero/Open-Reasoner-Zero
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import logging
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
import wandb
|
||||
|
|
@ -30,6 +31,7 @@ from atroposlib.envs.base import (
|
|||
ScoredDataGroup,
|
||||
ServerBaseline,
|
||||
)
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
|
||||
|
||||
prompt_format = (
|
||||
"A conversation between User and Assistant. The User asks a question, and the Assistant solves it. The Assistant "
|
||||
|
|
@ -125,7 +127,7 @@ class MathEnv(BaseEnv):
|
|||
def __init__(
|
||||
self,
|
||||
config: RSConfig,
|
||||
server_configs: ServerBaseline,
|
||||
server_configs: Union[ServerBaseline, List[APIServerConfig]],
|
||||
slurm=True,
|
||||
testing=False,
|
||||
):
|
||||
|
|
@ -152,26 +154,41 @@ class MathEnv(BaseEnv):
|
|||
print("=" * 60)
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[RSConfig, ServerBaseline]:
|
||||
def config_init(cls) -> Tuple[RSConfig, List[APIServerConfig]]:
|
||||
# Allow configuration via environment variables for running multiple instances
|
||||
model_name = os.environ.get("MATH_ENV_MODEL", "Qwen/Qwen3-4B-Instruct-2507")
|
||||
rollout_url = os.environ.get("MATH_ENV_ROLLOUT_URL", "http://localhost:8000")
|
||||
vllm_url = os.environ.get("MATH_ENV_VLLM_URL", "http://localhost:9001/v1")
|
||||
wandb_name = os.environ.get("MATH_ENV_WANDB_NAME", "math-zero-env")
|
||||
max_token_length = int(os.environ.get("MATH_ENV_MAX_TOKENS", "32000"))
|
||||
worker_timeout = float(os.environ.get("MATH_ENV_WORKER_TIMEOUT", "1500"))
|
||||
|
||||
env_config = RSConfig(
|
||||
tokenizer_name="Qwen/Qwen2.5-7B",
|
||||
group_size=16,
|
||||
tokenizer_name=model_name,
|
||||
group_size=8,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=1024,
|
||||
steps_per_eval=25,
|
||||
max_token_length=31000, # 22000 // (2 ** i),
|
||||
wandb_name="math",
|
||||
rollout_server_url=rollout_url,
|
||||
total_steps=120,
|
||||
batch_size=64,
|
||||
steps_per_eval=20,
|
||||
max_token_length=max_token_length,
|
||||
start_tok_length=max_token_length,
|
||||
wandb_name=wandb_name,
|
||||
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
|
||||
eval_limit_ratio=0.1,
|
||||
max_num_workers_per_node=24,
|
||||
worker_timeout=worker_timeout,
|
||||
)
|
||||
server_configs = ServerBaseline(
|
||||
model_name="Qwen/Qwen2.5-7B",
|
||||
num_requests_for_eval=256, # since evaling only on one...
|
||||
server_type="vllm",
|
||||
)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name=model_name,
|
||||
base_url=vllm_url,
|
||||
api_key="x",
|
||||
num_requests_for_eval=256,
|
||||
server_type="vllm",
|
||||
weight=1.0,
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
|
|
@ -352,7 +369,7 @@ class MathEnv(BaseEnv):
|
|||
completion = await managed.completion(
|
||||
prompt=question,
|
||||
n=1,
|
||||
max_tokens=32765,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
stop=stop_list,
|
||||
|
|
@ -376,6 +393,10 @@ class MathEnv(BaseEnv):
|
|||
async def evaluate(self, *args, **kwargs):
|
||||
if not self.config.run_evaluation:
|
||||
return
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
eval_tasks = []
|
||||
for item in self.test:
|
||||
eval_tasks.append(self.rollout_and_score_eval(item[0], item[1], item[2]))
|
||||
|
|
@ -385,17 +406,53 @@ class MathEnv(BaseEnv):
|
|||
if subset not in task_lists:
|
||||
task_lists[subset] = list()
|
||||
task_lists[subset].append(score)
|
||||
# Now get the average
|
||||
|
||||
# Build metrics dictionary for saving
|
||||
metrics = {}
|
||||
|
||||
# Now get the average per subset
|
||||
for subset, scores in task_lists.items():
|
||||
self.eval_metrics.append(
|
||||
(f"eval/{subset}_percent_correct", sum(scores) / len(scores))
|
||||
)
|
||||
accuracy = sum(scores) / len(scores)
|
||||
metrics[f"{subset}_accuracy"] = accuracy
|
||||
metrics[f"{subset}_total"] = len(scores)
|
||||
metrics[f"{subset}_correct"] = sum(scores)
|
||||
self.eval_metrics.append((f"eval/{subset}_percent_correct", accuracy))
|
||||
|
||||
# overall score
|
||||
scores = []
|
||||
all_scores = []
|
||||
for subset, score in task_lists.items():
|
||||
scores.extend(score)
|
||||
self.eval_metrics.append(
|
||||
("eval/overall_percent_correct", sum(scores) / len(scores))
|
||||
all_scores.extend(score)
|
||||
overall_accuracy = sum(all_scores) / len(all_scores)
|
||||
metrics["overall_accuracy"] = overall_accuracy
|
||||
metrics["overall_total"] = len(all_scores)
|
||||
metrics["overall_correct"] = sum(all_scores)
|
||||
self.eval_metrics.append(("eval/overall_percent_correct", overall_accuracy))
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# Print results to console
|
||||
print("\n" + "=" * 60)
|
||||
print("Math Zero Evaluation Results")
|
||||
print("=" * 60)
|
||||
print(
|
||||
f"Overall Accuracy: {overall_accuracy:.2%} ({sum(all_scores)}/{len(all_scores)})"
|
||||
)
|
||||
print("\nPer-subset breakdown:")
|
||||
for subset, scores in sorted(task_lists.items()):
|
||||
acc = sum(scores) / len(scores)
|
||||
print(f" {subset}: {acc:.2%} ({sum(scores)}/{len(scores)})")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
# Save results to disk
|
||||
await self.evaluate_log(
|
||||
metrics=metrics,
|
||||
task_name="math_zero",
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
generation_parameters={
|
||||
"max_tokens": self.config.max_token_length,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
)
|
||||
|
||||
async def collect_trajectories(self, item) -> Tuple[List, List]:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue