mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
319 lines
11 KiB
Python
319 lines
11 KiB
Python
"""
|
|
Verifiers Evaluation Environment for Atropos
|
|
|
|
This environment evaluates models using Prime Intellect's Verifiers library.
|
|
It supports any environment registered with the Verifiers ecosystem.
|
|
|
|
Uses verifiers' native rollout and scoring machinery - just pass an OpenAI-compatible
|
|
client and verifiers handles generation, parsing, and scoring.
|
|
|
|
To install a Verifiers/Prime environment:
|
|
1. uv tool install prime
|
|
2. prime login
|
|
3. prime env install will/wordle (or any owner/environment)
|
|
Docs: https://docs.primeintellect.ai/tutorials-environments/install
|
|
|
|
Usage:
|
|
# Evaluate with local server
|
|
python verifiers_eval.py evaluate \
|
|
--env.vf_env_name "primeintellect/gsm8k" \
|
|
--env.max_eval_items 100 \
|
|
--openai.model_name "Qwen/Qwen2.5-7B-Instruct" \
|
|
--openai.base_url "http://localhost:8000/v1"
|
|
|
|
# Evaluate with OpenAI
|
|
python verifiers_eval.py evaluate \
|
|
--env.vf_env_name "primeintellect/gsm8k" \
|
|
--env.max_eval_items 50 \
|
|
--openai.model_name "gpt-4o" \
|
|
--openai.api_key "$OPENAI_API_KEY" \
|
|
--openai.base_url "https://api.openai.com/v1"
|
|
|
|
Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.)
|
|
"""
|
|
|
|
import json
|
|
import time
|
|
from typing import Any, Dict, List, Tuple
|
|
|
|
import verifiers as vf
|
|
|
|
from atroposlib.envs.base import (
|
|
APIServerConfig,
|
|
BaseEnv,
|
|
BaseEnvConfig,
|
|
ScoredDataGroup,
|
|
)
|
|
|
|
# Import ManagedServerAdapter from shared location
|
|
from atroposlib.envs.server_handling.managed_server import ManagedServerAdapter
|
|
|
|
# Patch math_verify timeout to work in async context
|
|
# The signal-based timeout doesn't work in non-main threads (asyncio event loop)
|
|
|
|
|
|
def _no_signal_timeout(
|
|
_timeout_seconds: int | None = None, *, timeout_seconds: int | None = None
|
|
):
|
|
"""Replacement timeout decorator that doesn't use signals.
|
|
|
|
Accepts both positional arg (timeout(5)) and keyword arg (timeout(timeout_seconds=5)).
|
|
"""
|
|
# Silence unused parameter warnings - these match the original API signature
|
|
del _timeout_seconds, timeout_seconds
|
|
|
|
def decorator(func):
|
|
def wrapper(*args, **kwargs):
|
|
# Just call the function without timeout - safe in async context
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
try:
|
|
import math_verify.grader
|
|
import math_verify.parser
|
|
import math_verify.utils
|
|
|
|
# Patch all modules that use the timeout decorator
|
|
math_verify.utils.timeout = _no_signal_timeout
|
|
math_verify.parser.timeout = _no_signal_timeout
|
|
math_verify.grader.timeout = _no_signal_timeout
|
|
except ImportError:
|
|
pass # math_verify not installed
|
|
|
|
|
|
class VfEvalConfig(BaseEnvConfig):
|
|
"""Configuration for Verifiers evaluation environment."""
|
|
|
|
vf_env_name: str = "primeintellect/gsm8k"
|
|
env_args: str = "{}" # JSON string for environment-specific args
|
|
eval_temperature: float = 0.0
|
|
eval_max_tokens: int = 2048
|
|
max_eval_items: int = -1 # -1 means evaluate all items
|
|
max_concurrent: int = 64
|
|
|
|
# Override BaseEnvConfig defaults for eval mode
|
|
group_size: int = 1
|
|
total_steps: int = 1
|
|
steps_per_eval: int = 1
|
|
use_wandb: bool = True
|
|
|
|
def get_env_args(self) -> Dict[str, Any]:
|
|
"""Parse env_args JSON string into dict."""
|
|
if isinstance(self.env_args, dict):
|
|
return self.env_args
|
|
return json.loads(self.env_args)
|
|
|
|
|
|
class VerifiersEvalEnv(BaseEnv):
|
|
"""
|
|
Verifiers Evaluation Environment using BaseEnv pattern.
|
|
|
|
Uses verifiers' native batch evaluation for efficiency,
|
|
with BaseEnv's standardized logging via evaluate_log().
|
|
|
|
Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.)
|
|
"""
|
|
|
|
name = "verifiers_eval"
|
|
env_config_cls = VfEvalConfig # type: ignore[assignment]
|
|
|
|
@classmethod
|
|
def config_init(cls) -> Tuple[VfEvalConfig, List[APIServerConfig]]:
|
|
"""Return default configurations."""
|
|
env_config = VfEvalConfig(
|
|
tokenizer_name="Qwen/Qwen2.5-1.5B-Instruct",
|
|
vf_env_name="primeintellect/gsm8k",
|
|
wandb_name="verifiers_eval",
|
|
)
|
|
server_configs = [
|
|
APIServerConfig(
|
|
model_name="Qwen/Qwen2.5-1.5B-Instruct",
|
|
base_url="http://localhost:9001/v1",
|
|
api_key="x",
|
|
),
|
|
]
|
|
return env_config, server_configs
|
|
|
|
async def setup(self):
|
|
"""Load verifiers environment and dataset."""
|
|
env_args = self.config.get_env_args()
|
|
self.vf_env = vf.load_environment(self.config.vf_env_name, **env_args)
|
|
self.reward_func_names = self.vf_env.rubric._get_reward_func_names()
|
|
|
|
# Load evaluation dataset
|
|
dataset = self.vf_env.get_eval_dataset()
|
|
if self.config.max_eval_items > 0:
|
|
n = min(len(dataset), self.config.max_eval_items)
|
|
dataset = dataset.select(range(n))
|
|
self.data = dataset.to_list()
|
|
|
|
async def get_next_item(self):
|
|
"""Not used in eval mode - stub implementation."""
|
|
return None
|
|
|
|
async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, list]:
|
|
"""Not used in eval mode - stub implementation."""
|
|
_ = item # unused in eval mode
|
|
return (
|
|
ScoredDataGroup(
|
|
tokens=[],
|
|
masks=[],
|
|
scores=[],
|
|
messages=[],
|
|
inference_logprobs=[],
|
|
advantages=[],
|
|
ref_logprobs=[],
|
|
generation_params=None,
|
|
group_overrides=None,
|
|
overrides=[],
|
|
images=[],
|
|
),
|
|
[],
|
|
)
|
|
|
|
async def evaluate(self, *args, **kwargs) -> Dict[str, float]:
|
|
"""Run evaluation using verifiers with ManagedServer."""
|
|
start_time = time.time()
|
|
|
|
# Get server config
|
|
if hasattr(self.server, "servers") and self.server.servers:
|
|
server_config = self.server.servers[0].config
|
|
else:
|
|
server_config = self.server_configs[0]
|
|
|
|
model_name = server_config.model_name
|
|
|
|
print(f"\n{'=' * 60}")
|
|
print(f"Verifiers Evaluation: {self.config.vf_env_name}")
|
|
print(f"{'=' * 60}")
|
|
print(f" Model: {model_name}")
|
|
print(f" Items: {len(self.data)}")
|
|
print(f" Reward functions: {self.reward_func_names}")
|
|
print(f" Temperature: {self.config.eval_temperature}")
|
|
print(f" Max concurrent: {self.config.max_concurrent}")
|
|
print(f"{'=' * 60}\n")
|
|
|
|
num_examples = (
|
|
self.config.max_eval_items if self.config.max_eval_items > 0 else -1
|
|
)
|
|
|
|
# Use ManagedServer for automatic token/logprob tracking
|
|
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
|
# Create adapter that looks like AsyncOpenAI for verifiers
|
|
adapter = ManagedServerAdapter(
|
|
managed_server=managed,
|
|
base_url=server_config.base_url,
|
|
)
|
|
|
|
# Use verifiers' batch evaluation
|
|
results = await self.vf_env.evaluate(
|
|
client=adapter,
|
|
model=model_name,
|
|
sampling_args={
|
|
"temperature": self.config.eval_temperature,
|
|
"max_tokens": self.config.eval_max_tokens,
|
|
},
|
|
num_examples=num_examples,
|
|
max_concurrent=self.config.max_concurrent,
|
|
save_results=False,
|
|
)
|
|
|
|
end_time = time.time()
|
|
|
|
# Extract from verifiers output
|
|
rewards = results["reward"]
|
|
per_func_metrics = results["metrics"]
|
|
prompts = results["prompt"]
|
|
completions = results["completion"]
|
|
answers = results["answer"]
|
|
|
|
total = len(rewards)
|
|
correct = sum(1 for r in rewards if r > 0)
|
|
avg_score = sum(rewards) / total if total > 0 else 0.0
|
|
accuracy = correct / total if total > 0 else 0.0
|
|
|
|
# Per-reward function breakdown
|
|
reward_breakdown = {}
|
|
for func_name, values in per_func_metrics.items():
|
|
if values:
|
|
reward_breakdown[func_name] = {
|
|
"avg": sum(values) / len(values),
|
|
"correct": sum(1 for v in values if v > 0),
|
|
}
|
|
|
|
metrics = {
|
|
"accuracy": accuracy,
|
|
"avg_score": avg_score,
|
|
}
|
|
|
|
# Add per-function metrics
|
|
for func_name, data in reward_breakdown.items():
|
|
metrics[f"{func_name}_avg"] = data["avg"]
|
|
metrics[f"{func_name}_correct_rate"] = data["correct"] / total
|
|
|
|
# Print results summary
|
|
print(f"\n{'=' * 60}")
|
|
print("Verifiers Evaluation Results")
|
|
print(f"{'=' * 60}")
|
|
print(f" Average Score: {avg_score:.4f}")
|
|
print(f" Accuracy: {accuracy:.2%} ({correct}/{total})")
|
|
print(f" Time: {end_time - start_time:.1f}s")
|
|
print("\n Per-Reward Function:")
|
|
for name, data in reward_breakdown.items():
|
|
print(
|
|
f" {name}: avg={data['avg']:.4f}, correct={data['correct']}/{total}"
|
|
)
|
|
print(f"{'=' * 60}\n")
|
|
|
|
# Build samples for logging
|
|
system_prompt = self.vf_env.system_prompt or ""
|
|
samples = []
|
|
for i in range(min(total, 100)): # Limit samples for logging
|
|
prompt_msgs = prompts[i] if isinstance(prompts[i], list) else []
|
|
completion_msgs = completions[i] if completions[i] else []
|
|
|
|
messages = []
|
|
if system_prompt:
|
|
messages.append({"role": "system", "content": system_prompt})
|
|
messages.extend(prompt_msgs)
|
|
if isinstance(completion_msgs, list):
|
|
messages.extend(completion_msgs)
|
|
|
|
samples.append(
|
|
{
|
|
"messages": messages,
|
|
"gold_answer": answers[i] if i < len(answers) else "",
|
|
"score": rewards[i],
|
|
"correct": rewards[i] > 0,
|
|
"metrics": {
|
|
k: v[i] for k, v in per_func_metrics.items() if i < len(v)
|
|
},
|
|
}
|
|
)
|
|
|
|
# Use BaseEnv's evaluate_log
|
|
task_name = f"VerifiersEval@{self.config.vf_env_name.replace('/', '_')}"
|
|
await self.evaluate_log(
|
|
metrics=metrics,
|
|
task_name=task_name,
|
|
model_name=model_name,
|
|
start_time=start_time,
|
|
end_time=end_time,
|
|
generation_parameters={
|
|
"temperature": self.config.eval_temperature,
|
|
"max_tokens": self.config.eval_max_tokens,
|
|
"n": 1,
|
|
},
|
|
samples=samples,
|
|
verbose=True,
|
|
)
|
|
|
|
return metrics
|
|
|
|
|
|
if __name__ == "__main__":
|
|
VerifiersEvalEnv.cli()
|