mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Merge pull request #305 from alt-glitch/sid/verifiers
Verifiers Integration
This commit is contained in:
commit
5af29933a7
8 changed files with 949 additions and 1 deletions
|
|
@ -151,6 +151,7 @@ If you're looking to get into developing the repo or using the environments:
|
|||
pip install -e . # for using
|
||||
pip install -e .[dev] # for development
|
||||
pip install -e .[examples] # for running examples
|
||||
pip install -e .[verifiers] # for verifiers integration
|
||||
pip install -e .[all] # for everything
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -509,3 +509,63 @@ class ManagedServer:
|
|||
self.sequences.clear()
|
||||
else:
|
||||
self.current_nodes.clear()
|
||||
|
||||
|
||||
class ManagedServerAdapter:
|
||||
"""
|
||||
Adapter that makes ManagedServer look like AsyncOpenAI for external libraries.
|
||||
|
||||
Implements the subset of AsyncOpenAI interface commonly used:
|
||||
- client.chat.completions.create()
|
||||
- client.completions.create()
|
||||
- client.base_url
|
||||
|
||||
This allows libraries like verifiers to use ManagedServer transparently
|
||||
while still getting automatic token and logprob tracking.
|
||||
"""
|
||||
|
||||
def __init__(self, managed_server: ManagedServer, base_url: str):
|
||||
"""
|
||||
Initialize the adapter.
|
||||
|
||||
Args:
|
||||
managed_server: The ManagedServer instance to wrap
|
||||
base_url: The base URL to expose (for compatibility checks)
|
||||
"""
|
||||
self._managed = managed_server
|
||||
self.base_url = base_url
|
||||
self.chat = self._ChatNamespace(self._managed)
|
||||
self.completions = self._CompletionsNamespace(self._managed)
|
||||
|
||||
class _ChatNamespace:
|
||||
def __init__(self, managed: ManagedServer):
|
||||
self._managed = managed
|
||||
self.completions = ManagedServerAdapter._ChatCompletionsNamespace(managed)
|
||||
|
||||
class _ChatCompletionsNamespace:
|
||||
def __init__(self, managed: ManagedServer):
|
||||
self._managed = managed
|
||||
|
||||
async def create(self, **kwargs):
|
||||
return await self._managed.chat_completion(**kwargs)
|
||||
|
||||
class _CompletionsNamespace:
|
||||
def __init__(self, managed: ManagedServer):
|
||||
self._managed = managed
|
||||
|
||||
async def create(self, **kwargs):
|
||||
return await self._managed.completion(**kwargs)
|
||||
|
||||
async def post(self, path: str, body: dict, cast_to: type):
|
||||
"""Not supported - raises NotImplementedError."""
|
||||
raise NotImplementedError(
|
||||
f"ManagedServerAdapter does not support post() for path '{path}'. "
|
||||
"This is used for vLLM interleaved rollouts. Use standard chat completions."
|
||||
)
|
||||
|
||||
def copy(self, **kwargs):
|
||||
"""Not supported - raises NotImplementedError."""
|
||||
raise NotImplementedError(
|
||||
"ManagedServerAdapter does not support copy(). "
|
||||
"This is used for vLLM tokenization endpoints."
|
||||
)
|
||||
|
|
|
|||
|
|
@ -177,6 +177,8 @@ class VLLMServer(APIServer):
|
|||
prompt_tokens = prompt_tokens[1:]
|
||||
if "max_new_tokens" in kwargs:
|
||||
kwargs["max_tokens"] = kwargs.pop("max_new_tokens")
|
||||
if "max_completion_tokens" in kwargs:
|
||||
kwargs["max_tokens"] = kwargs.pop("max_completion_tokens")
|
||||
if "model" in kwargs:
|
||||
kwargs.pop("model")
|
||||
# Prepare request for VLLM native API
|
||||
|
|
|
|||
|
|
@ -11,7 +11,159 @@ This directory contains various environments for training and evaluating languag
|
|||
|
||||
---
|
||||
|
||||
### Prime Intellect Verifiers Integration
|
||||
|
||||
A flexible environment that integrates with the [Verifiers](https://docs.primeintellect.ai/) ecosystem, allowing you to use any registered Prime environment for RL training, SFT data generation, or evaluation.
|
||||
|
||||
**Files:**
|
||||
- `environments/verifiers_server.py` - Training and SFT data generation
|
||||
- `environments/eval_environments/verifiers_eval.py` - Standalone evaluation
|
||||
|
||||
**Dependencies:**
|
||||
|
||||
- `verifiers` Python package (install via `pip install verifiers` or include in your environment)
|
||||
- Prime CLI for environment management (`uv tool install prime`)
|
||||
- Prime CLI login required (`prime login`)
|
||||
- Environment installation (`prime env install owner/env_name`)
|
||||
|
||||
**Supported Modes:**
|
||||
|
||||
| Mode | File | Description |
|
||||
|------|------|-------------|
|
||||
| `serve` | `verifiers_server.py` | RL training with local inference server (requires ManagedServer for logprobs) |
|
||||
| `process` | `verifiers_server.py` | SFT data generation with ANY API (OpenAI, Claude, local, etc.) |
|
||||
| `evaluate` | `verifiers_server.py` | Quick evaluation using ManagedServer |
|
||||
| `evaluate` | `verifiers_eval.py` | Standalone evaluation with detailed metrics and retry logic |
|
||||
|
||||
**Input Format:**
|
||||
|
||||
- Loaded dynamically from the specified Prime environment via `vf.load_environment()`
|
||||
- Each item contains:
|
||||
- `question`: The problem/prompt
|
||||
- `answer`: The expected answer for verification
|
||||
|
||||
**System Prompt:**
|
||||
|
||||
- Dynamically loaded from the Prime environment's `system_prompt` configuration
|
||||
|
||||
**Reward Function:**
|
||||
|
||||
- Uses the environment's **rubric** system with:
|
||||
- `parser`: Extracts answers from completions (e.g., `parser.parse_answer(completion)`)
|
||||
- `funcs`: List of reward functions that receive `(parser, completion, answer)`
|
||||
- `weights`: Weights for combining reward functions (normalized to sum to 1.0)
|
||||
- Final score is weighted sum of all reward function outputs
|
||||
|
||||
**W&B Metrics Logged (Training - `verifiers_server.py`):**
|
||||
|
||||
| Metric | Description |
|
||||
|--------|-------------|
|
||||
| `train/percent_correct` | Average score from verifiers reward functions (0-1) |
|
||||
| `train/rollouts` | Table of tokenized completions with scores |
|
||||
| `train/completion_lengths_*` | Response length statistics (std, min, max, p95) |
|
||||
| `server/server_0_request_time_*` | API latency metrics (avg, std, 99p) |
|
||||
| `eval/avg_total_score` | Average score on evaluation dataset |
|
||||
|
||||
**Output (Evaluation - `verifiers_eval.py`):**
|
||||
|
||||
Uses `evaluate_log()` from `EvalBase` to output:
|
||||
- Console: Metrics table with accuracy, avg_score, time, and per-reward function breakdown
|
||||
- File: `metrics.json` and `samples.jsonl` (when `--env.data_dir_to_save_evals` is specified)
|
||||
|
||||
**Configuration Options (`VfEnvConfig` for `verifiers_server.py`):**
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `vf_env_name` | str | `""` | Prime environment identifier (e.g., `"will/wordle"`, `"primeintellect/gsm8k"`) |
|
||||
| `env_args` | Dict | `{}` | Additional arguments passed to `vf.load_environment()`. Read environment specific documentation to get these args. |
|
||||
|
||||
**CLI Options (`verifiers_eval.py`):**
|
||||
|
||||
Uses a simple argparse CLI with direct arguments:
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `--server-url` | str | `http://localhost:8000/v1` | URL of the inference server |
|
||||
| `--model-name` | str | (required) | Model name to evaluate |
|
||||
| `--api-key` | str | `$OPENAI_API_KEY` | API key (uses env var if not specified) |
|
||||
| `--vf-env-name` | str | `primeintellect/gsm8k` | Prime environment identifier |
|
||||
| `--temperature` | float | `0.0` | Temperature for generation |
|
||||
| `--max-tokens` | int | `2048` | Maximum tokens per completion |
|
||||
| `--max-eval-items` | int | `-1` | Maximum items to evaluate (-1 for all) |
|
||||
| `--max-concurrent` | int | `64` | Maximum concurrent requests |
|
||||
| `--eval-dir` | str | `None` | Directory to save evaluation results |
|
||||
|
||||
**Usage Examples:**
|
||||
|
||||
```bash
|
||||
# RL Training (requires local vLLM/SGLang server)
|
||||
python verifiers_server.py serve \
|
||||
--env.vf_env_name "will/wordle" \
|
||||
--openai.base_url http://localhost:9001/v1 \
|
||||
--slurm false
|
||||
|
||||
# SFT Data Generation with OpenAI GPT-4o
|
||||
python verifiers_server.py process \
|
||||
--env.vf_env_name "will/wordle" \
|
||||
--env.data_path_to_save_groups gpt4o_sft_data.jsonl \
|
||||
--env.total_steps 100 \
|
||||
--env.group_size 4 \
|
||||
--openai.model_name gpt-4o \
|
||||
--openai.base_url https://api.openai.com/v1
|
||||
|
||||
# SFT Data Generation with local server
|
||||
python verifiers_server.py process \
|
||||
--env.vf_env_name "will/wordle" \
|
||||
--env.data_path_to_save_groups local_sft_data.jsonl \
|
||||
--openai.base_url http://localhost:9001/v1
|
||||
|
||||
# Quick Evaluation via verifiers_server.py
|
||||
python verifiers_server.py evaluate \
|
||||
--env.vf_env_name "will/wordle" \
|
||||
--openai.base_url http://localhost:9001/v1
|
||||
|
||||
# Standalone Evaluation with OpenAI (verifiers_eval.py)
|
||||
python eval_environments/verifiers_eval.py \
|
||||
--server-url https://api.openai.com/v1 \
|
||||
--model-name gpt-4o \
|
||||
--vf-env-name primeintellect/gsm8k
|
||||
|
||||
# Quick test run with limited items
|
||||
python eval_environments/verifiers_eval.py \
|
||||
--server-url https://api.openai.com/v1 \
|
||||
--model-name gpt-4o-mini \
|
||||
--vf-env-name primeintellect/alphabet-sort \
|
||||
--max-eval-items 10
|
||||
|
||||
# Evaluation with local server and results saved
|
||||
python eval_environments/verifiers_eval.py \
|
||||
--server-url http://localhost:9001/v1 \
|
||||
--model-name Qwen/Qwen2.5-7B-Instruct \
|
||||
--vf-env-name primeintellect/gsm8k \
|
||||
--eval-dir ./eval_results
|
||||
```
|
||||
|
||||
**Key Implementation Details:**
|
||||
|
||||
- **RL Training Mode (`serve`)**: Uses `ManagedServer` for proper token/logprob alignment required by policy gradient methods (GRPO, PPO, REINFORCE). Returns `ScoredDataGroup` with `tokens`, `masks`, `scores`, and `inference_logprobs`.
|
||||
- **SFT Datagen Mode (`process`)**: Uses `tokenize_for_trainer` to tokenize API responses with your target model's tokenizer (e.g., GPT-4o responses tokenized for Qwen/Llama). Does NOT require logprobs.
|
||||
- **Evaluation (`verifiers_eval.py`)**: Standalone evaluation script using `EvalBase` with simple argparse CLI. Uses verifiers' native batch evaluation with `ManagedServerAdapter` for token/logprob tracking and outputs results via `evaluate_log()`. Works with any OpenAI-compatible API.
|
||||
|
||||
**Prime Environment Installation:**
|
||||
```bash
|
||||
# Install Prime CLI
|
||||
uv tool install prime
|
||||
|
||||
# Login to Prime
|
||||
prime login
|
||||
|
||||
# Install an environment (e.g., Wordle, GSM8K)
|
||||
prime env install will/wordle
|
||||
prime env install primeintellect/gsm8k
|
||||
|
||||
# List available environments
|
||||
prime env list
|
||||
```
|
||||
|
||||
### Letter Counting Environment (`letter_counting_environment.py`)
|
||||
|
||||
|
|
|
|||
31
environments/configs/verifiers.yaml
Normal file
31
environments/configs/verifiers.yaml
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
# Verifiers environment configuration
|
||||
# Usage: python environments/verifiers_server.py serve --config environments/configs/verifiers.yaml
|
||||
#
|
||||
# For SFT data generation with external API:
|
||||
# python environments/verifiers_server.py process \
|
||||
# --env.vf_env_name primeintellect/gsm8k \
|
||||
# --env.data_path_to_save_groups output.jsonl \
|
||||
# --openai.base_url https://api.openai.com/v1 \
|
||||
# --openai.api_key $OPENAI_API_KEY \
|
||||
# --openai.model_name gpt-4o
|
||||
|
||||
env:
|
||||
vf_env_name: "primeintellect/gsm8k" # Prime Env Hub environment
|
||||
env_args: {}
|
||||
group_size: 8
|
||||
max_token_length: 2048
|
||||
tokenizer_name: "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
rollout_server_url: "http://localhost:8000"
|
||||
use_wandb: true
|
||||
wandb_name: "verifiers"
|
||||
total_steps: 1000
|
||||
batch_size: 4
|
||||
steps_per_eval: 100
|
||||
|
||||
openai:
|
||||
- model_name: "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
base_url: "http://localhost:9001/v1"
|
||||
api_key: "x"
|
||||
|
||||
slurm: false
|
||||
testing: false
|
||||
369
environments/eval_environments/verifiers_eval.py
Normal file
369
environments/eval_environments/verifiers_eval.py
Normal file
|
|
@ -0,0 +1,369 @@
|
|||
"""
|
||||
Verifiers Evaluation Environment for Atropos
|
||||
|
||||
This environment evaluates models using Prime Intellect's Verifiers library.
|
||||
It supports any environment registered with the Verifiers ecosystem.
|
||||
|
||||
Uses verifiers' native rollout and scoring machinery - just pass an OpenAI-compatible
|
||||
client and verifiers handles generation, parsing, and scoring.
|
||||
|
||||
To install a Verifiers/Prime environment:
|
||||
1. uv tool install prime
|
||||
2. prime login
|
||||
3. prime env install will/wordle (or any owner/environment)
|
||||
Docs: https://docs.primeintellect.ai/tutorials-environments/install
|
||||
|
||||
Usage:
|
||||
# Evaluate with OpenAI
|
||||
python verifiers_eval.py \
|
||||
--server-url https://api.openai.com/v1 \
|
||||
--model-name gpt-4o \
|
||||
--vf-env-name primeintellect/gsm8k \
|
||||
--max-eval-items 50
|
||||
|
||||
# Evaluate with local server
|
||||
python verifiers_eval.py \
|
||||
--server-url http://localhost:8000/v1 \
|
||||
--model-name Qwen/Qwen2.5-7B-Instruct \
|
||||
--vf-env-name primeintellect/gsm8k
|
||||
|
||||
Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
import verifiers as vf
|
||||
|
||||
from atroposlib.envs.eval import EvalBase, evaluate_log
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServerAdapter
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
|
||||
from atroposlib.envs.server_handling.server_manager import ServerManager
|
||||
|
||||
# Patch math_verify timeout to work in async context
|
||||
# The signal-based timeout doesn't work in non-main threads (asyncio event loop)
|
||||
|
||||
|
||||
def _no_signal_timeout(
|
||||
_timeout_seconds: int | None = None, *, timeout_seconds: int | None = None
|
||||
):
|
||||
"""Replacement timeout decorator that doesn't use signals.
|
||||
|
||||
Accepts both positional arg (timeout(5)) and keyword arg (timeout(timeout_seconds=5)).
|
||||
"""
|
||||
# Silence unused parameter warnings - these match the original API signature
|
||||
del _timeout_seconds, timeout_seconds
|
||||
|
||||
def decorator(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
# Just call the function without timeout - safe in async context
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
try:
|
||||
import math_verify.grader
|
||||
import math_verify.parser
|
||||
import math_verify.utils
|
||||
|
||||
# Patch all modules that use the timeout decorator
|
||||
math_verify.utils.timeout = _no_signal_timeout
|
||||
math_verify.parser.timeout = _no_signal_timeout
|
||||
math_verify.grader.timeout = _no_signal_timeout
|
||||
except ImportError:
|
||||
pass # math_verify not installed
|
||||
|
||||
|
||||
class VerifiersEval(EvalBase):
|
||||
"""
|
||||
Verifiers Evaluation using EvalBase pattern.
|
||||
|
||||
Uses verifiers' native batch evaluation for efficiency,
|
||||
with ManagedServerAdapter for token/logprob tracking.
|
||||
|
||||
Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vf_env_name: str = "primeintellect/gsm8k",
|
||||
env_args: str = "{}",
|
||||
temperature: float = 0.0,
|
||||
max_tokens: int = 2048,
|
||||
max_eval_items: int = -1,
|
||||
max_concurrent: int = 64,
|
||||
**kwargs,
|
||||
):
|
||||
self.vf_env_name = vf_env_name
|
||||
self.env_args_str = env_args
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.max_eval_items = max_eval_items
|
||||
self.max_concurrent = max_concurrent
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_env_args(self) -> Dict[str, Any]:
|
||||
"""Parse env_args JSON string into dict."""
|
||||
if isinstance(self.env_args_str, dict):
|
||||
return self.env_args_str
|
||||
return json.loads(self.env_args_str)
|
||||
|
||||
def setup_data(self) -> list:
|
||||
"""Load verifiers environment and dataset."""
|
||||
env_args = self.get_env_args()
|
||||
self.vf_env = vf.load_environment(self.vf_env_name, **env_args)
|
||||
self.reward_func_names = self.vf_env.rubric._get_reward_func_names()
|
||||
|
||||
# Load evaluation dataset
|
||||
dataset = self.vf_env.get_eval_dataset()
|
||||
if self.max_eval_items > 0:
|
||||
n = min(len(dataset), self.max_eval_items)
|
||||
dataset = dataset.select(range(n))
|
||||
return dataset.to_list()
|
||||
|
||||
async def run_item(self, server: ServerManager, data_item: dict):
|
||||
"""Not used - verifiers uses batch evaluation in __call__."""
|
||||
# This won't be called since we override __call__
|
||||
raise NotImplementedError("VerifiersEval uses batch evaluation in __call__")
|
||||
|
||||
async def __call__(self, server_manager: ServerManager):
|
||||
"""Run evaluation using verifiers with ManagedServerAdapter."""
|
||||
start_time = time.time()
|
||||
|
||||
# Get server config
|
||||
server = server_manager.servers[0]
|
||||
model_name = server.config.model_name
|
||||
|
||||
num_examples = self.max_eval_items if self.max_eval_items > 0 else -1
|
||||
|
||||
# Use ManagedServer for automatic token/logprob tracking
|
||||
async with server_manager.managed_server(tokenizer=None) as managed:
|
||||
# Create adapter that looks like AsyncOpenAI for verifiers
|
||||
adapter = ManagedServerAdapter(
|
||||
managed_server=managed,
|
||||
base_url=server.config.base_url,
|
||||
)
|
||||
|
||||
# Use verifiers' batch evaluation
|
||||
results = await self.vf_env.evaluate(
|
||||
client=adapter,
|
||||
model=model_name,
|
||||
sampling_args={
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
},
|
||||
num_examples=num_examples,
|
||||
max_concurrent=self.max_concurrent,
|
||||
save_results=False,
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# Extract from verifiers output
|
||||
rewards = results["reward"]
|
||||
per_func_metrics = results["metrics"]
|
||||
prompts = results["prompt"]
|
||||
completions = results["completion"]
|
||||
answers = results["answer"]
|
||||
|
||||
total = len(rewards)
|
||||
correct = sum(1 for r in rewards if r > 0)
|
||||
avg_score = sum(rewards) / total if total > 0 else 0.0
|
||||
accuracy = correct / total if total > 0 else 0.0
|
||||
|
||||
# Per-reward function breakdown
|
||||
reward_breakdown = {}
|
||||
for func_name, values in per_func_metrics.items():
|
||||
if values:
|
||||
reward_breakdown[func_name] = {
|
||||
"avg": sum(values) / len(values),
|
||||
"correct": sum(1 for v in values if v > 0),
|
||||
}
|
||||
|
||||
metrics = {
|
||||
"accuracy": accuracy,
|
||||
"avg_score": avg_score,
|
||||
}
|
||||
|
||||
# Add per-function metrics
|
||||
for func_name, data in reward_breakdown.items():
|
||||
metrics[f"{func_name}_avg"] = data["avg"]
|
||||
metrics[f"{func_name}_correct_rate"] = data["correct"] / total
|
||||
|
||||
# Print results summary
|
||||
print(f"\n{'=' * 60}")
|
||||
print("Verifiers Evaluation Results")
|
||||
print(f"{'=' * 60}")
|
||||
print(f" Average Score: {avg_score:.4f}")
|
||||
print(f" Accuracy: {accuracy:.2%} ({correct}/{total})")
|
||||
print(f" Time: {end_time - start_time:.1f}s")
|
||||
print("\n Per-Reward Function:")
|
||||
for name, data in reward_breakdown.items():
|
||||
print(
|
||||
f" {name}: avg={data['avg']:.4f}, correct={data['correct']}/{total}"
|
||||
)
|
||||
print(f"{'=' * 60}\n")
|
||||
|
||||
# Build samples for logging
|
||||
system_prompt = self.vf_env.system_prompt or ""
|
||||
samples = []
|
||||
for i in range(min(total, 100)): # Limit samples for logging
|
||||
prompt_msgs = prompts[i] if isinstance(prompts[i], list) else []
|
||||
completion_msgs = completions[i] if completions[i] else []
|
||||
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.extend(prompt_msgs)
|
||||
if isinstance(completion_msgs, list):
|
||||
messages.extend(completion_msgs)
|
||||
|
||||
samples.append(
|
||||
{
|
||||
"messages": messages,
|
||||
"gold_answer": answers[i] if i < len(answers) else "",
|
||||
"score": rewards[i],
|
||||
"correct": rewards[i] > 0,
|
||||
"metrics": {
|
||||
k: v[i] for k, v in per_func_metrics.items() if i < len(v)
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Log results
|
||||
task_name = f"VerifiersEval@{self.vf_env_name.replace('/', '_')}"
|
||||
evaluate_log(
|
||||
metrics=metrics,
|
||||
eval_dir=getattr(self, "eval_dir", None),
|
||||
task_name=task_name,
|
||||
model_name=model_name,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
generation_parameters={
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
"n": 1,
|
||||
},
|
||||
samples=samples,
|
||||
verbose=getattr(self, "verbose", True),
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run verifiers evaluation with argparse CLI."""
|
||||
import os
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Evaluate models using Prime Intellect's Verifiers library"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--server-url",
|
||||
type=str,
|
||||
default="http://localhost:8000/v1",
|
||||
help="URL of the inference server (default: http://localhost:8000/v1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model name to evaluate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="API key (default: uses OPENAI_API_KEY env var)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vf-env-name",
|
||||
type=str,
|
||||
default="primeintellect/gsm8k",
|
||||
help="Verifiers environment name (default: primeintellect/gsm8k)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--env-args",
|
||||
type=str,
|
||||
default="{}",
|
||||
help="JSON string of environment-specific args (default: {})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Generation temperature (default: 0.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Maximum tokens per completion (default: 2048)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-eval-items",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Maximum items to evaluate, -1 for all (default: -1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-concurrent",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Maximum concurrent requests (default: 64)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to save evaluation results (default: None)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Print verbose output (default: True)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get API key from args or environment
|
||||
api_key = args.api_key or os.environ.get("OPENAI_API_KEY", "dummy")
|
||||
|
||||
# Create evaluation instance
|
||||
eval_env = VerifiersEval(
|
||||
vf_env_name=args.vf_env_name,
|
||||
env_args=args.env_args,
|
||||
temperature=args.temperature,
|
||||
max_tokens=args.max_tokens,
|
||||
max_eval_items=args.max_eval_items,
|
||||
max_concurrent=args.max_concurrent,
|
||||
eval_dir=args.eval_dir,
|
||||
verbose=args.verbose,
|
||||
)
|
||||
|
||||
# Create server manager
|
||||
server_manager = ServerManager(
|
||||
configs=[
|
||||
APIServerConfig(
|
||||
api_key=api_key,
|
||||
base_url=args.server_url,
|
||||
model_name=args.model_name,
|
||||
health_check=False,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# Run evaluation
|
||||
return await eval_env(server_manager)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
330
environments/verifiers_server.py
Normal file
330
environments/verifiers_server.py
Normal file
|
|
@ -0,0 +1,330 @@
|
|||
"""
|
||||
Verifiers Training Environment for Atropos
|
||||
|
||||
Unified environment that works for both RL training (serve) and SFT data generation (process).
|
||||
Uses vf_env.generate() with ManagedServer (via adapter) for automatic token and logprob tracking.
|
||||
|
||||
Usage:
|
||||
# RL Training (GRPO - no inference logprobs needed)
|
||||
python verifiers_server.py serve \
|
||||
--env.vf_env_name "primeintellect/alphabet-sort" \
|
||||
--openai.base_url http://localhost:9001/v1 \
|
||||
--slurm false
|
||||
|
||||
# SFT Data Generation with OpenAI GPT-4o
|
||||
python verifiers_server.py process \
|
||||
--env.vf_env_name "primeintellect/alphabet-sort" \
|
||||
--env.data_path_to_save_groups gpt4o_sft_data.jsonl \
|
||||
--env.total_steps 100 \
|
||||
--env.group_size 4 \
|
||||
--openai.model_name gpt-4o \
|
||||
--openai.base_url https://api.openai.com/v1
|
||||
|
||||
# SFT Data Generation with local server
|
||||
python verifiers_server.py process \
|
||||
--env.vf_env_name "primeintellect/alphabet-sort" \
|
||||
--env.data_path_to_save_groups local_sft_data.jsonl \
|
||||
--openai.base_url http://localhost:9001/v1
|
||||
|
||||
To install a Verifiers/Prime environment:
|
||||
1. uv tool install prime
|
||||
2. prime login
|
||||
3. prime env install primeintellect/alphabet-sort (or any owner/environment)
|
||||
Docs: https://docs.primeintellect.ai/tutorials-environments/install
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import verifiers as vf
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServerAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VfEnvConfig(BaseEnvConfig):
|
||||
vf_env_name: str = ""
|
||||
env_args: str = "{}"
|
||||
|
||||
def get_env_args(self) -> Dict[str, Any]:
|
||||
"""Parse env_args JSON string into dict."""
|
||||
if isinstance(self.env_args, dict):
|
||||
return self.env_args
|
||||
return json.loads(self.env_args)
|
||||
|
||||
|
||||
class VerifiersEnv(BaseEnv):
|
||||
name = "verifiers"
|
||||
env_config_cls = VfEnvConfig # type: ignore[assignment]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: VfEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm=False,
|
||||
testing=False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
|
||||
# Metrics buffers for wandb logging
|
||||
self.reward_buffer: List[float] = []
|
||||
self.metrics_buffer: Dict[str, List[float]] = defaultdict(list)
|
||||
self.num_turns_buffer: List[int] = []
|
||||
self.groups_with_identical_scores: int = 0
|
||||
self.groups_total: int = 0
|
||||
|
||||
logger.info("Loading verifiers environment: %s", config.vf_env_name)
|
||||
env_args = config.get_env_args()
|
||||
if env_args:
|
||||
logger.info("Environment args: %s", env_args)
|
||||
self.vf_env = vf.load_environment(config.vf_env_name, **env_args)
|
||||
self.rubric = self.vf_env.rubric
|
||||
self.system_prompt = self.vf_env.system_prompt
|
||||
|
||||
# Get reward function names for metrics reporting
|
||||
self.reward_func_names = self.rubric._get_reward_func_names()
|
||||
logger.info("Reward functions: %s", self.reward_func_names)
|
||||
|
||||
# Log multi-turn config if available
|
||||
if hasattr(self.vf_env, "max_turns"):
|
||||
logger.info("Max turns: %d", self.vf_env.max_turns)
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[VfEnvConfig, List[APIServerConfig]]:
|
||||
env_config = VfEnvConfig(
|
||||
tokenizer_name="Qwen/Qwen2.5-1.5B-Instruct",
|
||||
group_size=8,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=4,
|
||||
steps_per_eval=100,
|
||||
max_token_length=2048,
|
||||
wandb_name="verifiers",
|
||||
)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="Qwen/Qwen2.5-1.5B-Instruct",
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=4,
|
||||
server_type="sglang",
|
||||
),
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Enhanced wandb logging with verifiers-specific metrics."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
# Log mean reward across all rollouts
|
||||
if self.reward_buffer:
|
||||
wandb_metrics["metrics/mean_reward"] = sum(self.reward_buffer) / len(
|
||||
self.reward_buffer
|
||||
)
|
||||
wandb_metrics["metrics/reward_std"] = (
|
||||
(
|
||||
sum(
|
||||
(r - wandb_metrics["metrics/mean_reward"]) ** 2
|
||||
for r in self.reward_buffer
|
||||
)
|
||||
/ len(self.reward_buffer)
|
||||
)
|
||||
** 0.5
|
||||
if len(self.reward_buffer) > 1
|
||||
else 0.0
|
||||
)
|
||||
self.reward_buffer = []
|
||||
|
||||
# Log per-reward-function metrics (e.g., strict_accuracy, format_score)
|
||||
if self.metrics_buffer:
|
||||
for metric_name, values in self.metrics_buffer.items():
|
||||
if values:
|
||||
avg_metric = sum(values) / len(values)
|
||||
wandb_metrics[f"metrics/{metric_name}"] = avg_metric
|
||||
self.metrics_buffer = defaultdict(list)
|
||||
|
||||
# Log multi-turn statistics
|
||||
if self.num_turns_buffer:
|
||||
wandb_metrics["metrics/avg_num_turns"] = sum(self.num_turns_buffer) / len(
|
||||
self.num_turns_buffer
|
||||
)
|
||||
wandb_metrics["metrics/max_num_turns"] = max(self.num_turns_buffer)
|
||||
self.num_turns_buffer = []
|
||||
|
||||
# Log group filtering statistics (helpful for debugging)
|
||||
if self.groups_total > 0:
|
||||
wandb_metrics["metrics/groups_with_identical_scores"] = (
|
||||
self.groups_with_identical_scores
|
||||
)
|
||||
wandb_metrics["metrics/groups_total"] = self.groups_total
|
||||
wandb_metrics["metrics/identical_score_rate"] = (
|
||||
self.groups_with_identical_scores / self.groups_total
|
||||
)
|
||||
# Reset counters
|
||||
self.groups_with_identical_scores = 0
|
||||
self.groups_total = 0
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
async def setup(self):
|
||||
# Dataset already has: prompt, answer, info, example_id, task
|
||||
train_data = self.vf_env.get_dataset()
|
||||
self.train = train_data.to_list()
|
||||
self.iter = 0
|
||||
|
||||
def save_checkpoint(self, step, data=None):
|
||||
if data is None:
|
||||
data = {}
|
||||
data["iter"] = self.iter
|
||||
super().save_checkpoint(step, data)
|
||||
|
||||
async def get_next_item(self):
|
||||
next_item = self.train[self.iter % len(self.train)]
|
||||
self.iter += 1
|
||||
return next_item
|
||||
|
||||
async def evaluate(self) -> Dict[str, float]:
|
||||
"""No-op. Use environments/eval_environments/verifiers_eval.py for evaluation."""
|
||||
return {}
|
||||
|
||||
async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, list]:
|
||||
"""Unified trajectory collection using vf_env.generate() with ManagedServer.
|
||||
|
||||
Works for both RL training (serve) and SFT data generation (process).
|
||||
Uses ManagedServer adapter for automatic token and logprob tracking.
|
||||
"""
|
||||
# Get server config (handle both real servers and test harness)
|
||||
if hasattr(self.server, "servers") and self.server.servers:
|
||||
server_config = self.server.servers[0].config
|
||||
else:
|
||||
# Fallback for testing
|
||||
server_config = APIServerConfig(
|
||||
model_name=self.config.tokenizer_name,
|
||||
base_url="http://localhost:8000/v1",
|
||||
)
|
||||
|
||||
# Build inputs for group_size rollouts
|
||||
inputs = [
|
||||
{
|
||||
"prompt": item["prompt"],
|
||||
"answer": item.get("answer", ""),
|
||||
"example_id": item["example_id"],
|
||||
"task": item.get("task", self.config.vf_env_name),
|
||||
"info": item.get("info", {}),
|
||||
}
|
||||
for _ in range(self.config.group_size)
|
||||
]
|
||||
|
||||
# Use ManagedServer for automatic token/logprob tracking
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
# Create adapter that looks like AsyncOpenAI for verifiers
|
||||
adapter = ManagedServerAdapter(
|
||||
managed_server=managed,
|
||||
base_url=server_config.base_url,
|
||||
)
|
||||
|
||||
# Use vf_env.generate() - handles batching and scoring internally
|
||||
results = await self.vf_env.generate(
|
||||
inputs=inputs,
|
||||
client=adapter,
|
||||
model=server_config.model_name,
|
||||
sampling_args={
|
||||
"temperature": 1.0,
|
||||
"max_completion_tokens": self.config.max_token_length,
|
||||
},
|
||||
max_concurrent=self.config.group_size,
|
||||
max_concurrent_scoring=self.config.group_size,
|
||||
save_results=False,
|
||||
independent_scoring=True,
|
||||
)
|
||||
|
||||
# Get tracked state from ManagedServer
|
||||
managed_state = managed.get_state()
|
||||
nodes = managed_state["nodes"]
|
||||
|
||||
scored_data: ScoredDataGroup = {
|
||||
"tokens": [],
|
||||
"masks": [],
|
||||
"scores": [],
|
||||
"messages": [],
|
||||
"inference_logprobs": [],
|
||||
}
|
||||
|
||||
# Zip verifiers states with ManagedServer nodes for logprob tracking
|
||||
for i, vf_state in enumerate(results["state"]):
|
||||
# Extract messages from state
|
||||
messages = list(vf_state.get("prompt", [])) + list(
|
||||
vf_state.get("completion", [])
|
||||
)
|
||||
messages = [
|
||||
{**msg, "content": msg.get("content") or ""} for msg in messages
|
||||
]
|
||||
|
||||
# Get trajectory for metrics
|
||||
trajectory = vf_state.get("trajectory", [])
|
||||
|
||||
# Get tokens, masks, and logprobs from ManagedServer
|
||||
# IMPORTANT: We use ManagedServer's tokens (not re-tokenize) to ensure
|
||||
# alignment with logprobs. ManagedServer tracks tokens and logprobs together.
|
||||
if i >= len(nodes):
|
||||
raise RuntimeError(
|
||||
f"Node count mismatch: expected at least {i + 1} nodes, got {len(nodes)}. "
|
||||
"ManagedServer should track all rollouts."
|
||||
)
|
||||
|
||||
node = nodes[i]
|
||||
scored_data["tokens"].append(node.tokens)
|
||||
scored_data["masks"].append(node.masked_tokens)
|
||||
scored_data["inference_logprobs"].append(node.logprobs)
|
||||
scored_data["messages"].append(messages)
|
||||
|
||||
reward = vf_state.get("reward", 0.0)
|
||||
scored_data["scores"].append(reward)
|
||||
|
||||
# Metrics logging
|
||||
self.reward_buffer.append(reward)
|
||||
num_turns = len(trajectory)
|
||||
self.num_turns_buffer.append(num_turns)
|
||||
logger.debug("Rollout: %d turns, reward=%.3f", num_turns, reward)
|
||||
|
||||
# Per-function metrics from verifiers state
|
||||
state_metrics = vf_state.get("metrics", {})
|
||||
for metric_name, metric_value in state_metrics.items():
|
||||
if isinstance(metric_value, (int, float)):
|
||||
self.metrics_buffer[metric_name].append(float(metric_value))
|
||||
|
||||
# Log group summary
|
||||
turns = [len(s.get("trajectory", [])) for s in results["state"]]
|
||||
logger.info(
|
||||
"Group: %d rollouts, turns=%s, rewards=%s, nodes=%d",
|
||||
len(results["state"]),
|
||||
turns,
|
||||
[f"{s:.3f}" for s in scored_data["scores"]],
|
||||
len(nodes),
|
||||
)
|
||||
|
||||
# Track identical scores for debugging
|
||||
self.groups_total += 1
|
||||
if len(set(scored_data["scores"])) == 1:
|
||||
self.groups_with_identical_scores += 1
|
||||
logger.debug(
|
||||
"Group has identical scores (%.3f) - will be filtered by base env",
|
||||
scored_data["scores"][0],
|
||||
)
|
||||
|
||||
return scored_data, []
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
VerifiersEnv.cli()
|
||||
|
|
@ -17,7 +17,7 @@ dependencies = [
|
|||
"numpy",
|
||||
"wandb",
|
||||
"gymnasium",
|
||||
"math-verify==0.7.0",
|
||||
"math-verify>=0.8.0",
|
||||
"jinja2",
|
||||
"nltk",
|
||||
"rich",
|
||||
|
|
@ -58,6 +58,9 @@ examples = [
|
|||
"atroposlib[rewardfns]",
|
||||
"langdetect"
|
||||
]
|
||||
verifiers = [
|
||||
"verifiers==0.1.9.post2"
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue