mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
switch to evalbase for verifiers_eval.py
This commit is contained in:
parent
57fa229846
commit
c56af35eaa
2 changed files with 216 additions and 211 deletions
|
|
@ -14,32 +14,25 @@ To install a Verifiers/Prime environment:
|
|||
Docs: https://docs.primeintellect.ai/tutorials-environments/install
|
||||
|
||||
Usage:
|
||||
python verifiers_eval.py evaluate \
|
||||
--env.vf_env_name primeintellect/gsm8k \
|
||||
--openai.model_name gpt-4.1-nano \
|
||||
--openai.api_key $OPENAI_API_KEY
|
||||
python verifiers_eval.py \
|
||||
--server-url http://localhost:8000/v1 \
|
||||
--model-name Qwen/Qwen2.5-7B-Instruct \
|
||||
--vf-env-name primeintellect/gsm8k \
|
||||
--max-eval-items 100
|
||||
|
||||
Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.):
|
||||
python verifiers_eval.py evaluate \
|
||||
--env.vf_env_name primeintellect/gsm8k \
|
||||
--openai.model_name Qwen/Qwen2.5-7B-Instruct \
|
||||
--openai.base_url http://localhost:8000/v1
|
||||
Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.)
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import verifiers as vf
|
||||
import wandb
|
||||
from openai import AsyncOpenAI
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
)
|
||||
from atroposlib.envs.eval import EvalBase, evaluate_log
|
||||
from atroposlib.envs.server_handling.server_manager import ServerManager
|
||||
|
||||
|
||||
# Patch math_verify timeout to work in async context
|
||||
|
|
@ -70,149 +63,113 @@ except ImportError:
|
|||
pass # math_verify not installed
|
||||
|
||||
|
||||
class VerifiersEvaluationConfig(BaseEnvConfig):
|
||||
"""Configuration for Verifiers evaluation environment."""
|
||||
|
||||
vf_env_name: str = Field(
|
||||
default="",
|
||||
description="Verifiers environment name (e.g., primeintellect/gsm8k)",
|
||||
)
|
||||
env_args: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Additional arguments for verifiers environment",
|
||||
)
|
||||
temperature: float = Field(
|
||||
default=0.0,
|
||||
description="Temperature for generation (0.0 for deterministic)",
|
||||
)
|
||||
max_eval_items: int = Field(
|
||||
default=-1,
|
||||
description="Maximum number of items to evaluate (-1 for all)",
|
||||
)
|
||||
max_concurrent: int = Field(
|
||||
default=64,
|
||||
description="Maximum concurrent requests to the model",
|
||||
)
|
||||
|
||||
# Override BaseEnvConfig defaults for evaluation
|
||||
group_size: int = 1
|
||||
max_num_workers: int = 1024
|
||||
max_eval_workers: int = 256
|
||||
max_num_workers_per_node: int = 128
|
||||
use_wandb: bool = True
|
||||
rollout_server_url: str = "http://localhost:8000"
|
||||
total_steps: int = 1
|
||||
steps_per_eval: int = 1
|
||||
wandb_name: str = "verifiers_eval"
|
||||
|
||||
|
||||
class VerifiersEvaluationEnv(BaseEnv):
|
||||
class VerifiersEval(EvalBase):
|
||||
"""
|
||||
Verifiers Evaluation Environment.
|
||||
Verifiers Evaluation using EvalBase pattern.
|
||||
|
||||
Evaluates models using Prime Intellect's Verifiers library.
|
||||
Uses verifiers' native rollout and scoring machinery.
|
||||
Uses verifiers' native batch evaluation for efficiency,
|
||||
with EvalBase's standardized logging via evaluate_log().
|
||||
|
||||
Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.)
|
||||
"""
|
||||
|
||||
name = "verifiers_evaluation"
|
||||
env_config_cls = VerifiersEvaluationConfig # type: ignore[assignment]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: VerifiersEvaluationConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
vf_env_name: str = "primeintellect/gsm8k",
|
||||
env_args: dict = None,
|
||||
temperature: float = 0.0,
|
||||
max_tokens: int = 2048,
|
||||
max_eval_items: int = -1,
|
||||
max_concurrent: int = 64,
|
||||
eval_dir: str = None,
|
||||
verbose: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.config: VerifiersEvaluationConfig = config
|
||||
self.vf_env_name = vf_env_name
|
||||
self.env_args = env_args or {}
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.max_eval_items = max_eval_items
|
||||
self.max_concurrent = max_concurrent
|
||||
|
||||
self.vf_env = vf.load_environment(config.vf_env_name, **config.env_args)
|
||||
|
||||
# Get reward function names for metrics reporting
|
||||
# Load verifiers environment
|
||||
self.vf_env = vf.load_environment(vf_env_name, **self.env_args)
|
||||
self.reward_func_names = self.vf_env.rubric._get_reward_func_names()
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[VerifiersEvaluationConfig, List[APIServerConfig]]:
|
||||
"""Default configuration for evaluation."""
|
||||
env_config = VerifiersEvaluationConfig(
|
||||
vf_env_name="primeintellect/gsm8k",
|
||||
)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="gpt-4.1-nano",
|
||||
base_url="https://api.openai.com/v1",
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
),
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
||||
def _get_openai_client(self) -> AsyncOpenAI:
|
||||
"""Create AsyncOpenAI client from first server config."""
|
||||
server = self.server.servers[0]
|
||||
config = server.config
|
||||
return AsyncOpenAI(
|
||||
api_key=config.api_key or "x",
|
||||
base_url=config.base_url,
|
||||
timeout=config.timeout,
|
||||
# Initialize EvalBase (calls setup_data)
|
||||
super().__init__(
|
||||
eval_dir=eval_dir,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _get_model_name(self) -> str:
|
||||
"""Get model name from first server config."""
|
||||
return self.server.servers[0].config.model_name
|
||||
def get_generation_params(self):
|
||||
"""Generation params for logging."""
|
||||
return {
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
"n": 1,
|
||||
}
|
||||
|
||||
async def setup(self) -> None:
|
||||
"""Initialize the environment."""
|
||||
num_eval = len(self.vf_env.get_eval_dataset())
|
||||
if self.config.max_eval_items > 0:
|
||||
num_eval = min(num_eval, self.config.max_eval_items)
|
||||
def setup_data(self) -> list:
|
||||
"""Return evaluation dataset from verifiers environment."""
|
||||
dataset = self.vf_env.get_eval_dataset()
|
||||
if self.max_eval_items > 0:
|
||||
n = min(len(dataset), self.max_eval_items)
|
||||
dataset = dataset.select(range(n))
|
||||
return dataset.to_list()
|
||||
|
||||
print("\nVerifiers Evaluation Setup:")
|
||||
print(f" Environment: {self.config.vf_env_name}")
|
||||
print(f" Reward functions: {self.reward_func_names}")
|
||||
print(f" Evaluation items: {num_eval}")
|
||||
print(f" Max concurrent: {self.config.max_concurrent}")
|
||||
|
||||
async def evaluate(self) -> Dict:
|
||||
"""Run evaluation using verifiers' native machinery."""
|
||||
num_examples = (
|
||||
self.config.max_eval_items if self.config.max_eval_items > 0 else -1
|
||||
async def run_item(
|
||||
self, server: ServerManager, data_item: dict # noqa: ARG002
|
||||
) -> Tuple[dict, list]:
|
||||
"""Not used - we override __call__ for batch evaluation."""
|
||||
raise NotImplementedError(
|
||||
"VerifiersEval uses batch evaluation via __call__, not per-item run_item"
|
||||
)
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Starting Verifiers Evaluation: {self.config.vf_env_name}")
|
||||
print(f"{'=' * 60}")
|
||||
print(f" Model: {self._get_model_name()}")
|
||||
print(f" Temperature: {self.config.temperature}")
|
||||
print(f" Max concurrent: {self.config.max_concurrent}")
|
||||
print(f"{'=' * 60}\n")
|
||||
|
||||
async def __call__(self, server_manager: ServerManager):
|
||||
"""Run evaluation using verifiers' native batch machinery."""
|
||||
start_time = time.time()
|
||||
|
||||
# Create OpenAI client from atropos server config
|
||||
client = self._get_openai_client()
|
||||
model = self._get_model_name()
|
||||
# Create OpenAI client from server config
|
||||
server = server_manager.servers[0]
|
||||
client = AsyncOpenAI(
|
||||
api_key=server.config.api_key or "x",
|
||||
base_url=server.config.base_url,
|
||||
timeout=getattr(server.config, "timeout", 600),
|
||||
)
|
||||
model = server.config.model_name
|
||||
|
||||
# Let verifiers handle everything: rollouts + scoring
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Verifiers Evaluation: {self.vf_env_name}")
|
||||
print(f"{'=' * 60}")
|
||||
print(f" Model: {model}")
|
||||
print(f" Items: {len(self.data)}")
|
||||
print(f" Reward functions: {self.reward_func_names}")
|
||||
print(f" Temperature: {self.temperature}")
|
||||
print(f" Max concurrent: {self.max_concurrent}")
|
||||
print(f"{'=' * 60}\n")
|
||||
|
||||
num_examples = self.max_eval_items if self.max_eval_items > 0 else -1
|
||||
|
||||
# Use verifiers' batch evaluation
|
||||
results = await self.vf_env.evaluate(
|
||||
client=client,
|
||||
model=model,
|
||||
sampling_args={
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens": self.config.max_token_length,
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
},
|
||||
num_examples=num_examples,
|
||||
max_concurrent=self.config.max_concurrent,
|
||||
max_concurrent=self.max_concurrent,
|
||||
save_results=False,
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# Extract metrics from verifiers output
|
||||
# Extract from verifiers output
|
||||
rewards = results["reward"]
|
||||
per_func_metrics = results["metrics"] # dict of func_name -> list[float]
|
||||
per_func_metrics = results["metrics"]
|
||||
prompts = results["prompt"]
|
||||
completions = results["completion"]
|
||||
answers = results["answer"]
|
||||
|
|
@ -232,14 +189,11 @@ class VerifiersEvaluationEnv(BaseEnv):
|
|||
}
|
||||
|
||||
metrics = {
|
||||
"avg_score": avg_score,
|
||||
"accuracy": accuracy,
|
||||
"total_evaluated": total,
|
||||
"total_correct": correct,
|
||||
"reward_breakdown": reward_breakdown,
|
||||
"avg_score": avg_score,
|
||||
}
|
||||
|
||||
# Print results
|
||||
# Print results summary
|
||||
print(f"\n{'=' * 60}")
|
||||
print("Verifiers Evaluation Results")
|
||||
print(f"{'=' * 60}")
|
||||
|
|
@ -253,14 +207,13 @@ class VerifiersEvaluationEnv(BaseEnv):
|
|||
)
|
||||
print(f"{'=' * 60}\n")
|
||||
|
||||
# Log to evaluate_log (atropos's logging system)
|
||||
# Build samples for logging
|
||||
system_prompt = self.vf_env.system_prompt or ""
|
||||
samples = []
|
||||
for i in range(min(total, 100)): # Limit samples for logging
|
||||
prompt_msgs = prompts[i] if isinstance(prompts[i], list) else []
|
||||
completion_msgs = completions[i] if completions[i] else []
|
||||
|
||||
# Build full message list
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
|
@ -280,59 +233,114 @@ class VerifiersEvaluationEnv(BaseEnv):
|
|||
}
|
||||
)
|
||||
|
||||
await self.evaluate_log(
|
||||
metrics={"accuracy": accuracy, "avg_score": avg_score},
|
||||
samples=samples,
|
||||
# Use EvalBase's evaluate_log
|
||||
task_name = f"VerifiersEval@{self.vf_env_name.replace('/', '_')}"
|
||||
evaluate_log(
|
||||
metrics=metrics,
|
||||
eval_dir=getattr(self, "eval_dir", None),
|
||||
task_name=task_name,
|
||||
model_name=model,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
generation_parameters={
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens": self.config.max_token_length,
|
||||
},
|
||||
generation_parameters=self.get_generation_params(),
|
||||
samples=samples,
|
||||
verbose=getattr(self, "verbose", False),
|
||||
)
|
||||
|
||||
# Log to wandb
|
||||
await self.wandb_log(metrics)
|
||||
|
||||
return metrics
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None) -> None:
|
||||
"""Log metrics to Weights & Biases."""
|
||||
if not self.config.use_wandb or wandb_metrics is None:
|
||||
return
|
||||
|
||||
# Lazy init if wandb not already initialized
|
||||
if wandb.run is None:
|
||||
wandb.init(
|
||||
project="atropos-environments",
|
||||
name=self.config.wandb_name,
|
||||
config=self.config.model_dump(),
|
||||
)
|
||||
async def main():
|
||||
"""CLI entry point for verifiers evaluation."""
|
||||
import os
|
||||
|
||||
log_dict = {
|
||||
"verifiers/accuracy": wandb_metrics.get("accuracy", 0),
|
||||
"verifiers/avg_score": wandb_metrics.get("avg_score", 0),
|
||||
"verifiers/total_evaluated": wandb_metrics.get("total_evaluated", 0),
|
||||
"verifiers/total_correct": wandb_metrics.get("total_correct", 0),
|
||||
}
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
|
||||
|
||||
# Add per-reward function metrics
|
||||
reward_breakdown = wandb_metrics.get("reward_breakdown", {})
|
||||
for func_name, data in reward_breakdown.items():
|
||||
log_dict[f"verifiers/{func_name}_avg"] = data.get("avg", 0)
|
||||
log_dict[f"verifiers/{func_name}_correct"] = data.get("correct", 0)
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Evaluate models using Verifiers environments"
|
||||
)
|
||||
# Server args (same as eval_runner)
|
||||
parser.add_argument(
|
||||
"--server-url",
|
||||
type=str,
|
||||
default="http://localhost:8000/v1",
|
||||
help="URL of the inference server",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model name to evaluate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api-key",
|
||||
type=str,
|
||||
default=os.getenv("OPENAI_API_KEY", "x"),
|
||||
help="API key (defaults to OPENAI_API_KEY env var)",
|
||||
)
|
||||
# Verifiers-specific args
|
||||
parser.add_argument(
|
||||
"--vf-env-name",
|
||||
type=str,
|
||||
default="primeintellect/gsm8k",
|
||||
help="Verifiers environment name (e.g., primeintellect/gsm8k)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-eval-items",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Maximum items to evaluate (-1 for all)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Generation temperature",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Maximum tokens per completion",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-concurrent",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Maximum concurrent requests",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to save evaluation results",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
wandb.log(log_dict)
|
||||
# Create server manager
|
||||
server_manager = ServerManager(
|
||||
configs=[
|
||||
APIServerConfig(
|
||||
api_key=args.api_key,
|
||||
base_url=args.server_url,
|
||||
model_name=args.model_name,
|
||||
health_check=False,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# Required abstract method implementations (stubs for evaluation-only mode)
|
||||
async def get_next_item(self) -> Optional[Dict]:
|
||||
"""Not used in evaluation mode."""
|
||||
return None
|
||||
|
||||
async def collect_trajectories(self, item) -> Tuple[List, List]: # noqa: ARG002
|
||||
"""Not used in evaluation mode."""
|
||||
return [], []
|
||||
# Create and run evaluation
|
||||
eval_instance = VerifiersEval(
|
||||
vf_env_name=args.vf_env_name,
|
||||
max_eval_items=args.max_eval_items,
|
||||
temperature=args.temperature,
|
||||
max_tokens=args.max_tokens,
|
||||
max_concurrent=args.max_concurrent,
|
||||
eval_dir=args.eval_dir,
|
||||
verbose=True,
|
||||
)
|
||||
return await eval_instance(server_manager)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
VerifiersEvaluationEnv.cli()
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue