switch eval to use managed server adapter impl. moved managed server

adapter
This commit is contained in:
balyan.sid@gmail.com 2026-01-23 23:25:19 +05:30
parent 32d12c05c3
commit 5a20abdce7
4 changed files with 253 additions and 270 deletions

View file

@ -14,31 +14,53 @@ To install a Verifiers/Prime environment:
Docs: https://docs.primeintellect.ai/tutorials-environments/install
Usage:
python verifiers_eval.py \
--server-url http://localhost:8000/v1 \
--model-name Qwen/Qwen2.5-7B-Instruct \
--vf-env-name primeintellect/gsm8k \
--max-eval-items 100
# Evaluate with local server
python verifiers_eval.py evaluate \
--env.vf_env_name "primeintellect/gsm8k" \
--env.max_eval_items 100 \
--openai.model_name "Qwen/Qwen2.5-7B-Instruct" \
--openai.base_url "http://localhost:8000/v1"
# Evaluate with OpenAI
python verifiers_eval.py evaluate \
--env.vf_env_name "primeintellect/gsm8k" \
--env.max_eval_items 50 \
--openai.model_name "gpt-4o" \
--openai.api_key "$OPENAI_API_KEY" \
--openai.base_url "https://api.openai.com/v1"
Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.)
"""
import argparse
import asyncio
import json
import time
from typing import Tuple
from typing import Any, Dict, List, Tuple
import verifiers as vf
from openai import AsyncOpenAI
from atroposlib.envs.eval import EvalBase, evaluate_log
from atroposlib.envs.server_handling.server_manager import ServerManager
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
ScoredDataGroup,
)
# Import ManagedServerAdapter from shared location
from atroposlib.envs.server_handling.managed_server import ManagedServerAdapter
# Patch math_verify timeout to work in async context
# The signal-based timeout doesn't work in non-main threads (asyncio event loop)
def _no_signal_timeout(_timeout_seconds: int):
"""Replacement timeout decorator that doesn't use signals."""
def _no_signal_timeout(
_timeout_seconds: int | None = None, *, timeout_seconds: int | None = None
):
"""Replacement timeout decorator that doesn't use signals.
Accepts both positional arg (timeout(5)) and keyword arg (timeout(timeout_seconds=5)).
"""
# Silence unused parameter warnings - these match the original API signature
del _timeout_seconds, timeout_seconds
def decorator(func):
def wrapper(*args, **kwargs):
@ -63,108 +85,143 @@ except ImportError:
pass # math_verify not installed
class VerifiersEval(EvalBase):
class VfEvalConfig(BaseEnvConfig):
"""Configuration for Verifiers evaluation environment."""
vf_env_name: str = "primeintellect/gsm8k"
env_args: str = "{}" # JSON string for environment-specific args
eval_temperature: float = 0.0
eval_max_tokens: int = 2048
max_eval_items: int = -1 # -1 means evaluate all items
max_concurrent: int = 64
# Override BaseEnvConfig defaults for eval mode
group_size: int = 1
total_steps: int = 1
steps_per_eval: int = 1
use_wandb: bool = True
def get_env_args(self) -> Dict[str, Any]:
"""Parse env_args JSON string into dict."""
if isinstance(self.env_args, dict):
return self.env_args
return json.loads(self.env_args)
class VerifiersEvalEnv(BaseEnv):
"""
Verifiers Evaluation using EvalBase pattern.
Verifiers Evaluation Environment using BaseEnv pattern.
Uses verifiers' native batch evaluation for efficiency,
with EvalBase's standardized logging via evaluate_log().
with BaseEnv's standardized logging via evaluate_log().
Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.)
"""
def __init__(
self,
vf_env_name: str = "primeintellect/gsm8k",
env_args: dict = None,
temperature: float = 0.0,
max_tokens: int = 2048,
max_eval_items: int = -1,
max_concurrent: int = 64,
eval_dir: str = None,
verbose: bool = True,
**kwargs,
):
self.vf_env_name = vf_env_name
self.env_args = env_args or {}
self.temperature = temperature
self.max_tokens = max_tokens
self.max_eval_items = max_eval_items
self.max_concurrent = max_concurrent
name = "verifiers_eval"
env_config_cls = VfEvalConfig # type: ignore[assignment]
# Load verifiers environment
self.vf_env = vf.load_environment(vf_env_name, **self.env_args)
@classmethod
def config_init(cls) -> Tuple[VfEvalConfig, List[APIServerConfig]]:
"""Return default configurations."""
env_config = VfEvalConfig(
tokenizer_name="Qwen/Qwen2.5-1.5B-Instruct",
vf_env_name="primeintellect/gsm8k",
wandb_name="verifiers_eval",
)
server_configs = [
APIServerConfig(
model_name="Qwen/Qwen2.5-1.5B-Instruct",
base_url="http://localhost:9001/v1",
api_key="x",
),
]
return env_config, server_configs
async def setup(self):
"""Load verifiers environment and dataset."""
env_args = self.config.get_env_args()
self.vf_env = vf.load_environment(self.config.vf_env_name, **env_args)
self.reward_func_names = self.vf_env.rubric._get_reward_func_names()
# Initialize EvalBase (calls setup_data)
super().__init__(
eval_dir=eval_dir,
verbose=verbose,
**kwargs,
)
def get_generation_params(self):
"""Generation params for logging."""
return {
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"n": 1,
}
def setup_data(self) -> list:
"""Return evaluation dataset from verifiers environment."""
# Load evaluation dataset
dataset = self.vf_env.get_eval_dataset()
if self.max_eval_items > 0:
n = min(len(dataset), self.max_eval_items)
if self.config.max_eval_items > 0:
n = min(len(dataset), self.config.max_eval_items)
dataset = dataset.select(range(n))
return dataset.to_list()
self.data = dataset.to_list()
async def run_item(
self, server: ServerManager, data_item: dict # noqa: ARG002
) -> Tuple[dict, list]:
"""Not used - we override __call__ for batch evaluation."""
raise NotImplementedError(
"VerifiersEval uses batch evaluation via __call__, not per-item run_item"
async def get_next_item(self):
"""Not used in eval mode - stub implementation."""
return None
async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, list]:
"""Not used in eval mode - stub implementation."""
_ = item # unused in eval mode
return (
ScoredDataGroup(
tokens=[],
masks=[],
scores=[],
messages=[],
inference_logprobs=[],
advantages=[],
ref_logprobs=[],
generation_params=None,
group_overrides=None,
overrides=[],
images=[],
),
[],
)
async def __call__(self, server_manager: ServerManager):
"""Run evaluation using verifiers' native batch machinery."""
async def evaluate(self, *args, **kwargs) -> Dict[str, float]:
"""Run evaluation using verifiers with ManagedServer."""
start_time = time.time()
# Create OpenAI client from server config
server = server_manager.servers[0]
client = AsyncOpenAI(
api_key=server.config.api_key or "x",
base_url=server.config.base_url,
timeout=getattr(server.config, "timeout", 600),
)
model = server.config.model_name
# Get server config
if hasattr(self.server, "servers") and self.server.servers:
server_config = self.server.servers[0].config
else:
server_config = self.server_configs[0]
model_name = server_config.model_name
print(f"\n{'=' * 60}")
print(f"Verifiers Evaluation: {self.vf_env_name}")
print(f"Verifiers Evaluation: {self.config.vf_env_name}")
print(f"{'=' * 60}")
print(f" Model: {model}")
print(f" Model: {model_name}")
print(f" Items: {len(self.data)}")
print(f" Reward functions: {self.reward_func_names}")
print(f" Temperature: {self.temperature}")
print(f" Max concurrent: {self.max_concurrent}")
print(f" Temperature: {self.config.eval_temperature}")
print(f" Max concurrent: {self.config.max_concurrent}")
print(f"{'=' * 60}\n")
num_examples = self.max_eval_items if self.max_eval_items > 0 else -1
# Use verifiers' batch evaluation
results = await self.vf_env.evaluate(
client=client,
model=model,
sampling_args={
"temperature": self.temperature,
"max_tokens": self.max_tokens,
},
num_examples=num_examples,
max_concurrent=self.max_concurrent,
save_results=False,
num_examples = (
self.config.max_eval_items if self.config.max_eval_items > 0 else -1
)
# Use ManagedServer for automatic token/logprob tracking
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
# Create adapter that looks like AsyncOpenAI for verifiers
adapter = ManagedServerAdapter(
managed_server=managed,
base_url=server_config.base_url,
)
# Use verifiers' batch evaluation
results = await self.vf_env.evaluate(
client=adapter,
model=model_name,
sampling_args={
"temperature": self.config.eval_temperature,
"max_tokens": self.config.eval_max_tokens,
},
num_examples=num_examples,
max_concurrent=self.config.max_concurrent,
save_results=False,
)
end_time = time.time()
# Extract from verifiers output
@ -193,6 +250,11 @@ class VerifiersEval(EvalBase):
"avg_score": avg_score,
}
# Add per-function metrics
for func_name, data in reward_breakdown.items():
metrics[f"{func_name}_avg"] = data["avg"]
metrics[f"{func_name}_correct_rate"] = data["correct"] / total
# Print results summary
print(f"\n{'=' * 60}")
print("Verifiers Evaluation Results")
@ -233,114 +295,25 @@ class VerifiersEval(EvalBase):
}
)
# Use EvalBase's evaluate_log
task_name = f"VerifiersEval@{self.vf_env_name.replace('/', '_')}"
evaluate_log(
# Use BaseEnv's evaluate_log
task_name = f"VerifiersEval@{self.config.vf_env_name.replace('/', '_')}"
await self.evaluate_log(
metrics=metrics,
eval_dir=getattr(self, "eval_dir", None),
task_name=task_name,
model_name=model,
model_name=model_name,
start_time=start_time,
end_time=end_time,
generation_parameters=self.get_generation_params(),
generation_parameters={
"temperature": self.config.eval_temperature,
"max_tokens": self.config.eval_max_tokens,
"n": 1,
},
samples=samples,
verbose=getattr(self, "verbose", False),
verbose=True,
)
return metrics
async def main():
"""CLI entry point for verifiers evaluation."""
import os
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
parser = argparse.ArgumentParser(
description="Evaluate models using Verifiers environments"
)
# Server args (same as eval_runner)
parser.add_argument(
"--server-url",
type=str,
default="http://localhost:8000/v1",
help="URL of the inference server",
)
parser.add_argument(
"--model-name",
type=str,
required=True,
help="Model name to evaluate",
)
parser.add_argument(
"--api-key",
type=str,
default=os.getenv("OPENAI_API_KEY", "x"),
help="API key (defaults to OPENAI_API_KEY env var)",
)
# Verifiers-specific args
parser.add_argument(
"--vf-env-name",
type=str,
default="primeintellect/gsm8k",
help="Verifiers environment name (e.g., primeintellect/gsm8k)",
)
parser.add_argument(
"--max-eval-items",
type=int,
default=-1,
help="Maximum items to evaluate (-1 for all)",
)
parser.add_argument(
"--temperature",
type=float,
default=0.0,
help="Generation temperature",
)
parser.add_argument(
"--max-tokens",
type=int,
default=2048,
help="Maximum tokens per completion",
)
parser.add_argument(
"--max-concurrent",
type=int,
default=64,
help="Maximum concurrent requests",
)
parser.add_argument(
"--eval-dir",
type=str,
default=None,
help="Directory to save evaluation results",
)
args = parser.parse_args()
# Create server manager
server_manager = ServerManager(
configs=[
APIServerConfig(
api_key=args.api_key,
base_url=args.server_url,
model_name=args.model_name,
health_check=False,
),
]
)
# Create and run evaluation
eval_instance = VerifiersEval(
vf_env_name=args.vf_env_name,
max_eval_items=args.max_eval_items,
temperature=args.temperature,
max_tokens=args.max_tokens,
max_concurrent=args.max_concurrent,
eval_dir=args.eval_dir,
verbose=True,
)
return await eval_instance(server_manager)
if __name__ == "__main__":
asyncio.run(main())
VerifiersEvalEnv.cli()