clean up eval, pin verifiers version

This commit is contained in:
balyan.sid@gmail.com 2026-01-12 05:38:15 +05:30
parent d98bc6d9fc
commit 24b4488c60
3 changed files with 202 additions and 244 deletions

View file

@ -4,6 +4,9 @@ Verifiers Evaluation Environment for Atropos
This environment evaluates models using Prime Intellect's Verifiers library.
It supports any environment registered with the Verifiers ecosystem.
Uses verifiers' native rollout and scoring machinery - just pass an OpenAI-compatible
client and verifiers handles generation, parsing, and scoring.
To install a Verifiers/Prime environment:
1. uv tool install prime
2. prime login
@ -15,18 +18,24 @@ Usage:
--env.vf_env_name primeintellect/gsm8k \
--openai.model_name gpt-4.1-nano \
--openai.api_key $OPENAI_API_KEY
Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.):
python verifiers_eval.py evaluate \
--env.vf_env_name primeintellect/gsm8k \
--openai.model_name Qwen/Qwen2.5-7B-Instruct \
--openai.base_url http://localhost:8000/v1
"""
import asyncio
import inspect
import os
import time
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import verifiers as vf
import wandb
from openai import AsyncOpenAI
from pydantic import Field
import wandb
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
@ -36,14 +45,12 @@ from atroposlib.envs.base import (
# Patch math_verify timeout to work in async context
# The signal-based timeout doesn't work in non-main threads (asyncio event loop)
def _no_signal_timeout(timeout_seconds: int):
def _no_signal_timeout(_timeout_seconds: int):
"""Replacement timeout decorator that doesn't use signals."""
def decorator(func):
def wrapper(*args, **kwargs):
# Just call the function without timeout
# This is safe because we're in an async context with our own timeouts
# timeout_seconds is intentionally unused - we're replacing the timeout logic
# Just call the function without timeout - safe in async context
return func(*args, **kwargs)
return wrapper
@ -67,41 +74,47 @@ except ImportError:
class VerifiersEvaluationConfig(BaseEnvConfig):
"""Configuration for Verifiers evaluation environment."""
# Verifiers environment
vf_env_name: str = Field(
default="",
description="Verifiers environment name (e.g., primeintellect/gsm8k)",
)
env_args: dict = Field(
env_args: Dict[str, Any] = Field(
default_factory=dict,
description="Additional arguments for verifiers environment",
)
temperature: float = Field(
default=0.0, description="Temperature for generation (0.0 for deterministic)"
default=0.0,
description="Temperature for generation (0.0 for deterministic)",
)
max_eval_items: int = Field(
default=-1,
description="Maximum number of items to evaluate (-1 for all)",
)
max_concurrent: int = Field(
default=64,
description="Maximum concurrent requests to the model",
)
max_retries: int = Field(
default=3, description="Maximum retries for failed API calls"
)
retry_delay: float = Field(
default=1.0, description="Delay between retries in seconds"
)
min_response_length: int = Field(
default=1, description="Minimum response length to consider valid"
)
full_debug: bool = Field(default=False, description="Enable full debug output")
max_eval_items: int = Field(
default=-1, description="Maximum number of items to evaluate (-1 for all)"
)
# Override BaseEnvConfig defaults for evaluation
group_size: int = 1
max_num_workers: int = 1024
max_eval_workers: int = 256
max_num_workers_per_node: int = 128
use_wandb: bool = True
rollout_server_url: str = "http://localhost:8000"
total_steps: int = 1
steps_per_eval: int = 1
wandb_name: str = "verifiers_eval"
class VerifiersEvaluationEnv(BaseEnv):
"""
Verifiers Evaluation Environment.
Evaluates models using Prime Intellect's Verifiers library rubrics.
Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, etc.)
Evaluates models using Prime Intellect's Verifiers library.
Uses verifiers' native rollout and scoring machinery.
Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.)
"""
name = "verifiers_evaluation"
@ -117,39 +130,10 @@ class VerifiersEvaluationEnv(BaseEnv):
super().__init__(config, server_configs, slurm, testing)
self.config: VerifiersEvaluationConfig = config
# Load verifiers environment
self.vf_env = vf.load_environment(config.vf_env_name, **config.env_args)
self.rubric = self.vf_env.rubric
# Extract rubric components from RubricGroup
# RubricGroup.funcs is empty - need to collect from individual rubrics
self.parser = self.rubric.parser
self.reward_funcs: List[Callable] = []
self.reward_weights: List[float] = []
self.rubric_class_objects: List[Dict[str, Any]] = [] # class_objects per func
if hasattr(self.rubric, "rubrics"):
# RubricGroup: collect from all individual rubrics
for rubric in self.rubric.rubrics:
class_objects = getattr(rubric, "class_objects", {})
for func, weight in zip(rubric.funcs, rubric.weights):
self.reward_funcs.append(func)
self.reward_weights.append(weight)
self.rubric_class_objects.append(class_objects)
else:
# Single Rubric
self.reward_funcs = self.rubric.funcs
self.reward_weights = self.rubric.weights
class_objects = getattr(self.rubric, "class_objects", {})
self.rubric_class_objects = [class_objects] * len(self.rubric.funcs)
total_weight = sum(self.reward_weights) if self.reward_weights else 1.0
self.reward_scales = [weight / total_weight for weight in self.reward_weights]
self.system_prompt = self.vf_env.system_prompt
# Tracking
self.eval_items: List[Dict] = []
self._dataset_loaded = False
# Get reward function names for metrics reporting
self.reward_func_names = self.vf_env.rubric._get_reward_func_names()
@classmethod
def config_init(cls) -> Tuple[VerifiersEvaluationConfig, List[APIServerConfig]]:
@ -166,183 +150,87 @@ class VerifiersEvaluationEnv(BaseEnv):
]
return env_config, server_configs
def _get_openai_client(self) -> AsyncOpenAI:
"""Create AsyncOpenAI client from first server config."""
server = self.server.servers[0]
config = server.config
return AsyncOpenAI(
api_key=config.api_key or "x",
base_url=config.base_url,
timeout=config.timeout,
)
def _get_model_name(self) -> str:
"""Get model name from first server config."""
return self.server.servers[0].config.model_name
async def setup(self) -> None:
"""Initialize the environment and load datasets."""
if not self._dataset_loaded:
# Load datasets from verifiers environment
test_data = self.vf_env.get_eval_dataset()
self.eval_items = test_data.select_columns(["question", "answer"]).to_list()
# Limit items if max_eval_items is set
if self.config.max_eval_items > 0:
self.eval_items = self.eval_items[: self.config.max_eval_items]
self._dataset_loaded = True
"""Initialize the environment."""
num_eval = len(self.vf_env.get_eval_dataset())
if self.config.max_eval_items > 0:
num_eval = min(num_eval, self.config.max_eval_items)
print("\nVerifiers Evaluation Setup:")
print(f" Environment: {self.config.vf_env_name}")
print(f" Reward functions: {len(self.reward_funcs)}")
print(f" Reward weights: {self.reward_weights}")
print(f" Loaded {len(self.eval_items)} evaluation items")
print(f" Reward functions: {self.reward_func_names}")
print(f" Evaluation items: {num_eval}")
print(f" Max concurrent: {self.config.max_concurrent}")
async def rollout_and_score(self, item: Dict) -> Optional[Dict]:
"""
Run evaluation on a single item and return the result.
async def evaluate(self) -> Dict:
"""Run evaluation using verifiers' native machinery."""
num_examples = (
self.config.max_eval_items if self.config.max_eval_items > 0 else -1
)
Args:
item: Dict with 'question' and 'answer' keys
Returns:
Dict with evaluation results or None if failed
"""
question = item["question"]
answer = item["answer"]
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": question},
]
# Build API call parameters
kwargs = {
"messages": messages,
"temperature": self.config.temperature,
"max_tokens": self.config.max_token_length,
"n": 1,
}
response_text = ""
for attempt in range(self.config.max_retries):
try:
# Direct API call (no ManagedServer) - eval doesn't need token tracking
response = await self.server.chat_completion(**kwargs)
response_text = response.choices[0].message.content or ""
if len(response_text) >= self.config.min_response_length:
break
except Exception as e:
if self.config.full_debug:
print(f" API error (attempt {attempt + 1}): {e}")
if attempt < self.config.max_retries - 1:
await asyncio.sleep(self.config.retry_delay)
continue
if not response_text:
return None
# Build completion messages for scoring
completion_messages = messages + [
{"role": "assistant", "content": response_text}
]
# Parse answer
answer_parsed = self.parser.parse_answer(completion=response_text)
# Score using reward funcs (async functions need await)
# Use signature introspection to pass only required params (like verifiers does)
rewards = []
for i, func in enumerate(self.reward_funcs):
try:
# Build merged dict of all possible parameters
class_objects = self.rubric_class_objects[i]
merged = {
"completion": completion_messages,
"answer": answer,
"prompt": question,
}
merged.update(class_objects) # Adds parser, etc.
# Filter to only params the function accepts
sig = inspect.signature(func)
if any(p.kind == p.VAR_KEYWORD for p in sig.parameters.values()):
# Function accepts **kwargs, pass everything
kwargs = merged
else:
# Only pass params in signature
kwargs = {k: v for k, v in merged.items() if k in sig.parameters}
result = func(**kwargs)
# Reward functions may be async coroutines
if asyncio.iscoroutine(result):
reward = await result
else:
reward = result
reward = float(reward)
except Exception as e:
if self.config.full_debug:
print(f" Reward func {func.__name__} error: {e}")
reward = 0.0
rewards.append(reward)
weighted_rewards = [r * self.reward_scales[j] for j, r in enumerate(rewards)]
score = sum(weighted_rewards)
if self.config.full_debug:
print("\n--- Item ---")
print(f"Question: {question[:100]}...")
print(f"Gold answer: {answer}")
print(f"Model parsed: {answer_parsed}")
print(f"Rewards: {rewards}")
print(f"Score: {score}")
return {
"question": question,
"gold_answer": answer,
"response": response_text,
"model_parsed": str(answer_parsed) if answer_parsed else None,
"rewards": rewards,
"weighted_rewards": weighted_rewards,
"score": score,
"correct": bool(score > 0),
}
async def evaluate(self, *args, **kwargs) -> Dict:
"""Run the full evaluation."""
print(f"\n{'=' * 60}")
print(f"Starting Verifiers Evaluation: {self.config.vf_env_name}")
print(f"{'=' * 60}")
print(f" Total questions: {len(self.eval_items)}")
print(f" Model: {self._get_model_name()}")
print(f" Temperature: {self.config.temperature}")
print(f" Max concurrent: {self.config.max_concurrent}")
print(f"{'=' * 60}\n")
start_time = time.time()
# Run sequentially to avoid signal/threading issues with math_verify parser
# The parser uses signals for timeouts which only work in main thread
from tqdm import tqdm
# Create OpenAI client from atropos server config
client = self._get_openai_client()
model = self._get_model_name()
results = []
for item in tqdm(self.eval_items, desc="Evaluating"):
result = await self.rollout_and_score(item)
results.append(result)
# Filter out failed results
valid_results = [r for r in results if r is not None]
if not valid_results:
print("Warning: No valid evaluation results obtained")
return {"error": "No valid results", "accuracy": 0.0}
# Let verifiers handle everything: rollouts + scoring
results = await self.vf_env.evaluate(
client=client,
model=model,
sampling_args={
"temperature": self.config.temperature,
"max_tokens": self.config.max_token_length,
},
num_examples=num_examples,
max_concurrent=self.config.max_concurrent,
save_results=False,
)
end_time = time.time()
# Calculate metrics
total = len(valid_results)
scores = [r["score"] for r in valid_results]
correct = sum(1 for r in valid_results if r["correct"])
# Extract metrics from verifiers output
rewards = results["reward"]
per_func_metrics = results["metrics"] # dict of func_name -> list[float]
prompts = results["prompt"]
completions = results["completion"]
answers = results["answer"]
avg_score = sum(scores) / total if total > 0 else 0.0
total = len(rewards)
correct = sum(1 for r in rewards if r > 0)
avg_score = sum(rewards) / total if total > 0 else 0.0
accuracy = correct / total if total > 0 else 0.0
# Per-reward function breakdown
reward_breakdown = {}
for i, weight in enumerate(self.reward_weights):
func_rewards = [r["rewards"][i] for r in valid_results]
reward_breakdown[f"reward_func_{i}"] = {
"weight": weight,
"avg": sum(func_rewards) / len(func_rewards),
"correct": sum(1 for r in func_rewards if r > 0),
}
for func_name, values in per_func_metrics.items():
if values:
reward_breakdown[func_name] = {
"avg": sum(values) / len(values),
"correct": sum(1 for v in values if v > 0),
}
metrics = {
"avg_score": avg_score,
@ -366,22 +254,32 @@ class VerifiersEvaluationEnv(BaseEnv):
)
print(f"{'=' * 60}\n")
# Log to evaluate_log
samples = [
{
"messages": [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": r["question"]},
{"role": "assistant", "content": r["response"]},
],
"question": r["question"],
"gold_answer": r["gold_answer"],
"model_parsed": r["model_parsed"],
"score": r["score"],
"correct": r["correct"],
}
for r in valid_results
]
# Log to evaluate_log (atropos's logging system)
system_prompt = self.vf_env.system_prompt or ""
samples = []
for i in range(min(total, 100)): # Limit samples for logging
prompt_msgs = prompts[i] if isinstance(prompts[i], list) else []
completion_msgs = completions[i] if completions[i] else []
# Build full message list
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(prompt_msgs)
if isinstance(completion_msgs, list):
messages.extend(completion_msgs)
samples.append(
{
"messages": messages,
"gold_answer": answers[i] if i < len(answers) else "",
"score": rewards[i],
"correct": rewards[i] > 0,
"metrics": {
k: v[i] for k, v in per_func_metrics.items() if i < len(v)
},
}
)
await self.evaluate_log(
metrics={"accuracy": accuracy, "avg_score": avg_score},
@ -430,13 +328,11 @@ class VerifiersEvaluationEnv(BaseEnv):
# Required abstract method implementations (stubs for evaluation-only mode)
async def get_next_item(self) -> Optional[Dict]:
"""Not used in evaluation mode."""
raise NotImplementedError("get_next_item not supported in evaluation-only mode")
return None
async def collect_trajectories(self, item) -> Tuple[List, List]:
async def collect_trajectories(self, item) -> Tuple[List, List]: # noqa: ARG002
"""Not used in evaluation mode."""
raise NotImplementedError(
"collect_trajectories not supported in evaluation-only mode"
)
return [], []
if __name__ == "__main__":