mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-26 17:13:09 +00:00
clean up eval, pin verifiers version
This commit is contained in:
parent
d98bc6d9fc
commit
24b4488c60
3 changed files with 202 additions and 244 deletions
|
|
@ -4,6 +4,9 @@ Verifiers Evaluation Environment for Atropos
|
|||
This environment evaluates models using Prime Intellect's Verifiers library.
|
||||
It supports any environment registered with the Verifiers ecosystem.
|
||||
|
||||
Uses verifiers' native rollout and scoring machinery - just pass an OpenAI-compatible
|
||||
client and verifiers handles generation, parsing, and scoring.
|
||||
|
||||
To install a Verifiers/Prime environment:
|
||||
1. uv tool install prime
|
||||
2. prime login
|
||||
|
|
@ -15,18 +18,24 @@ Usage:
|
|||
--env.vf_env_name primeintellect/gsm8k \
|
||||
--openai.model_name gpt-4.1-nano \
|
||||
--openai.api_key $OPENAI_API_KEY
|
||||
|
||||
Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.):
|
||||
python verifiers_eval.py evaluate \
|
||||
--env.vf_env_name primeintellect/gsm8k \
|
||||
--openai.model_name Qwen/Qwen2.5-7B-Instruct \
|
||||
--openai.base_url http://localhost:8000/v1
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import verifiers as vf
|
||||
import wandb
|
||||
from openai import AsyncOpenAI
|
||||
from pydantic import Field
|
||||
|
||||
import wandb
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
|
|
@ -36,14 +45,12 @@ from atroposlib.envs.base import (
|
|||
|
||||
# Patch math_verify timeout to work in async context
|
||||
# The signal-based timeout doesn't work in non-main threads (asyncio event loop)
|
||||
def _no_signal_timeout(timeout_seconds: int):
|
||||
def _no_signal_timeout(_timeout_seconds: int):
|
||||
"""Replacement timeout decorator that doesn't use signals."""
|
||||
|
||||
def decorator(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
# Just call the function without timeout
|
||||
# This is safe because we're in an async context with our own timeouts
|
||||
# timeout_seconds is intentionally unused - we're replacing the timeout logic
|
||||
# Just call the function without timeout - safe in async context
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
|
@ -67,41 +74,47 @@ except ImportError:
|
|||
class VerifiersEvaluationConfig(BaseEnvConfig):
|
||||
"""Configuration for Verifiers evaluation environment."""
|
||||
|
||||
# Verifiers environment
|
||||
vf_env_name: str = Field(
|
||||
default="",
|
||||
description="Verifiers environment name (e.g., primeintellect/gsm8k)",
|
||||
)
|
||||
env_args: dict = Field(
|
||||
env_args: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Additional arguments for verifiers environment",
|
||||
)
|
||||
|
||||
temperature: float = Field(
|
||||
default=0.0, description="Temperature for generation (0.0 for deterministic)"
|
||||
default=0.0,
|
||||
description="Temperature for generation (0.0 for deterministic)",
|
||||
)
|
||||
max_eval_items: int = Field(
|
||||
default=-1,
|
||||
description="Maximum number of items to evaluate (-1 for all)",
|
||||
)
|
||||
max_concurrent: int = Field(
|
||||
default=64,
|
||||
description="Maximum concurrent requests to the model",
|
||||
)
|
||||
|
||||
max_retries: int = Field(
|
||||
default=3, description="Maximum retries for failed API calls"
|
||||
)
|
||||
retry_delay: float = Field(
|
||||
default=1.0, description="Delay between retries in seconds"
|
||||
)
|
||||
min_response_length: int = Field(
|
||||
default=1, description="Minimum response length to consider valid"
|
||||
)
|
||||
full_debug: bool = Field(default=False, description="Enable full debug output")
|
||||
max_eval_items: int = Field(
|
||||
default=-1, description="Maximum number of items to evaluate (-1 for all)"
|
||||
)
|
||||
# Override BaseEnvConfig defaults for evaluation
|
||||
group_size: int = 1
|
||||
max_num_workers: int = 1024
|
||||
max_eval_workers: int = 256
|
||||
max_num_workers_per_node: int = 128
|
||||
use_wandb: bool = True
|
||||
rollout_server_url: str = "http://localhost:8000"
|
||||
total_steps: int = 1
|
||||
steps_per_eval: int = 1
|
||||
wandb_name: str = "verifiers_eval"
|
||||
|
||||
|
||||
class VerifiersEvaluationEnv(BaseEnv):
|
||||
"""
|
||||
Verifiers Evaluation Environment.
|
||||
|
||||
Evaluates models using Prime Intellect's Verifiers library rubrics.
|
||||
Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, etc.)
|
||||
Evaluates models using Prime Intellect's Verifiers library.
|
||||
Uses verifiers' native rollout and scoring machinery.
|
||||
|
||||
Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.)
|
||||
"""
|
||||
|
||||
name = "verifiers_evaluation"
|
||||
|
|
@ -117,39 +130,10 @@ class VerifiersEvaluationEnv(BaseEnv):
|
|||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.config: VerifiersEvaluationConfig = config
|
||||
|
||||
# Load verifiers environment
|
||||
self.vf_env = vf.load_environment(config.vf_env_name, **config.env_args)
|
||||
self.rubric = self.vf_env.rubric
|
||||
|
||||
# Extract rubric components from RubricGroup
|
||||
# RubricGroup.funcs is empty - need to collect from individual rubrics
|
||||
self.parser = self.rubric.parser
|
||||
self.reward_funcs: List[Callable] = []
|
||||
self.reward_weights: List[float] = []
|
||||
self.rubric_class_objects: List[Dict[str, Any]] = [] # class_objects per func
|
||||
|
||||
if hasattr(self.rubric, "rubrics"):
|
||||
# RubricGroup: collect from all individual rubrics
|
||||
for rubric in self.rubric.rubrics:
|
||||
class_objects = getattr(rubric, "class_objects", {})
|
||||
for func, weight in zip(rubric.funcs, rubric.weights):
|
||||
self.reward_funcs.append(func)
|
||||
self.reward_weights.append(weight)
|
||||
self.rubric_class_objects.append(class_objects)
|
||||
else:
|
||||
# Single Rubric
|
||||
self.reward_funcs = self.rubric.funcs
|
||||
self.reward_weights = self.rubric.weights
|
||||
class_objects = getattr(self.rubric, "class_objects", {})
|
||||
self.rubric_class_objects = [class_objects] * len(self.rubric.funcs)
|
||||
|
||||
total_weight = sum(self.reward_weights) if self.reward_weights else 1.0
|
||||
self.reward_scales = [weight / total_weight for weight in self.reward_weights]
|
||||
self.system_prompt = self.vf_env.system_prompt
|
||||
|
||||
# Tracking
|
||||
self.eval_items: List[Dict] = []
|
||||
self._dataset_loaded = False
|
||||
# Get reward function names for metrics reporting
|
||||
self.reward_func_names = self.vf_env.rubric._get_reward_func_names()
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[VerifiersEvaluationConfig, List[APIServerConfig]]:
|
||||
|
|
@ -166,183 +150,87 @@ class VerifiersEvaluationEnv(BaseEnv):
|
|||
]
|
||||
return env_config, server_configs
|
||||
|
||||
def _get_openai_client(self) -> AsyncOpenAI:
|
||||
"""Create AsyncOpenAI client from first server config."""
|
||||
server = self.server.servers[0]
|
||||
config = server.config
|
||||
return AsyncOpenAI(
|
||||
api_key=config.api_key or "x",
|
||||
base_url=config.base_url,
|
||||
timeout=config.timeout,
|
||||
)
|
||||
|
||||
def _get_model_name(self) -> str:
|
||||
"""Get model name from first server config."""
|
||||
return self.server.servers[0].config.model_name
|
||||
|
||||
async def setup(self) -> None:
|
||||
"""Initialize the environment and load datasets."""
|
||||
if not self._dataset_loaded:
|
||||
# Load datasets from verifiers environment
|
||||
test_data = self.vf_env.get_eval_dataset()
|
||||
self.eval_items = test_data.select_columns(["question", "answer"]).to_list()
|
||||
|
||||
# Limit items if max_eval_items is set
|
||||
if self.config.max_eval_items > 0:
|
||||
self.eval_items = self.eval_items[: self.config.max_eval_items]
|
||||
|
||||
self._dataset_loaded = True
|
||||
"""Initialize the environment."""
|
||||
num_eval = len(self.vf_env.get_eval_dataset())
|
||||
if self.config.max_eval_items > 0:
|
||||
num_eval = min(num_eval, self.config.max_eval_items)
|
||||
|
||||
print("\nVerifiers Evaluation Setup:")
|
||||
print(f" Environment: {self.config.vf_env_name}")
|
||||
print(f" Reward functions: {len(self.reward_funcs)}")
|
||||
print(f" Reward weights: {self.reward_weights}")
|
||||
print(f" Loaded {len(self.eval_items)} evaluation items")
|
||||
print(f" Reward functions: {self.reward_func_names}")
|
||||
print(f" Evaluation items: {num_eval}")
|
||||
print(f" Max concurrent: {self.config.max_concurrent}")
|
||||
|
||||
async def rollout_and_score(self, item: Dict) -> Optional[Dict]:
|
||||
"""
|
||||
Run evaluation on a single item and return the result.
|
||||
async def evaluate(self) -> Dict:
|
||||
"""Run evaluation using verifiers' native machinery."""
|
||||
num_examples = (
|
||||
self.config.max_eval_items if self.config.max_eval_items > 0 else -1
|
||||
)
|
||||
|
||||
Args:
|
||||
item: Dict with 'question' and 'answer' keys
|
||||
|
||||
Returns:
|
||||
Dict with evaluation results or None if failed
|
||||
"""
|
||||
question = item["question"]
|
||||
answer = item["answer"]
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": self.system_prompt},
|
||||
{"role": "user", "content": question},
|
||||
]
|
||||
|
||||
# Build API call parameters
|
||||
kwargs = {
|
||||
"messages": messages,
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens": self.config.max_token_length,
|
||||
"n": 1,
|
||||
}
|
||||
|
||||
response_text = ""
|
||||
for attempt in range(self.config.max_retries):
|
||||
try:
|
||||
# Direct API call (no ManagedServer) - eval doesn't need token tracking
|
||||
response = await self.server.chat_completion(**kwargs)
|
||||
response_text = response.choices[0].message.content or ""
|
||||
|
||||
if len(response_text) >= self.config.min_response_length:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
if self.config.full_debug:
|
||||
print(f" API error (attempt {attempt + 1}): {e}")
|
||||
if attempt < self.config.max_retries - 1:
|
||||
await asyncio.sleep(self.config.retry_delay)
|
||||
continue
|
||||
|
||||
if not response_text:
|
||||
return None
|
||||
|
||||
# Build completion messages for scoring
|
||||
completion_messages = messages + [
|
||||
{"role": "assistant", "content": response_text}
|
||||
]
|
||||
|
||||
# Parse answer
|
||||
answer_parsed = self.parser.parse_answer(completion=response_text)
|
||||
|
||||
# Score using reward funcs (async functions need await)
|
||||
# Use signature introspection to pass only required params (like verifiers does)
|
||||
rewards = []
|
||||
for i, func in enumerate(self.reward_funcs):
|
||||
try:
|
||||
# Build merged dict of all possible parameters
|
||||
class_objects = self.rubric_class_objects[i]
|
||||
merged = {
|
||||
"completion": completion_messages,
|
||||
"answer": answer,
|
||||
"prompt": question,
|
||||
}
|
||||
merged.update(class_objects) # Adds parser, etc.
|
||||
|
||||
# Filter to only params the function accepts
|
||||
sig = inspect.signature(func)
|
||||
if any(p.kind == p.VAR_KEYWORD for p in sig.parameters.values()):
|
||||
# Function accepts **kwargs, pass everything
|
||||
kwargs = merged
|
||||
else:
|
||||
# Only pass params in signature
|
||||
kwargs = {k: v for k, v in merged.items() if k in sig.parameters}
|
||||
|
||||
result = func(**kwargs)
|
||||
# Reward functions may be async coroutines
|
||||
if asyncio.iscoroutine(result):
|
||||
reward = await result
|
||||
else:
|
||||
reward = result
|
||||
reward = float(reward)
|
||||
except Exception as e:
|
||||
if self.config.full_debug:
|
||||
print(f" Reward func {func.__name__} error: {e}")
|
||||
reward = 0.0
|
||||
rewards.append(reward)
|
||||
|
||||
weighted_rewards = [r * self.reward_scales[j] for j, r in enumerate(rewards)]
|
||||
score = sum(weighted_rewards)
|
||||
|
||||
if self.config.full_debug:
|
||||
print("\n--- Item ---")
|
||||
print(f"Question: {question[:100]}...")
|
||||
print(f"Gold answer: {answer}")
|
||||
print(f"Model parsed: {answer_parsed}")
|
||||
print(f"Rewards: {rewards}")
|
||||
print(f"Score: {score}")
|
||||
|
||||
return {
|
||||
"question": question,
|
||||
"gold_answer": answer,
|
||||
"response": response_text,
|
||||
"model_parsed": str(answer_parsed) if answer_parsed else None,
|
||||
"rewards": rewards,
|
||||
"weighted_rewards": weighted_rewards,
|
||||
"score": score,
|
||||
"correct": bool(score > 0),
|
||||
}
|
||||
|
||||
async def evaluate(self, *args, **kwargs) -> Dict:
|
||||
"""Run the full evaluation."""
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Starting Verifiers Evaluation: {self.config.vf_env_name}")
|
||||
print(f"{'=' * 60}")
|
||||
print(f" Total questions: {len(self.eval_items)}")
|
||||
print(f" Model: {self._get_model_name()}")
|
||||
print(f" Temperature: {self.config.temperature}")
|
||||
print(f" Max concurrent: {self.config.max_concurrent}")
|
||||
print(f"{'=' * 60}\n")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Run sequentially to avoid signal/threading issues with math_verify parser
|
||||
# The parser uses signals for timeouts which only work in main thread
|
||||
from tqdm import tqdm
|
||||
# Create OpenAI client from atropos server config
|
||||
client = self._get_openai_client()
|
||||
model = self._get_model_name()
|
||||
|
||||
results = []
|
||||
for item in tqdm(self.eval_items, desc="Evaluating"):
|
||||
result = await self.rollout_and_score(item)
|
||||
results.append(result)
|
||||
|
||||
# Filter out failed results
|
||||
valid_results = [r for r in results if r is not None]
|
||||
|
||||
if not valid_results:
|
||||
print("Warning: No valid evaluation results obtained")
|
||||
return {"error": "No valid results", "accuracy": 0.0}
|
||||
# Let verifiers handle everything: rollouts + scoring
|
||||
results = await self.vf_env.evaluate(
|
||||
client=client,
|
||||
model=model,
|
||||
sampling_args={
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens": self.config.max_token_length,
|
||||
},
|
||||
num_examples=num_examples,
|
||||
max_concurrent=self.config.max_concurrent,
|
||||
save_results=False,
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# Calculate metrics
|
||||
total = len(valid_results)
|
||||
scores = [r["score"] for r in valid_results]
|
||||
correct = sum(1 for r in valid_results if r["correct"])
|
||||
# Extract metrics from verifiers output
|
||||
rewards = results["reward"]
|
||||
per_func_metrics = results["metrics"] # dict of func_name -> list[float]
|
||||
prompts = results["prompt"]
|
||||
completions = results["completion"]
|
||||
answers = results["answer"]
|
||||
|
||||
avg_score = sum(scores) / total if total > 0 else 0.0
|
||||
total = len(rewards)
|
||||
correct = sum(1 for r in rewards if r > 0)
|
||||
avg_score = sum(rewards) / total if total > 0 else 0.0
|
||||
accuracy = correct / total if total > 0 else 0.0
|
||||
|
||||
# Per-reward function breakdown
|
||||
reward_breakdown = {}
|
||||
for i, weight in enumerate(self.reward_weights):
|
||||
func_rewards = [r["rewards"][i] for r in valid_results]
|
||||
reward_breakdown[f"reward_func_{i}"] = {
|
||||
"weight": weight,
|
||||
"avg": sum(func_rewards) / len(func_rewards),
|
||||
"correct": sum(1 for r in func_rewards if r > 0),
|
||||
}
|
||||
for func_name, values in per_func_metrics.items():
|
||||
if values:
|
||||
reward_breakdown[func_name] = {
|
||||
"avg": sum(values) / len(values),
|
||||
"correct": sum(1 for v in values if v > 0),
|
||||
}
|
||||
|
||||
metrics = {
|
||||
"avg_score": avg_score,
|
||||
|
|
@ -366,22 +254,32 @@ class VerifiersEvaluationEnv(BaseEnv):
|
|||
)
|
||||
print(f"{'=' * 60}\n")
|
||||
|
||||
# Log to evaluate_log
|
||||
samples = [
|
||||
{
|
||||
"messages": [
|
||||
{"role": "system", "content": self.system_prompt},
|
||||
{"role": "user", "content": r["question"]},
|
||||
{"role": "assistant", "content": r["response"]},
|
||||
],
|
||||
"question": r["question"],
|
||||
"gold_answer": r["gold_answer"],
|
||||
"model_parsed": r["model_parsed"],
|
||||
"score": r["score"],
|
||||
"correct": r["correct"],
|
||||
}
|
||||
for r in valid_results
|
||||
]
|
||||
# Log to evaluate_log (atropos's logging system)
|
||||
system_prompt = self.vf_env.system_prompt or ""
|
||||
samples = []
|
||||
for i in range(min(total, 100)): # Limit samples for logging
|
||||
prompt_msgs = prompts[i] if isinstance(prompts[i], list) else []
|
||||
completion_msgs = completions[i] if completions[i] else []
|
||||
|
||||
# Build full message list
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.extend(prompt_msgs)
|
||||
if isinstance(completion_msgs, list):
|
||||
messages.extend(completion_msgs)
|
||||
|
||||
samples.append(
|
||||
{
|
||||
"messages": messages,
|
||||
"gold_answer": answers[i] if i < len(answers) else "",
|
||||
"score": rewards[i],
|
||||
"correct": rewards[i] > 0,
|
||||
"metrics": {
|
||||
k: v[i] for k, v in per_func_metrics.items() if i < len(v)
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
await self.evaluate_log(
|
||||
metrics={"accuracy": accuracy, "avg_score": avg_score},
|
||||
|
|
@ -430,13 +328,11 @@ class VerifiersEvaluationEnv(BaseEnv):
|
|||
# Required abstract method implementations (stubs for evaluation-only mode)
|
||||
async def get_next_item(self) -> Optional[Dict]:
|
||||
"""Not used in evaluation mode."""
|
||||
raise NotImplementedError("get_next_item not supported in evaluation-only mode")
|
||||
return None
|
||||
|
||||
async def collect_trajectories(self, item) -> Tuple[List, List]:
|
||||
async def collect_trajectories(self, item) -> Tuple[List, List]: # noqa: ARG002
|
||||
"""Not used in evaluation mode."""
|
||||
raise NotImplementedError(
|
||||
"collect_trajectories not supported in evaluation-only mode"
|
||||
)
|
||||
return [], []
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue