revert to using evalbase

This commit is contained in:
balyan.sid@gmail.com 2026-01-23 23:41:32 +05:30
parent 5a20abdce7
commit 4ba69d3a80
2 changed files with 204 additions and 156 deletions

View file

@ -14,39 +14,33 @@ To install a Verifiers/Prime environment:
Docs: https://docs.primeintellect.ai/tutorials-environments/install
Usage:
# Evaluate with local server
python verifiers_eval.py evaluate \
--env.vf_env_name "primeintellect/gsm8k" \
--env.max_eval_items 100 \
--openai.model_name "Qwen/Qwen2.5-7B-Instruct" \
--openai.base_url "http://localhost:8000/v1"
# Evaluate with OpenAI
python verifiers_eval.py evaluate \
--env.vf_env_name "primeintellect/gsm8k" \
--env.max_eval_items 50 \
--openai.model_name "gpt-4o" \
--openai.api_key "$OPENAI_API_KEY" \
--openai.base_url "https://api.openai.com/v1"
python verifiers_eval.py \
--server-url https://api.openai.com/v1 \
--model-name gpt-4o \
--vf-env-name primeintellect/gsm8k \
--max-eval-items 50
# Evaluate with local server
python verifiers_eval.py \
--server-url http://localhost:8000/v1 \
--model-name Qwen/Qwen2.5-7B-Instruct \
--vf-env-name primeintellect/gsm8k
Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.)
"""
import argparse
import json
import time
from typing import Any, Dict, List, Tuple
from typing import Any, Dict
import verifiers as vf
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
ScoredDataGroup,
)
# Import ManagedServerAdapter from shared location
from atroposlib.envs.eval import EvalBase, evaluate_log
from atroposlib.envs.server_handling.managed_server import ManagedServerAdapter
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
from atroposlib.envs.server_handling.server_manager import ServerManager
# Patch math_verify timeout to work in async context
# The signal-based timeout doesn't work in non-main threads (asyncio event loop)
@ -85,128 +79,74 @@ except ImportError:
pass # math_verify not installed
class VfEvalConfig(BaseEnvConfig):
"""Configuration for Verifiers evaluation environment."""
vf_env_name: str = "primeintellect/gsm8k"
env_args: str = "{}" # JSON string for environment-specific args
eval_temperature: float = 0.0
eval_max_tokens: int = 2048
max_eval_items: int = -1 # -1 means evaluate all items
max_concurrent: int = 64
# Override BaseEnvConfig defaults for eval mode
group_size: int = 1
total_steps: int = 1
steps_per_eval: int = 1
use_wandb: bool = True
def get_env_args(self) -> Dict[str, Any]:
"""Parse env_args JSON string into dict."""
if isinstance(self.env_args, dict):
return self.env_args
return json.loads(self.env_args)
class VerifiersEvalEnv(BaseEnv):
class VerifiersEval(EvalBase):
"""
Verifiers Evaluation Environment using BaseEnv pattern.
Verifiers Evaluation using EvalBase pattern.
Uses verifiers' native batch evaluation for efficiency,
with BaseEnv's standardized logging via evaluate_log().
with ManagedServerAdapter for token/logprob tracking.
Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.)
"""
name = "verifiers_eval"
env_config_cls = VfEvalConfig # type: ignore[assignment]
def __init__(
self,
vf_env_name: str = "primeintellect/gsm8k",
env_args: str = "{}",
temperature: float = 0.0,
max_tokens: int = 2048,
max_eval_items: int = -1,
max_concurrent: int = 64,
**kwargs,
):
self.vf_env_name = vf_env_name
self.env_args_str = env_args
self.temperature = temperature
self.max_tokens = max_tokens
self.max_eval_items = max_eval_items
self.max_concurrent = max_concurrent
super().__init__(**kwargs)
@classmethod
def config_init(cls) -> Tuple[VfEvalConfig, List[APIServerConfig]]:
"""Return default configurations."""
env_config = VfEvalConfig(
tokenizer_name="Qwen/Qwen2.5-1.5B-Instruct",
vf_env_name="primeintellect/gsm8k",
wandb_name="verifiers_eval",
)
server_configs = [
APIServerConfig(
model_name="Qwen/Qwen2.5-1.5B-Instruct",
base_url="http://localhost:9001/v1",
api_key="x",
),
]
return env_config, server_configs
def get_env_args(self) -> Dict[str, Any]:
"""Parse env_args JSON string into dict."""
if isinstance(self.env_args_str, dict):
return self.env_args_str
return json.loads(self.env_args_str)
async def setup(self):
def setup_data(self) -> list:
"""Load verifiers environment and dataset."""
env_args = self.config.get_env_args()
self.vf_env = vf.load_environment(self.config.vf_env_name, **env_args)
env_args = self.get_env_args()
self.vf_env = vf.load_environment(self.vf_env_name, **env_args)
self.reward_func_names = self.vf_env.rubric._get_reward_func_names()
# Load evaluation dataset
dataset = self.vf_env.get_eval_dataset()
if self.config.max_eval_items > 0:
n = min(len(dataset), self.config.max_eval_items)
if self.max_eval_items > 0:
n = min(len(dataset), self.max_eval_items)
dataset = dataset.select(range(n))
self.data = dataset.to_list()
return dataset.to_list()
async def get_next_item(self):
"""Not used in eval mode - stub implementation."""
return None
async def run_item(self, server: ServerManager, data_item: dict):
"""Not used - verifiers uses batch evaluation in __call__."""
# This won't be called since we override __call__
raise NotImplementedError("VerifiersEval uses batch evaluation in __call__")
async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, list]:
"""Not used in eval mode - stub implementation."""
_ = item # unused in eval mode
return (
ScoredDataGroup(
tokens=[],
masks=[],
scores=[],
messages=[],
inference_logprobs=[],
advantages=[],
ref_logprobs=[],
generation_params=None,
group_overrides=None,
overrides=[],
images=[],
),
[],
)
async def evaluate(self, *args, **kwargs) -> Dict[str, float]:
"""Run evaluation using verifiers with ManagedServer."""
async def __call__(self, server_manager: ServerManager):
"""Run evaluation using verifiers with ManagedServerAdapter."""
start_time = time.time()
# Get server config
if hasattr(self.server, "servers") and self.server.servers:
server_config = self.server.servers[0].config
else:
server_config = self.server_configs[0]
server = server_manager.servers[0]
model_name = server.config.model_name
model_name = server_config.model_name
print(f"\n{'=' * 60}")
print(f"Verifiers Evaluation: {self.config.vf_env_name}")
print(f"{'=' * 60}")
print(f" Model: {model_name}")
print(f" Items: {len(self.data)}")
print(f" Reward functions: {self.reward_func_names}")
print(f" Temperature: {self.config.eval_temperature}")
print(f" Max concurrent: {self.config.max_concurrent}")
print(f"{'=' * 60}\n")
num_examples = (
self.config.max_eval_items if self.config.max_eval_items > 0 else -1
)
num_examples = self.max_eval_items if self.max_eval_items > 0 else -1
# Use ManagedServer for automatic token/logprob tracking
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
async with server_manager.managed_server(tokenizer=None) as managed:
# Create adapter that looks like AsyncOpenAI for verifiers
adapter = ManagedServerAdapter(
managed_server=managed,
base_url=server_config.base_url,
base_url=server.config.base_url,
)
# Use verifiers' batch evaluation
@ -214,11 +154,11 @@ class VerifiersEvalEnv(BaseEnv):
client=adapter,
model=model_name,
sampling_args={
"temperature": self.config.eval_temperature,
"max_tokens": self.config.eval_max_tokens,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
},
num_examples=num_examples,
max_concurrent=self.config.max_concurrent,
max_concurrent=self.max_concurrent,
save_results=False,
)
@ -295,25 +235,135 @@ class VerifiersEvalEnv(BaseEnv):
}
)
# Use BaseEnv's evaluate_log
task_name = f"VerifiersEval@{self.config.vf_env_name.replace('/', '_')}"
await self.evaluate_log(
# Log results
task_name = f"VerifiersEval@{self.vf_env_name.replace('/', '_')}"
evaluate_log(
metrics=metrics,
eval_dir=getattr(self, "eval_dir", None),
task_name=task_name,
model_name=model_name,
start_time=start_time,
end_time=end_time,
generation_parameters={
"temperature": self.config.eval_temperature,
"max_tokens": self.config.eval_max_tokens,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"n": 1,
},
samples=samples,
verbose=True,
verbose=getattr(self, "verbose", True),
)
return metrics
async def main():
"""Run verifiers evaluation with argparse CLI."""
import os
parser = argparse.ArgumentParser(
description="Evaluate models using Prime Intellect's Verifiers library"
)
parser.add_argument(
"--server-url",
type=str,
default="http://localhost:8000/v1",
help="URL of the inference server (default: http://localhost:8000/v1)",
)
parser.add_argument(
"--model-name",
type=str,
required=True,
help="Model name to evaluate",
)
parser.add_argument(
"--api-key",
type=str,
default=None,
help="API key (default: uses OPENAI_API_KEY env var)",
)
parser.add_argument(
"--vf-env-name",
type=str,
default="primeintellect/gsm8k",
help="Verifiers environment name (default: primeintellect/gsm8k)",
)
parser.add_argument(
"--env-args",
type=str,
default="{}",
help="JSON string of environment-specific args (default: {})",
)
parser.add_argument(
"--temperature",
type=float,
default=0.0,
help="Generation temperature (default: 0.0)",
)
parser.add_argument(
"--max-tokens",
type=int,
default=2048,
help="Maximum tokens per completion (default: 2048)",
)
parser.add_argument(
"--max-eval-items",
type=int,
default=-1,
help="Maximum items to evaluate, -1 for all (default: -1)",
)
parser.add_argument(
"--max-concurrent",
type=int,
default=64,
help="Maximum concurrent requests (default: 64)",
)
parser.add_argument(
"--eval-dir",
type=str,
default=None,
help="Directory to save evaluation results (default: None)",
)
parser.add_argument(
"--verbose",
action="store_true",
default=True,
help="Print verbose output (default: True)",
)
args = parser.parse_args()
# Get API key from args or environment
api_key = args.api_key or os.environ.get("OPENAI_API_KEY", "dummy")
# Create evaluation instance
eval_env = VerifiersEval(
vf_env_name=args.vf_env_name,
env_args=args.env_args,
temperature=args.temperature,
max_tokens=args.max_tokens,
max_eval_items=args.max_eval_items,
max_concurrent=args.max_concurrent,
eval_dir=args.eval_dir,
verbose=args.verbose,
)
# Create server manager
server_manager = ServerManager(
configs=[
APIServerConfig(
api_key=api_key,
base_url=args.server_url,
model_name=args.model_name,
health_check=False,
),
]
)
# Run evaluation
return await eval_env(server_manager)
if __name__ == "__main__":
VerifiersEvalEnv.cli()
import asyncio
asyncio.run(main())