mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
revert to using evalbase
This commit is contained in:
parent
5a20abdce7
commit
4ba69d3a80
2 changed files with 204 additions and 156 deletions
|
|
@ -14,39 +14,33 @@ To install a Verifiers/Prime environment:
|
|||
Docs: https://docs.primeintellect.ai/tutorials-environments/install
|
||||
|
||||
Usage:
|
||||
# Evaluate with local server
|
||||
python verifiers_eval.py evaluate \
|
||||
--env.vf_env_name "primeintellect/gsm8k" \
|
||||
--env.max_eval_items 100 \
|
||||
--openai.model_name "Qwen/Qwen2.5-7B-Instruct" \
|
||||
--openai.base_url "http://localhost:8000/v1"
|
||||
|
||||
# Evaluate with OpenAI
|
||||
python verifiers_eval.py evaluate \
|
||||
--env.vf_env_name "primeintellect/gsm8k" \
|
||||
--env.max_eval_items 50 \
|
||||
--openai.model_name "gpt-4o" \
|
||||
--openai.api_key "$OPENAI_API_KEY" \
|
||||
--openai.base_url "https://api.openai.com/v1"
|
||||
python verifiers_eval.py \
|
||||
--server-url https://api.openai.com/v1 \
|
||||
--model-name gpt-4o \
|
||||
--vf-env-name primeintellect/gsm8k \
|
||||
--max-eval-items 50
|
||||
|
||||
# Evaluate with local server
|
||||
python verifiers_eval.py \
|
||||
--server-url http://localhost:8000/v1 \
|
||||
--model-name Qwen/Qwen2.5-7B-Instruct \
|
||||
--vf-env-name primeintellect/gsm8k
|
||||
|
||||
Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, Dict
|
||||
|
||||
import verifiers as vf
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
|
||||
# Import ManagedServerAdapter from shared location
|
||||
from atroposlib.envs.eval import EvalBase, evaluate_log
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServerAdapter
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
|
||||
from atroposlib.envs.server_handling.server_manager import ServerManager
|
||||
|
||||
# Patch math_verify timeout to work in async context
|
||||
# The signal-based timeout doesn't work in non-main threads (asyncio event loop)
|
||||
|
|
@ -85,128 +79,74 @@ except ImportError:
|
|||
pass # math_verify not installed
|
||||
|
||||
|
||||
class VfEvalConfig(BaseEnvConfig):
|
||||
"""Configuration for Verifiers evaluation environment."""
|
||||
|
||||
vf_env_name: str = "primeintellect/gsm8k"
|
||||
env_args: str = "{}" # JSON string for environment-specific args
|
||||
eval_temperature: float = 0.0
|
||||
eval_max_tokens: int = 2048
|
||||
max_eval_items: int = -1 # -1 means evaluate all items
|
||||
max_concurrent: int = 64
|
||||
|
||||
# Override BaseEnvConfig defaults for eval mode
|
||||
group_size: int = 1
|
||||
total_steps: int = 1
|
||||
steps_per_eval: int = 1
|
||||
use_wandb: bool = True
|
||||
|
||||
def get_env_args(self) -> Dict[str, Any]:
|
||||
"""Parse env_args JSON string into dict."""
|
||||
if isinstance(self.env_args, dict):
|
||||
return self.env_args
|
||||
return json.loads(self.env_args)
|
||||
|
||||
|
||||
class VerifiersEvalEnv(BaseEnv):
|
||||
class VerifiersEval(EvalBase):
|
||||
"""
|
||||
Verifiers Evaluation Environment using BaseEnv pattern.
|
||||
Verifiers Evaluation using EvalBase pattern.
|
||||
|
||||
Uses verifiers' native batch evaluation for efficiency,
|
||||
with BaseEnv's standardized logging via evaluate_log().
|
||||
with ManagedServerAdapter for token/logprob tracking.
|
||||
|
||||
Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.)
|
||||
"""
|
||||
|
||||
name = "verifiers_eval"
|
||||
env_config_cls = VfEvalConfig # type: ignore[assignment]
|
||||
def __init__(
|
||||
self,
|
||||
vf_env_name: str = "primeintellect/gsm8k",
|
||||
env_args: str = "{}",
|
||||
temperature: float = 0.0,
|
||||
max_tokens: int = 2048,
|
||||
max_eval_items: int = -1,
|
||||
max_concurrent: int = 64,
|
||||
**kwargs,
|
||||
):
|
||||
self.vf_env_name = vf_env_name
|
||||
self.env_args_str = env_args
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.max_eval_items = max_eval_items
|
||||
self.max_concurrent = max_concurrent
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[VfEvalConfig, List[APIServerConfig]]:
|
||||
"""Return default configurations."""
|
||||
env_config = VfEvalConfig(
|
||||
tokenizer_name="Qwen/Qwen2.5-1.5B-Instruct",
|
||||
vf_env_name="primeintellect/gsm8k",
|
||||
wandb_name="verifiers_eval",
|
||||
)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="Qwen/Qwen2.5-1.5B-Instruct",
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
),
|
||||
]
|
||||
return env_config, server_configs
|
||||
def get_env_args(self) -> Dict[str, Any]:
|
||||
"""Parse env_args JSON string into dict."""
|
||||
if isinstance(self.env_args_str, dict):
|
||||
return self.env_args_str
|
||||
return json.loads(self.env_args_str)
|
||||
|
||||
async def setup(self):
|
||||
def setup_data(self) -> list:
|
||||
"""Load verifiers environment and dataset."""
|
||||
env_args = self.config.get_env_args()
|
||||
self.vf_env = vf.load_environment(self.config.vf_env_name, **env_args)
|
||||
env_args = self.get_env_args()
|
||||
self.vf_env = vf.load_environment(self.vf_env_name, **env_args)
|
||||
self.reward_func_names = self.vf_env.rubric._get_reward_func_names()
|
||||
|
||||
# Load evaluation dataset
|
||||
dataset = self.vf_env.get_eval_dataset()
|
||||
if self.config.max_eval_items > 0:
|
||||
n = min(len(dataset), self.config.max_eval_items)
|
||||
if self.max_eval_items > 0:
|
||||
n = min(len(dataset), self.max_eval_items)
|
||||
dataset = dataset.select(range(n))
|
||||
self.data = dataset.to_list()
|
||||
return dataset.to_list()
|
||||
|
||||
async def get_next_item(self):
|
||||
"""Not used in eval mode - stub implementation."""
|
||||
return None
|
||||
async def run_item(self, server: ServerManager, data_item: dict):
|
||||
"""Not used - verifiers uses batch evaluation in __call__."""
|
||||
# This won't be called since we override __call__
|
||||
raise NotImplementedError("VerifiersEval uses batch evaluation in __call__")
|
||||
|
||||
async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, list]:
|
||||
"""Not used in eval mode - stub implementation."""
|
||||
_ = item # unused in eval mode
|
||||
return (
|
||||
ScoredDataGroup(
|
||||
tokens=[],
|
||||
masks=[],
|
||||
scores=[],
|
||||
messages=[],
|
||||
inference_logprobs=[],
|
||||
advantages=[],
|
||||
ref_logprobs=[],
|
||||
generation_params=None,
|
||||
group_overrides=None,
|
||||
overrides=[],
|
||||
images=[],
|
||||
),
|
||||
[],
|
||||
)
|
||||
|
||||
async def evaluate(self, *args, **kwargs) -> Dict[str, float]:
|
||||
"""Run evaluation using verifiers with ManagedServer."""
|
||||
async def __call__(self, server_manager: ServerManager):
|
||||
"""Run evaluation using verifiers with ManagedServerAdapter."""
|
||||
start_time = time.time()
|
||||
|
||||
# Get server config
|
||||
if hasattr(self.server, "servers") and self.server.servers:
|
||||
server_config = self.server.servers[0].config
|
||||
else:
|
||||
server_config = self.server_configs[0]
|
||||
server = server_manager.servers[0]
|
||||
model_name = server.config.model_name
|
||||
|
||||
model_name = server_config.model_name
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Verifiers Evaluation: {self.config.vf_env_name}")
|
||||
print(f"{'=' * 60}")
|
||||
print(f" Model: {model_name}")
|
||||
print(f" Items: {len(self.data)}")
|
||||
print(f" Reward functions: {self.reward_func_names}")
|
||||
print(f" Temperature: {self.config.eval_temperature}")
|
||||
print(f" Max concurrent: {self.config.max_concurrent}")
|
||||
print(f"{'=' * 60}\n")
|
||||
|
||||
num_examples = (
|
||||
self.config.max_eval_items if self.config.max_eval_items > 0 else -1
|
||||
)
|
||||
num_examples = self.max_eval_items if self.max_eval_items > 0 else -1
|
||||
|
||||
# Use ManagedServer for automatic token/logprob tracking
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
async with server_manager.managed_server(tokenizer=None) as managed:
|
||||
# Create adapter that looks like AsyncOpenAI for verifiers
|
||||
adapter = ManagedServerAdapter(
|
||||
managed_server=managed,
|
||||
base_url=server_config.base_url,
|
||||
base_url=server.config.base_url,
|
||||
)
|
||||
|
||||
# Use verifiers' batch evaluation
|
||||
|
|
@ -214,11 +154,11 @@ class VerifiersEvalEnv(BaseEnv):
|
|||
client=adapter,
|
||||
model=model_name,
|
||||
sampling_args={
|
||||
"temperature": self.config.eval_temperature,
|
||||
"max_tokens": self.config.eval_max_tokens,
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
},
|
||||
num_examples=num_examples,
|
||||
max_concurrent=self.config.max_concurrent,
|
||||
max_concurrent=self.max_concurrent,
|
||||
save_results=False,
|
||||
)
|
||||
|
||||
|
|
@ -295,25 +235,135 @@ class VerifiersEvalEnv(BaseEnv):
|
|||
}
|
||||
)
|
||||
|
||||
# Use BaseEnv's evaluate_log
|
||||
task_name = f"VerifiersEval@{self.config.vf_env_name.replace('/', '_')}"
|
||||
await self.evaluate_log(
|
||||
# Log results
|
||||
task_name = f"VerifiersEval@{self.vf_env_name.replace('/', '_')}"
|
||||
evaluate_log(
|
||||
metrics=metrics,
|
||||
eval_dir=getattr(self, "eval_dir", None),
|
||||
task_name=task_name,
|
||||
model_name=model_name,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
generation_parameters={
|
||||
"temperature": self.config.eval_temperature,
|
||||
"max_tokens": self.config.eval_max_tokens,
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
"n": 1,
|
||||
},
|
||||
samples=samples,
|
||||
verbose=True,
|
||||
verbose=getattr(self, "verbose", True),
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run verifiers evaluation with argparse CLI."""
|
||||
import os
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Evaluate models using Prime Intellect's Verifiers library"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--server-url",
|
||||
type=str,
|
||||
default="http://localhost:8000/v1",
|
||||
help="URL of the inference server (default: http://localhost:8000/v1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model name to evaluate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="API key (default: uses OPENAI_API_KEY env var)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vf-env-name",
|
||||
type=str,
|
||||
default="primeintellect/gsm8k",
|
||||
help="Verifiers environment name (default: primeintellect/gsm8k)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--env-args",
|
||||
type=str,
|
||||
default="{}",
|
||||
help="JSON string of environment-specific args (default: {})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Generation temperature (default: 0.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Maximum tokens per completion (default: 2048)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-eval-items",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Maximum items to evaluate, -1 for all (default: -1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-concurrent",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Maximum concurrent requests (default: 64)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to save evaluation results (default: None)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Print verbose output (default: True)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get API key from args or environment
|
||||
api_key = args.api_key or os.environ.get("OPENAI_API_KEY", "dummy")
|
||||
|
||||
# Create evaluation instance
|
||||
eval_env = VerifiersEval(
|
||||
vf_env_name=args.vf_env_name,
|
||||
env_args=args.env_args,
|
||||
temperature=args.temperature,
|
||||
max_tokens=args.max_tokens,
|
||||
max_eval_items=args.max_eval_items,
|
||||
max_concurrent=args.max_concurrent,
|
||||
eval_dir=args.eval_dir,
|
||||
verbose=args.verbose,
|
||||
)
|
||||
|
||||
# Create server manager
|
||||
server_manager = ServerManager(
|
||||
configs=[
|
||||
APIServerConfig(
|
||||
api_key=api_key,
|
||||
base_url=args.server_url,
|
||||
model_name=args.model_name,
|
||||
health_check=False,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# Run evaluation
|
||||
return await eval_env(server_manager)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
VerifiersEvalEnv.cli()
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue