switch eval to use managed server adapter impl. moved managed server

adapter
This commit is contained in:
balyan.sid@gmail.com 2026-01-23 23:25:19 +05:30
parent 32d12c05c3
commit 5a20abdce7
4 changed files with 253 additions and 270 deletions

View file

@ -509,3 +509,63 @@ class ManagedServer:
self.sequences.clear()
else:
self.current_nodes.clear()
class ManagedServerAdapter:
"""
Adapter that makes ManagedServer look like AsyncOpenAI for external libraries.
Implements the subset of AsyncOpenAI interface commonly used:
- client.chat.completions.create()
- client.completions.create()
- client.base_url
This allows libraries like verifiers to use ManagedServer transparently
while still getting automatic token and logprob tracking.
"""
def __init__(self, managed_server: ManagedServer, base_url: str):
"""
Initialize the adapter.
Args:
managed_server: The ManagedServer instance to wrap
base_url: The base URL to expose (for compatibility checks)
"""
self._managed = managed_server
self.base_url = base_url
self.chat = self._ChatNamespace(self._managed)
self.completions = self._CompletionsNamespace(self._managed)
class _ChatNamespace:
def __init__(self, managed: ManagedServer):
self._managed = managed
self.completions = ManagedServerAdapter._ChatCompletionsNamespace(managed)
class _ChatCompletionsNamespace:
def __init__(self, managed: ManagedServer):
self._managed = managed
async def create(self, **kwargs):
return await self._managed.chat_completion(**kwargs)
class _CompletionsNamespace:
def __init__(self, managed: ManagedServer):
self._managed = managed
async def create(self, **kwargs):
return await self._managed.completion(**kwargs)
async def post(self, path: str, body: dict, cast_to: type):
"""Not supported - raises NotImplementedError."""
raise NotImplementedError(
f"ManagedServerAdapter does not support post() for path '{path}'. "
"This is used for vLLM interleaved rollouts. Use standard chat completions."
)
def copy(self, **kwargs):
"""Not supported - raises NotImplementedError."""
raise NotImplementedError(
"ManagedServerAdapter does not support copy(). "
"This is used for vLLM tokenization endpoints."
)

View file

@ -66,9 +66,9 @@ A flexible environment that integrates with the [Verifiers](https://docs.primein
**Output (Evaluation - `verifiers_eval.py`):**
Uses `evaluate_log()` from `atroposlib.envs.eval` to output:
Uses `evaluate_log()` from `BaseEnv` to output:
- Console: Metrics table with accuracy, avg_score, time, and per-reward function breakdown
- File: `metrics.json` and `samples.jsonl` (when `--eval-dir` is specified)
- File: `metrics.json` and `samples.jsonl` (when `--env.data_dir_to_save_evals` is specified)
**Configuration Options (`VfEnvConfig` for `verifiers_server.py`):**
@ -79,17 +79,19 @@ Uses `evaluate_log()` from `atroposlib.envs.eval` to output:
**CLI Options (`verifiers_eval.py`):**
Uses the standard BaseEnv CLI pattern with `evaluate` subcommand. Key options:
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--server-url` | str | `http://localhost:8000/v1` | URL of the inference server |
| `--model-name` | str | (required) | Model name to evaluate |
| `--api-key` | str | `$OPENAI_API_KEY` | API key (defaults to env var) |
| `--vf-env-name` | str | `primeintellect/gsm8k` | Prime environment identifier |
| `--temperature` | float | `0.0` | Temperature for generation |
| `--max-tokens` | int | `2048` | Maximum tokens per completion |
| `--max-eval-items` | int | `-1` | Maximum items to evaluate (-1 for all) |
| `--max-concurrent` | int | `64` | Maximum concurrent requests |
| `--eval-dir` | str | `None` | Directory to save evaluation results |
| `--openai.base_url` | str | `http://localhost:9001/v1` | URL of the inference server |
| `--openai.model_name` | str | `Qwen/Qwen2.5-1.5B-Instruct` | Model name to evaluate |
| `--openai.api_key` | str | `x` | API key |
| `--env.vf_env_name` | str | `primeintellect/gsm8k` | Prime environment identifier |
| `--env.eval_temperature` | float | `0.0` | Temperature for generation |
| `--env.eval_max_tokens` | int | `2048` | Maximum tokens per completion |
| `--env.max_eval_items` | int | `-1` | Maximum items to evaluate (-1 for all) |
| `--env.max_concurrent` | int | `64` | Maximum concurrent requests |
| `--env.data_dir_to_save_evals` | str | `None` | Directory to save evaluation results |
**Usage Examples:**
@ -121,31 +123,33 @@ python verifiers_server.py evaluate \
--openai.base_url http://localhost:9001/v1
# Standalone Evaluation with OpenAI (verifiers_eval.py)
python eval_environments/verifiers_eval.py \
--server-url https://api.openai.com/v1 \
--model-name gpt-4o \
--vf-env-name primeintellect/gsm8k
python eval_environments/verifiers_eval.py evaluate \
--openai.base_url https://api.openai.com/v1 \
--openai.api_key $OPENAI_API_KEY \
--openai.model_name gpt-4o \
--env.vf_env_name primeintellect/gsm8k
# Quick test run with limited items
python eval_environments/verifiers_eval.py \
--server-url https://api.openai.com/v1 \
--model-name gpt-4o-mini \
--vf-env-name primeintellect/alphabet-sort \
--max-eval-items 10
python eval_environments/verifiers_eval.py evaluate \
--openai.base_url https://api.openai.com/v1 \
--openai.api_key $OPENAI_API_KEY \
--openai.model_name gpt-4o-mini \
--env.vf_env_name primeintellect/alphabet-sort \
--env.max_eval_items 10
# Evaluation with local server and results saved
python eval_environments/verifiers_eval.py \
--server-url http://localhost:9001/v1 \
--model-name Qwen/Qwen2.5-7B-Instruct \
--vf-env-name primeintellect/gsm8k \
--eval-dir ./eval_results
python eval_environments/verifiers_eval.py evaluate \
--openai.base_url http://localhost:9001/v1 \
--openai.model_name Qwen/Qwen2.5-7B-Instruct \
--env.vf_env_name primeintellect/gsm8k \
--env.data_dir_to_save_evals ./eval_results
```
**Key Implementation Details:**
- **RL Training Mode (`serve`)**: Uses `ManagedServer` for proper token/logprob alignment required by policy gradient methods (GRPO, PPO, REINFORCE). Returns `ScoredDataGroup` with `tokens`, `masks`, `scores`, and `inference_logprobs`.
- **SFT Datagen Mode (`process`)**: Uses `tokenize_for_trainer` to tokenize API responses with your target model's tokenizer (e.g., GPT-4o responses tokenized for Qwen/Llama). Does NOT require logprobs.
- **Evaluation (`verifiers_eval.py`)**: Standalone evaluation script using `EvalBase` pattern. Uses verifiers' native batch evaluation for efficiency and outputs results via `evaluate_log()`. Works with any OpenAI-compatible API.
- **Evaluation (`verifiers_eval.py`)**: Standalone evaluation script using `BaseEnv` pattern with `evaluate` subcommand. Uses verifiers' native batch evaluation for efficiency and outputs results via `evaluate_log()`. Works with any OpenAI-compatible API.
**Prime Environment Installation:**
```bash

View file

@ -14,31 +14,53 @@ To install a Verifiers/Prime environment:
Docs: https://docs.primeintellect.ai/tutorials-environments/install
Usage:
python verifiers_eval.py \
--server-url http://localhost:8000/v1 \
--model-name Qwen/Qwen2.5-7B-Instruct \
--vf-env-name primeintellect/gsm8k \
--max-eval-items 100
# Evaluate with local server
python verifiers_eval.py evaluate \
--env.vf_env_name "primeintellect/gsm8k" \
--env.max_eval_items 100 \
--openai.model_name "Qwen/Qwen2.5-7B-Instruct" \
--openai.base_url "http://localhost:8000/v1"
# Evaluate with OpenAI
python verifiers_eval.py evaluate \
--env.vf_env_name "primeintellect/gsm8k" \
--env.max_eval_items 50 \
--openai.model_name "gpt-4o" \
--openai.api_key "$OPENAI_API_KEY" \
--openai.base_url "https://api.openai.com/v1"
Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.)
"""
import argparse
import asyncio
import json
import time
from typing import Tuple
from typing import Any, Dict, List, Tuple
import verifiers as vf
from openai import AsyncOpenAI
from atroposlib.envs.eval import EvalBase, evaluate_log
from atroposlib.envs.server_handling.server_manager import ServerManager
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
ScoredDataGroup,
)
# Import ManagedServerAdapter from shared location
from atroposlib.envs.server_handling.managed_server import ManagedServerAdapter
# Patch math_verify timeout to work in async context
# The signal-based timeout doesn't work in non-main threads (asyncio event loop)
def _no_signal_timeout(_timeout_seconds: int):
"""Replacement timeout decorator that doesn't use signals."""
def _no_signal_timeout(
_timeout_seconds: int | None = None, *, timeout_seconds: int | None = None
):
"""Replacement timeout decorator that doesn't use signals.
Accepts both positional arg (timeout(5)) and keyword arg (timeout(timeout_seconds=5)).
"""
# Silence unused parameter warnings - these match the original API signature
del _timeout_seconds, timeout_seconds
def decorator(func):
def wrapper(*args, **kwargs):
@ -63,108 +85,143 @@ except ImportError:
pass # math_verify not installed
class VerifiersEval(EvalBase):
class VfEvalConfig(BaseEnvConfig):
"""Configuration for Verifiers evaluation environment."""
vf_env_name: str = "primeintellect/gsm8k"
env_args: str = "{}" # JSON string for environment-specific args
eval_temperature: float = 0.0
eval_max_tokens: int = 2048
max_eval_items: int = -1 # -1 means evaluate all items
max_concurrent: int = 64
# Override BaseEnvConfig defaults for eval mode
group_size: int = 1
total_steps: int = 1
steps_per_eval: int = 1
use_wandb: bool = True
def get_env_args(self) -> Dict[str, Any]:
"""Parse env_args JSON string into dict."""
if isinstance(self.env_args, dict):
return self.env_args
return json.loads(self.env_args)
class VerifiersEvalEnv(BaseEnv):
"""
Verifiers Evaluation using EvalBase pattern.
Verifiers Evaluation Environment using BaseEnv pattern.
Uses verifiers' native batch evaluation for efficiency,
with EvalBase's standardized logging via evaluate_log().
with BaseEnv's standardized logging via evaluate_log().
Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.)
"""
def __init__(
self,
vf_env_name: str = "primeintellect/gsm8k",
env_args: dict = None,
temperature: float = 0.0,
max_tokens: int = 2048,
max_eval_items: int = -1,
max_concurrent: int = 64,
eval_dir: str = None,
verbose: bool = True,
**kwargs,
):
self.vf_env_name = vf_env_name
self.env_args = env_args or {}
self.temperature = temperature
self.max_tokens = max_tokens
self.max_eval_items = max_eval_items
self.max_concurrent = max_concurrent
name = "verifiers_eval"
env_config_cls = VfEvalConfig # type: ignore[assignment]
# Load verifiers environment
self.vf_env = vf.load_environment(vf_env_name, **self.env_args)
@classmethod
def config_init(cls) -> Tuple[VfEvalConfig, List[APIServerConfig]]:
"""Return default configurations."""
env_config = VfEvalConfig(
tokenizer_name="Qwen/Qwen2.5-1.5B-Instruct",
vf_env_name="primeintellect/gsm8k",
wandb_name="verifiers_eval",
)
server_configs = [
APIServerConfig(
model_name="Qwen/Qwen2.5-1.5B-Instruct",
base_url="http://localhost:9001/v1",
api_key="x",
),
]
return env_config, server_configs
async def setup(self):
"""Load verifiers environment and dataset."""
env_args = self.config.get_env_args()
self.vf_env = vf.load_environment(self.config.vf_env_name, **env_args)
self.reward_func_names = self.vf_env.rubric._get_reward_func_names()
# Initialize EvalBase (calls setup_data)
super().__init__(
eval_dir=eval_dir,
verbose=verbose,
**kwargs,
)
def get_generation_params(self):
"""Generation params for logging."""
return {
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"n": 1,
}
def setup_data(self) -> list:
"""Return evaluation dataset from verifiers environment."""
# Load evaluation dataset
dataset = self.vf_env.get_eval_dataset()
if self.max_eval_items > 0:
n = min(len(dataset), self.max_eval_items)
if self.config.max_eval_items > 0:
n = min(len(dataset), self.config.max_eval_items)
dataset = dataset.select(range(n))
return dataset.to_list()
self.data = dataset.to_list()
async def run_item(
self, server: ServerManager, data_item: dict # noqa: ARG002
) -> Tuple[dict, list]:
"""Not used - we override __call__ for batch evaluation."""
raise NotImplementedError(
"VerifiersEval uses batch evaluation via __call__, not per-item run_item"
async def get_next_item(self):
"""Not used in eval mode - stub implementation."""
return None
async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, list]:
"""Not used in eval mode - stub implementation."""
_ = item # unused in eval mode
return (
ScoredDataGroup(
tokens=[],
masks=[],
scores=[],
messages=[],
inference_logprobs=[],
advantages=[],
ref_logprobs=[],
generation_params=None,
group_overrides=None,
overrides=[],
images=[],
),
[],
)
async def __call__(self, server_manager: ServerManager):
"""Run evaluation using verifiers' native batch machinery."""
async def evaluate(self, *args, **kwargs) -> Dict[str, float]:
"""Run evaluation using verifiers with ManagedServer."""
start_time = time.time()
# Create OpenAI client from server config
server = server_manager.servers[0]
client = AsyncOpenAI(
api_key=server.config.api_key or "x",
base_url=server.config.base_url,
timeout=getattr(server.config, "timeout", 600),
)
model = server.config.model_name
# Get server config
if hasattr(self.server, "servers") and self.server.servers:
server_config = self.server.servers[0].config
else:
server_config = self.server_configs[0]
model_name = server_config.model_name
print(f"\n{'=' * 60}")
print(f"Verifiers Evaluation: {self.vf_env_name}")
print(f"Verifiers Evaluation: {self.config.vf_env_name}")
print(f"{'=' * 60}")
print(f" Model: {model}")
print(f" Model: {model_name}")
print(f" Items: {len(self.data)}")
print(f" Reward functions: {self.reward_func_names}")
print(f" Temperature: {self.temperature}")
print(f" Max concurrent: {self.max_concurrent}")
print(f" Temperature: {self.config.eval_temperature}")
print(f" Max concurrent: {self.config.max_concurrent}")
print(f"{'=' * 60}\n")
num_examples = self.max_eval_items if self.max_eval_items > 0 else -1
# Use verifiers' batch evaluation
results = await self.vf_env.evaluate(
client=client,
model=model,
sampling_args={
"temperature": self.temperature,
"max_tokens": self.max_tokens,
},
num_examples=num_examples,
max_concurrent=self.max_concurrent,
save_results=False,
num_examples = (
self.config.max_eval_items if self.config.max_eval_items > 0 else -1
)
# Use ManagedServer for automatic token/logprob tracking
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
# Create adapter that looks like AsyncOpenAI for verifiers
adapter = ManagedServerAdapter(
managed_server=managed,
base_url=server_config.base_url,
)
# Use verifiers' batch evaluation
results = await self.vf_env.evaluate(
client=adapter,
model=model_name,
sampling_args={
"temperature": self.config.eval_temperature,
"max_tokens": self.config.eval_max_tokens,
},
num_examples=num_examples,
max_concurrent=self.config.max_concurrent,
save_results=False,
)
end_time = time.time()
# Extract from verifiers output
@ -193,6 +250,11 @@ class VerifiersEval(EvalBase):
"avg_score": avg_score,
}
# Add per-function metrics
for func_name, data in reward_breakdown.items():
metrics[f"{func_name}_avg"] = data["avg"]
metrics[f"{func_name}_correct_rate"] = data["correct"] / total
# Print results summary
print(f"\n{'=' * 60}")
print("Verifiers Evaluation Results")
@ -233,114 +295,25 @@ class VerifiersEval(EvalBase):
}
)
# Use EvalBase's evaluate_log
task_name = f"VerifiersEval@{self.vf_env_name.replace('/', '_')}"
evaluate_log(
# Use BaseEnv's evaluate_log
task_name = f"VerifiersEval@{self.config.vf_env_name.replace('/', '_')}"
await self.evaluate_log(
metrics=metrics,
eval_dir=getattr(self, "eval_dir", None),
task_name=task_name,
model_name=model,
model_name=model_name,
start_time=start_time,
end_time=end_time,
generation_parameters=self.get_generation_params(),
generation_parameters={
"temperature": self.config.eval_temperature,
"max_tokens": self.config.eval_max_tokens,
"n": 1,
},
samples=samples,
verbose=getattr(self, "verbose", False),
verbose=True,
)
return metrics
async def main():
"""CLI entry point for verifiers evaluation."""
import os
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
parser = argparse.ArgumentParser(
description="Evaluate models using Verifiers environments"
)
# Server args (same as eval_runner)
parser.add_argument(
"--server-url",
type=str,
default="http://localhost:8000/v1",
help="URL of the inference server",
)
parser.add_argument(
"--model-name",
type=str,
required=True,
help="Model name to evaluate",
)
parser.add_argument(
"--api-key",
type=str,
default=os.getenv("OPENAI_API_KEY", "x"),
help="API key (defaults to OPENAI_API_KEY env var)",
)
# Verifiers-specific args
parser.add_argument(
"--vf-env-name",
type=str,
default="primeintellect/gsm8k",
help="Verifiers environment name (e.g., primeintellect/gsm8k)",
)
parser.add_argument(
"--max-eval-items",
type=int,
default=-1,
help="Maximum items to evaluate (-1 for all)",
)
parser.add_argument(
"--temperature",
type=float,
default=0.0,
help="Generation temperature",
)
parser.add_argument(
"--max-tokens",
type=int,
default=2048,
help="Maximum tokens per completion",
)
parser.add_argument(
"--max-concurrent",
type=int,
default=64,
help="Maximum concurrent requests",
)
parser.add_argument(
"--eval-dir",
type=str,
default=None,
help="Directory to save evaluation results",
)
args = parser.parse_args()
# Create server manager
server_manager = ServerManager(
configs=[
APIServerConfig(
api_key=args.api_key,
base_url=args.server_url,
model_name=args.model_name,
health_check=False,
),
]
)
# Create and run evaluation
eval_instance = VerifiersEval(
vf_env_name=args.vf_env_name,
max_eval_items=args.max_eval_items,
temperature=args.temperature,
max_tokens=args.max_tokens,
max_concurrent=args.max_concurrent,
eval_dir=args.eval_dir,
verbose=True,
)
return await eval_instance(server_manager)
if __name__ == "__main__":
asyncio.run(main())
VerifiersEvalEnv.cli()

View file

@ -46,65 +46,11 @@ from atroposlib.envs.base import (
BaseEnvConfig,
ScoredDataGroup,
)
from atroposlib.envs.server_handling.managed_server import ManagedServer
from atroposlib.envs.server_handling.managed_server import ManagedServerAdapter
logger = logging.getLogger(__name__)
class ManagedServerAdapter:
"""
Adapter that makes ManagedServer look like AsyncOpenAI for verifiers.
Implements the subset of AsyncOpenAI interface that verifiers uses:
- client.chat.completions.create()
- client.completions.create()
- client.base_url
"""
def __init__(self, managed_server: ManagedServer, base_url: str):
self._managed = managed_server
self.base_url = base_url
self.chat = self._ChatNamespace(self._managed)
self.completions = self._CompletionsNamespace(self._managed)
class _ChatNamespace:
def __init__(self, managed: ManagedServer):
self._managed = managed
self.completions = ManagedServerAdapter._ChatCompletionsNamespace(managed)
class _ChatCompletionsNamespace:
def __init__(self, managed: ManagedServer):
self._managed = managed
async def create(self, **kwargs):
logger.info(
"ManagedServerAdapter.chat.completions.create called with model=%s",
kwargs.get("model"),
)
result = await self._managed.chat_completion(**kwargs)
logger.info("ManagedServerAdapter.chat.completions.create completed")
return result
class _CompletionsNamespace:
def __init__(self, managed: ManagedServer):
self._managed = managed
async def create(self, **kwargs):
return await self._managed.completion(**kwargs)
async def post(self, path: str, body: dict, cast_to: type):
raise NotImplementedError(
f"ManagedServerAdapter does not support post() for path '{path}'. "
"This is used for vLLM interleaved rollouts. Use standard chat completions."
)
def copy(self, **kwargs):
raise NotImplementedError(
"ManagedServerAdapter does not support copy(). "
"This is used for vLLM tokenization endpoints."
)
class VfEnvConfig(BaseEnvConfig):
vf_env_name: str = ""
env_args: str = "{}"