switch to evalbase for verifiers_eval.py

This commit is contained in:
balyan.sid@gmail.com 2026-01-15 11:34:40 +05:30
parent 57fa229846
commit c56af35eaa
2 changed files with 216 additions and 211 deletions

View file

@ -14,32 +14,25 @@ To install a Verifiers/Prime environment:
Docs: https://docs.primeintellect.ai/tutorials-environments/install
Usage:
python verifiers_eval.py evaluate \
--env.vf_env_name primeintellect/gsm8k \
--openai.model_name gpt-4.1-nano \
--openai.api_key $OPENAI_API_KEY
python verifiers_eval.py \
--server-url http://localhost:8000/v1 \
--model-name Qwen/Qwen2.5-7B-Instruct \
--vf-env-name primeintellect/gsm8k \
--max-eval-items 100
Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.):
python verifiers_eval.py evaluate \
--env.vf_env_name primeintellect/gsm8k \
--openai.model_name Qwen/Qwen2.5-7B-Instruct \
--openai.base_url http://localhost:8000/v1
Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.)
"""
import os
import argparse
import asyncio
import time
from typing import Any, Dict, List, Optional, Tuple
from typing import Tuple
import verifiers as vf
import wandb
from openai import AsyncOpenAI
from pydantic import Field
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
)
from atroposlib.envs.eval import EvalBase, evaluate_log
from atroposlib.envs.server_handling.server_manager import ServerManager
# Patch math_verify timeout to work in async context
@ -70,149 +63,113 @@ except ImportError:
pass # math_verify not installed
class VerifiersEvaluationConfig(BaseEnvConfig):
"""Configuration for Verifiers evaluation environment."""
vf_env_name: str = Field(
default="",
description="Verifiers environment name (e.g., primeintellect/gsm8k)",
)
env_args: Dict[str, Any] = Field(
default_factory=dict,
description="Additional arguments for verifiers environment",
)
temperature: float = Field(
default=0.0,
description="Temperature for generation (0.0 for deterministic)",
)
max_eval_items: int = Field(
default=-1,
description="Maximum number of items to evaluate (-1 for all)",
)
max_concurrent: int = Field(
default=64,
description="Maximum concurrent requests to the model",
)
# Override BaseEnvConfig defaults for evaluation
group_size: int = 1
max_num_workers: int = 1024
max_eval_workers: int = 256
max_num_workers_per_node: int = 128
use_wandb: bool = True
rollout_server_url: str = "http://localhost:8000"
total_steps: int = 1
steps_per_eval: int = 1
wandb_name: str = "verifiers_eval"
class VerifiersEvaluationEnv(BaseEnv):
class VerifiersEval(EvalBase):
"""
Verifiers Evaluation Environment.
Verifiers Evaluation using EvalBase pattern.
Evaluates models using Prime Intellect's Verifiers library.
Uses verifiers' native rollout and scoring machinery.
Uses verifiers' native batch evaluation for efficiency,
with EvalBase's standardized logging via evaluate_log().
Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.)
"""
name = "verifiers_evaluation"
env_config_cls = VerifiersEvaluationConfig # type: ignore[assignment]
def __init__(
self,
config: VerifiersEvaluationConfig,
server_configs: List[APIServerConfig],
slurm: bool = False,
testing: bool = False,
vf_env_name: str = "primeintellect/gsm8k",
env_args: dict = None,
temperature: float = 0.0,
max_tokens: int = 2048,
max_eval_items: int = -1,
max_concurrent: int = 64,
eval_dir: str = None,
verbose: bool = True,
**kwargs,
):
super().__init__(config, server_configs, slurm, testing)
self.config: VerifiersEvaluationConfig = config
self.vf_env_name = vf_env_name
self.env_args = env_args or {}
self.temperature = temperature
self.max_tokens = max_tokens
self.max_eval_items = max_eval_items
self.max_concurrent = max_concurrent
self.vf_env = vf.load_environment(config.vf_env_name, **config.env_args)
# Get reward function names for metrics reporting
# Load verifiers environment
self.vf_env = vf.load_environment(vf_env_name, **self.env_args)
self.reward_func_names = self.vf_env.rubric._get_reward_func_names()
@classmethod
def config_init(cls) -> Tuple[VerifiersEvaluationConfig, List[APIServerConfig]]:
"""Default configuration for evaluation."""
env_config = VerifiersEvaluationConfig(
vf_env_name="primeintellect/gsm8k",
)
server_configs = [
APIServerConfig(
model_name="gpt-4.1-nano",
base_url="https://api.openai.com/v1",
api_key=os.getenv("OPENAI_API_KEY"),
),
]
return env_config, server_configs
def _get_openai_client(self) -> AsyncOpenAI:
"""Create AsyncOpenAI client from first server config."""
server = self.server.servers[0]
config = server.config
return AsyncOpenAI(
api_key=config.api_key or "x",
base_url=config.base_url,
timeout=config.timeout,
# Initialize EvalBase (calls setup_data)
super().__init__(
eval_dir=eval_dir,
verbose=verbose,
**kwargs,
)
def _get_model_name(self) -> str:
"""Get model name from first server config."""
return self.server.servers[0].config.model_name
def get_generation_params(self):
"""Generation params for logging."""
return {
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"n": 1,
}
async def setup(self) -> None:
"""Initialize the environment."""
num_eval = len(self.vf_env.get_eval_dataset())
if self.config.max_eval_items > 0:
num_eval = min(num_eval, self.config.max_eval_items)
def setup_data(self) -> list:
"""Return evaluation dataset from verifiers environment."""
dataset = self.vf_env.get_eval_dataset()
if self.max_eval_items > 0:
n = min(len(dataset), self.max_eval_items)
dataset = dataset.select(range(n))
return dataset.to_list()
print("\nVerifiers Evaluation Setup:")
print(f" Environment: {self.config.vf_env_name}")
print(f" Reward functions: {self.reward_func_names}")
print(f" Evaluation items: {num_eval}")
print(f" Max concurrent: {self.config.max_concurrent}")
async def evaluate(self) -> Dict:
"""Run evaluation using verifiers' native machinery."""
num_examples = (
self.config.max_eval_items if self.config.max_eval_items > 0 else -1
async def run_item(
self, server: ServerManager, data_item: dict # noqa: ARG002
) -> Tuple[dict, list]:
"""Not used - we override __call__ for batch evaluation."""
raise NotImplementedError(
"VerifiersEval uses batch evaluation via __call__, not per-item run_item"
)
print(f"\n{'=' * 60}")
print(f"Starting Verifiers Evaluation: {self.config.vf_env_name}")
print(f"{'=' * 60}")
print(f" Model: {self._get_model_name()}")
print(f" Temperature: {self.config.temperature}")
print(f" Max concurrent: {self.config.max_concurrent}")
print(f"{'=' * 60}\n")
async def __call__(self, server_manager: ServerManager):
"""Run evaluation using verifiers' native batch machinery."""
start_time = time.time()
# Create OpenAI client from atropos server config
client = self._get_openai_client()
model = self._get_model_name()
# Create OpenAI client from server config
server = server_manager.servers[0]
client = AsyncOpenAI(
api_key=server.config.api_key or "x",
base_url=server.config.base_url,
timeout=getattr(server.config, "timeout", 600),
)
model = server.config.model_name
# Let verifiers handle everything: rollouts + scoring
print(f"\n{'=' * 60}")
print(f"Verifiers Evaluation: {self.vf_env_name}")
print(f"{'=' * 60}")
print(f" Model: {model}")
print(f" Items: {len(self.data)}")
print(f" Reward functions: {self.reward_func_names}")
print(f" Temperature: {self.temperature}")
print(f" Max concurrent: {self.max_concurrent}")
print(f"{'=' * 60}\n")
num_examples = self.max_eval_items if self.max_eval_items > 0 else -1
# Use verifiers' batch evaluation
results = await self.vf_env.evaluate(
client=client,
model=model,
sampling_args={
"temperature": self.config.temperature,
"max_tokens": self.config.max_token_length,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
},
num_examples=num_examples,
max_concurrent=self.config.max_concurrent,
max_concurrent=self.max_concurrent,
save_results=False,
)
end_time = time.time()
# Extract metrics from verifiers output
# Extract from verifiers output
rewards = results["reward"]
per_func_metrics = results["metrics"] # dict of func_name -> list[float]
per_func_metrics = results["metrics"]
prompts = results["prompt"]
completions = results["completion"]
answers = results["answer"]
@ -232,14 +189,11 @@ class VerifiersEvaluationEnv(BaseEnv):
}
metrics = {
"avg_score": avg_score,
"accuracy": accuracy,
"total_evaluated": total,
"total_correct": correct,
"reward_breakdown": reward_breakdown,
"avg_score": avg_score,
}
# Print results
# Print results summary
print(f"\n{'=' * 60}")
print("Verifiers Evaluation Results")
print(f"{'=' * 60}")
@ -253,14 +207,13 @@ class VerifiersEvaluationEnv(BaseEnv):
)
print(f"{'=' * 60}\n")
# Log to evaluate_log (atropos's logging system)
# Build samples for logging
system_prompt = self.vf_env.system_prompt or ""
samples = []
for i in range(min(total, 100)): # Limit samples for logging
prompt_msgs = prompts[i] if isinstance(prompts[i], list) else []
completion_msgs = completions[i] if completions[i] else []
# Build full message list
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
@ -280,59 +233,114 @@ class VerifiersEvaluationEnv(BaseEnv):
}
)
await self.evaluate_log(
metrics={"accuracy": accuracy, "avg_score": avg_score},
samples=samples,
# Use EvalBase's evaluate_log
task_name = f"VerifiersEval@{self.vf_env_name.replace('/', '_')}"
evaluate_log(
metrics=metrics,
eval_dir=getattr(self, "eval_dir", None),
task_name=task_name,
model_name=model,
start_time=start_time,
end_time=end_time,
generation_parameters={
"temperature": self.config.temperature,
"max_tokens": self.config.max_token_length,
},
generation_parameters=self.get_generation_params(),
samples=samples,
verbose=getattr(self, "verbose", False),
)
# Log to wandb
await self.wandb_log(metrics)
return metrics
async def wandb_log(self, wandb_metrics: Optional[Dict] = None) -> None:
"""Log metrics to Weights & Biases."""
if not self.config.use_wandb or wandb_metrics is None:
return
# Lazy init if wandb not already initialized
if wandb.run is None:
wandb.init(
project="atropos-environments",
name=self.config.wandb_name,
config=self.config.model_dump(),
)
async def main():
"""CLI entry point for verifiers evaluation."""
import os
log_dict = {
"verifiers/accuracy": wandb_metrics.get("accuracy", 0),
"verifiers/avg_score": wandb_metrics.get("avg_score", 0),
"verifiers/total_evaluated": wandb_metrics.get("total_evaluated", 0),
"verifiers/total_correct": wandb_metrics.get("total_correct", 0),
}
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
# Add per-reward function metrics
reward_breakdown = wandb_metrics.get("reward_breakdown", {})
for func_name, data in reward_breakdown.items():
log_dict[f"verifiers/{func_name}_avg"] = data.get("avg", 0)
log_dict[f"verifiers/{func_name}_correct"] = data.get("correct", 0)
parser = argparse.ArgumentParser(
description="Evaluate models using Verifiers environments"
)
# Server args (same as eval_runner)
parser.add_argument(
"--server-url",
type=str,
default="http://localhost:8000/v1",
help="URL of the inference server",
)
parser.add_argument(
"--model-name",
type=str,
required=True,
help="Model name to evaluate",
)
parser.add_argument(
"--api-key",
type=str,
default=os.getenv("OPENAI_API_KEY", "x"),
help="API key (defaults to OPENAI_API_KEY env var)",
)
# Verifiers-specific args
parser.add_argument(
"--vf-env-name",
type=str,
default="primeintellect/gsm8k",
help="Verifiers environment name (e.g., primeintellect/gsm8k)",
)
parser.add_argument(
"--max-eval-items",
type=int,
default=-1,
help="Maximum items to evaluate (-1 for all)",
)
parser.add_argument(
"--temperature",
type=float,
default=0.0,
help="Generation temperature",
)
parser.add_argument(
"--max-tokens",
type=int,
default=2048,
help="Maximum tokens per completion",
)
parser.add_argument(
"--max-concurrent",
type=int,
default=64,
help="Maximum concurrent requests",
)
parser.add_argument(
"--eval-dir",
type=str,
default=None,
help="Directory to save evaluation results",
)
args = parser.parse_args()
wandb.log(log_dict)
# Create server manager
server_manager = ServerManager(
configs=[
APIServerConfig(
api_key=args.api_key,
base_url=args.server_url,
model_name=args.model_name,
health_check=False,
),
]
)
# Required abstract method implementations (stubs for evaluation-only mode)
async def get_next_item(self) -> Optional[Dict]:
"""Not used in evaluation mode."""
return None
async def collect_trajectories(self, item) -> Tuple[List, List]: # noqa: ARG002
"""Not used in evaluation mode."""
return [], []
# Create and run evaluation
eval_instance = VerifiersEval(
vf_env_name=args.vf_env_name,
max_eval_items=args.max_eval_items,
temperature=args.temperature,
max_tokens=args.max_tokens,
max_concurrent=args.max_concurrent,
eval_dir=args.eval_dir,
verbose=True,
)
return await eval_instance(server_manager)
if __name__ == "__main__":
VerifiersEvaluationEnv.cli()
asyncio.run(main())