mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
switch eval to use managed server adapter impl. moved managed server
adapter
This commit is contained in:
parent
32d12c05c3
commit
5a20abdce7
4 changed files with 253 additions and 270 deletions
|
|
@ -509,3 +509,63 @@ class ManagedServer:
|
|||
self.sequences.clear()
|
||||
else:
|
||||
self.current_nodes.clear()
|
||||
|
||||
|
||||
class ManagedServerAdapter:
|
||||
"""
|
||||
Adapter that makes ManagedServer look like AsyncOpenAI for external libraries.
|
||||
|
||||
Implements the subset of AsyncOpenAI interface commonly used:
|
||||
- client.chat.completions.create()
|
||||
- client.completions.create()
|
||||
- client.base_url
|
||||
|
||||
This allows libraries like verifiers to use ManagedServer transparently
|
||||
while still getting automatic token and logprob tracking.
|
||||
"""
|
||||
|
||||
def __init__(self, managed_server: ManagedServer, base_url: str):
|
||||
"""
|
||||
Initialize the adapter.
|
||||
|
||||
Args:
|
||||
managed_server: The ManagedServer instance to wrap
|
||||
base_url: The base URL to expose (for compatibility checks)
|
||||
"""
|
||||
self._managed = managed_server
|
||||
self.base_url = base_url
|
||||
self.chat = self._ChatNamespace(self._managed)
|
||||
self.completions = self._CompletionsNamespace(self._managed)
|
||||
|
||||
class _ChatNamespace:
|
||||
def __init__(self, managed: ManagedServer):
|
||||
self._managed = managed
|
||||
self.completions = ManagedServerAdapter._ChatCompletionsNamespace(managed)
|
||||
|
||||
class _ChatCompletionsNamespace:
|
||||
def __init__(self, managed: ManagedServer):
|
||||
self._managed = managed
|
||||
|
||||
async def create(self, **kwargs):
|
||||
return await self._managed.chat_completion(**kwargs)
|
||||
|
||||
class _CompletionsNamespace:
|
||||
def __init__(self, managed: ManagedServer):
|
||||
self._managed = managed
|
||||
|
||||
async def create(self, **kwargs):
|
||||
return await self._managed.completion(**kwargs)
|
||||
|
||||
async def post(self, path: str, body: dict, cast_to: type):
|
||||
"""Not supported - raises NotImplementedError."""
|
||||
raise NotImplementedError(
|
||||
f"ManagedServerAdapter does not support post() for path '{path}'. "
|
||||
"This is used for vLLM interleaved rollouts. Use standard chat completions."
|
||||
)
|
||||
|
||||
def copy(self, **kwargs):
|
||||
"""Not supported - raises NotImplementedError."""
|
||||
raise NotImplementedError(
|
||||
"ManagedServerAdapter does not support copy(). "
|
||||
"This is used for vLLM tokenization endpoints."
|
||||
)
|
||||
|
|
|
|||
|
|
@ -66,9 +66,9 @@ A flexible environment that integrates with the [Verifiers](https://docs.primein
|
|||
|
||||
**Output (Evaluation - `verifiers_eval.py`):**
|
||||
|
||||
Uses `evaluate_log()` from `atroposlib.envs.eval` to output:
|
||||
Uses `evaluate_log()` from `BaseEnv` to output:
|
||||
- Console: Metrics table with accuracy, avg_score, time, and per-reward function breakdown
|
||||
- File: `metrics.json` and `samples.jsonl` (when `--eval-dir` is specified)
|
||||
- File: `metrics.json` and `samples.jsonl` (when `--env.data_dir_to_save_evals` is specified)
|
||||
|
||||
**Configuration Options (`VfEnvConfig` for `verifiers_server.py`):**
|
||||
|
||||
|
|
@ -79,17 +79,19 @@ Uses `evaluate_log()` from `atroposlib.envs.eval` to output:
|
|||
|
||||
**CLI Options (`verifiers_eval.py`):**
|
||||
|
||||
Uses the standard BaseEnv CLI pattern with `evaluate` subcommand. Key options:
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `--server-url` | str | `http://localhost:8000/v1` | URL of the inference server |
|
||||
| `--model-name` | str | (required) | Model name to evaluate |
|
||||
| `--api-key` | str | `$OPENAI_API_KEY` | API key (defaults to env var) |
|
||||
| `--vf-env-name` | str | `primeintellect/gsm8k` | Prime environment identifier |
|
||||
| `--temperature` | float | `0.0` | Temperature for generation |
|
||||
| `--max-tokens` | int | `2048` | Maximum tokens per completion |
|
||||
| `--max-eval-items` | int | `-1` | Maximum items to evaluate (-1 for all) |
|
||||
| `--max-concurrent` | int | `64` | Maximum concurrent requests |
|
||||
| `--eval-dir` | str | `None` | Directory to save evaluation results |
|
||||
| `--openai.base_url` | str | `http://localhost:9001/v1` | URL of the inference server |
|
||||
| `--openai.model_name` | str | `Qwen/Qwen2.5-1.5B-Instruct` | Model name to evaluate |
|
||||
| `--openai.api_key` | str | `x` | API key |
|
||||
| `--env.vf_env_name` | str | `primeintellect/gsm8k` | Prime environment identifier |
|
||||
| `--env.eval_temperature` | float | `0.0` | Temperature for generation |
|
||||
| `--env.eval_max_tokens` | int | `2048` | Maximum tokens per completion |
|
||||
| `--env.max_eval_items` | int | `-1` | Maximum items to evaluate (-1 for all) |
|
||||
| `--env.max_concurrent` | int | `64` | Maximum concurrent requests |
|
||||
| `--env.data_dir_to_save_evals` | str | `None` | Directory to save evaluation results |
|
||||
|
||||
**Usage Examples:**
|
||||
|
||||
|
|
@ -121,31 +123,33 @@ python verifiers_server.py evaluate \
|
|||
--openai.base_url http://localhost:9001/v1
|
||||
|
||||
# Standalone Evaluation with OpenAI (verifiers_eval.py)
|
||||
python eval_environments/verifiers_eval.py \
|
||||
--server-url https://api.openai.com/v1 \
|
||||
--model-name gpt-4o \
|
||||
--vf-env-name primeintellect/gsm8k
|
||||
python eval_environments/verifiers_eval.py evaluate \
|
||||
--openai.base_url https://api.openai.com/v1 \
|
||||
--openai.api_key $OPENAI_API_KEY \
|
||||
--openai.model_name gpt-4o \
|
||||
--env.vf_env_name primeintellect/gsm8k
|
||||
|
||||
# Quick test run with limited items
|
||||
python eval_environments/verifiers_eval.py \
|
||||
--server-url https://api.openai.com/v1 \
|
||||
--model-name gpt-4o-mini \
|
||||
--vf-env-name primeintellect/alphabet-sort \
|
||||
--max-eval-items 10
|
||||
python eval_environments/verifiers_eval.py evaluate \
|
||||
--openai.base_url https://api.openai.com/v1 \
|
||||
--openai.api_key $OPENAI_API_KEY \
|
||||
--openai.model_name gpt-4o-mini \
|
||||
--env.vf_env_name primeintellect/alphabet-sort \
|
||||
--env.max_eval_items 10
|
||||
|
||||
# Evaluation with local server and results saved
|
||||
python eval_environments/verifiers_eval.py \
|
||||
--server-url http://localhost:9001/v1 \
|
||||
--model-name Qwen/Qwen2.5-7B-Instruct \
|
||||
--vf-env-name primeintellect/gsm8k \
|
||||
--eval-dir ./eval_results
|
||||
python eval_environments/verifiers_eval.py evaluate \
|
||||
--openai.base_url http://localhost:9001/v1 \
|
||||
--openai.model_name Qwen/Qwen2.5-7B-Instruct \
|
||||
--env.vf_env_name primeintellect/gsm8k \
|
||||
--env.data_dir_to_save_evals ./eval_results
|
||||
```
|
||||
|
||||
**Key Implementation Details:**
|
||||
|
||||
- **RL Training Mode (`serve`)**: Uses `ManagedServer` for proper token/logprob alignment required by policy gradient methods (GRPO, PPO, REINFORCE). Returns `ScoredDataGroup` with `tokens`, `masks`, `scores`, and `inference_logprobs`.
|
||||
- **SFT Datagen Mode (`process`)**: Uses `tokenize_for_trainer` to tokenize API responses with your target model's tokenizer (e.g., GPT-4o responses tokenized for Qwen/Llama). Does NOT require logprobs.
|
||||
- **Evaluation (`verifiers_eval.py`)**: Standalone evaluation script using `EvalBase` pattern. Uses verifiers' native batch evaluation for efficiency and outputs results via `evaluate_log()`. Works with any OpenAI-compatible API.
|
||||
- **Evaluation (`verifiers_eval.py`)**: Standalone evaluation script using `BaseEnv` pattern with `evaluate` subcommand. Uses verifiers' native batch evaluation for efficiency and outputs results via `evaluate_log()`. Works with any OpenAI-compatible API.
|
||||
|
||||
**Prime Environment Installation:**
|
||||
```bash
|
||||
|
|
|
|||
|
|
@ -14,31 +14,53 @@ To install a Verifiers/Prime environment:
|
|||
Docs: https://docs.primeintellect.ai/tutorials-environments/install
|
||||
|
||||
Usage:
|
||||
python verifiers_eval.py \
|
||||
--server-url http://localhost:8000/v1 \
|
||||
--model-name Qwen/Qwen2.5-7B-Instruct \
|
||||
--vf-env-name primeintellect/gsm8k \
|
||||
--max-eval-items 100
|
||||
# Evaluate with local server
|
||||
python verifiers_eval.py evaluate \
|
||||
--env.vf_env_name "primeintellect/gsm8k" \
|
||||
--env.max_eval_items 100 \
|
||||
--openai.model_name "Qwen/Qwen2.5-7B-Instruct" \
|
||||
--openai.base_url "http://localhost:8000/v1"
|
||||
|
||||
# Evaluate with OpenAI
|
||||
python verifiers_eval.py evaluate \
|
||||
--env.vf_env_name "primeintellect/gsm8k" \
|
||||
--env.max_eval_items 50 \
|
||||
--openai.model_name "gpt-4o" \
|
||||
--openai.api_key "$OPENAI_API_KEY" \
|
||||
--openai.base_url "https://api.openai.com/v1"
|
||||
|
||||
Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Tuple
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import verifiers as vf
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from atroposlib.envs.eval import EvalBase, evaluate_log
|
||||
from atroposlib.envs.server_handling.server_manager import ServerManager
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
|
||||
# Import ManagedServerAdapter from shared location
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServerAdapter
|
||||
|
||||
# Patch math_verify timeout to work in async context
|
||||
# The signal-based timeout doesn't work in non-main threads (asyncio event loop)
|
||||
def _no_signal_timeout(_timeout_seconds: int):
|
||||
"""Replacement timeout decorator that doesn't use signals."""
|
||||
|
||||
|
||||
def _no_signal_timeout(
|
||||
_timeout_seconds: int | None = None, *, timeout_seconds: int | None = None
|
||||
):
|
||||
"""Replacement timeout decorator that doesn't use signals.
|
||||
|
||||
Accepts both positional arg (timeout(5)) and keyword arg (timeout(timeout_seconds=5)).
|
||||
"""
|
||||
# Silence unused parameter warnings - these match the original API signature
|
||||
del _timeout_seconds, timeout_seconds
|
||||
|
||||
def decorator(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
|
|
@ -63,108 +85,143 @@ except ImportError:
|
|||
pass # math_verify not installed
|
||||
|
||||
|
||||
class VerifiersEval(EvalBase):
|
||||
class VfEvalConfig(BaseEnvConfig):
|
||||
"""Configuration for Verifiers evaluation environment."""
|
||||
|
||||
vf_env_name: str = "primeintellect/gsm8k"
|
||||
env_args: str = "{}" # JSON string for environment-specific args
|
||||
eval_temperature: float = 0.0
|
||||
eval_max_tokens: int = 2048
|
||||
max_eval_items: int = -1 # -1 means evaluate all items
|
||||
max_concurrent: int = 64
|
||||
|
||||
# Override BaseEnvConfig defaults for eval mode
|
||||
group_size: int = 1
|
||||
total_steps: int = 1
|
||||
steps_per_eval: int = 1
|
||||
use_wandb: bool = True
|
||||
|
||||
def get_env_args(self) -> Dict[str, Any]:
|
||||
"""Parse env_args JSON string into dict."""
|
||||
if isinstance(self.env_args, dict):
|
||||
return self.env_args
|
||||
return json.loads(self.env_args)
|
||||
|
||||
|
||||
class VerifiersEvalEnv(BaseEnv):
|
||||
"""
|
||||
Verifiers Evaluation using EvalBase pattern.
|
||||
Verifiers Evaluation Environment using BaseEnv pattern.
|
||||
|
||||
Uses verifiers' native batch evaluation for efficiency,
|
||||
with EvalBase's standardized logging via evaluate_log().
|
||||
with BaseEnv's standardized logging via evaluate_log().
|
||||
|
||||
Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vf_env_name: str = "primeintellect/gsm8k",
|
||||
env_args: dict = None,
|
||||
temperature: float = 0.0,
|
||||
max_tokens: int = 2048,
|
||||
max_eval_items: int = -1,
|
||||
max_concurrent: int = 64,
|
||||
eval_dir: str = None,
|
||||
verbose: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
self.vf_env_name = vf_env_name
|
||||
self.env_args = env_args or {}
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.max_eval_items = max_eval_items
|
||||
self.max_concurrent = max_concurrent
|
||||
name = "verifiers_eval"
|
||||
env_config_cls = VfEvalConfig # type: ignore[assignment]
|
||||
|
||||
# Load verifiers environment
|
||||
self.vf_env = vf.load_environment(vf_env_name, **self.env_args)
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[VfEvalConfig, List[APIServerConfig]]:
|
||||
"""Return default configurations."""
|
||||
env_config = VfEvalConfig(
|
||||
tokenizer_name="Qwen/Qwen2.5-1.5B-Instruct",
|
||||
vf_env_name="primeintellect/gsm8k",
|
||||
wandb_name="verifiers_eval",
|
||||
)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="Qwen/Qwen2.5-1.5B-Instruct",
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
),
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup(self):
|
||||
"""Load verifiers environment and dataset."""
|
||||
env_args = self.config.get_env_args()
|
||||
self.vf_env = vf.load_environment(self.config.vf_env_name, **env_args)
|
||||
self.reward_func_names = self.vf_env.rubric._get_reward_func_names()
|
||||
|
||||
# Initialize EvalBase (calls setup_data)
|
||||
super().__init__(
|
||||
eval_dir=eval_dir,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_generation_params(self):
|
||||
"""Generation params for logging."""
|
||||
return {
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
"n": 1,
|
||||
}
|
||||
|
||||
def setup_data(self) -> list:
|
||||
"""Return evaluation dataset from verifiers environment."""
|
||||
# Load evaluation dataset
|
||||
dataset = self.vf_env.get_eval_dataset()
|
||||
if self.max_eval_items > 0:
|
||||
n = min(len(dataset), self.max_eval_items)
|
||||
if self.config.max_eval_items > 0:
|
||||
n = min(len(dataset), self.config.max_eval_items)
|
||||
dataset = dataset.select(range(n))
|
||||
return dataset.to_list()
|
||||
self.data = dataset.to_list()
|
||||
|
||||
async def run_item(
|
||||
self, server: ServerManager, data_item: dict # noqa: ARG002
|
||||
) -> Tuple[dict, list]:
|
||||
"""Not used - we override __call__ for batch evaluation."""
|
||||
raise NotImplementedError(
|
||||
"VerifiersEval uses batch evaluation via __call__, not per-item run_item"
|
||||
async def get_next_item(self):
|
||||
"""Not used in eval mode - stub implementation."""
|
||||
return None
|
||||
|
||||
async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, list]:
|
||||
"""Not used in eval mode - stub implementation."""
|
||||
_ = item # unused in eval mode
|
||||
return (
|
||||
ScoredDataGroup(
|
||||
tokens=[],
|
||||
masks=[],
|
||||
scores=[],
|
||||
messages=[],
|
||||
inference_logprobs=[],
|
||||
advantages=[],
|
||||
ref_logprobs=[],
|
||||
generation_params=None,
|
||||
group_overrides=None,
|
||||
overrides=[],
|
||||
images=[],
|
||||
),
|
||||
[],
|
||||
)
|
||||
|
||||
async def __call__(self, server_manager: ServerManager):
|
||||
"""Run evaluation using verifiers' native batch machinery."""
|
||||
async def evaluate(self, *args, **kwargs) -> Dict[str, float]:
|
||||
"""Run evaluation using verifiers with ManagedServer."""
|
||||
start_time = time.time()
|
||||
|
||||
# Create OpenAI client from server config
|
||||
server = server_manager.servers[0]
|
||||
client = AsyncOpenAI(
|
||||
api_key=server.config.api_key or "x",
|
||||
base_url=server.config.base_url,
|
||||
timeout=getattr(server.config, "timeout", 600),
|
||||
)
|
||||
model = server.config.model_name
|
||||
# Get server config
|
||||
if hasattr(self.server, "servers") and self.server.servers:
|
||||
server_config = self.server.servers[0].config
|
||||
else:
|
||||
server_config = self.server_configs[0]
|
||||
|
||||
model_name = server_config.model_name
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Verifiers Evaluation: {self.vf_env_name}")
|
||||
print(f"Verifiers Evaluation: {self.config.vf_env_name}")
|
||||
print(f"{'=' * 60}")
|
||||
print(f" Model: {model}")
|
||||
print(f" Model: {model_name}")
|
||||
print(f" Items: {len(self.data)}")
|
||||
print(f" Reward functions: {self.reward_func_names}")
|
||||
print(f" Temperature: {self.temperature}")
|
||||
print(f" Max concurrent: {self.max_concurrent}")
|
||||
print(f" Temperature: {self.config.eval_temperature}")
|
||||
print(f" Max concurrent: {self.config.max_concurrent}")
|
||||
print(f"{'=' * 60}\n")
|
||||
|
||||
num_examples = self.max_eval_items if self.max_eval_items > 0 else -1
|
||||
|
||||
# Use verifiers' batch evaluation
|
||||
results = await self.vf_env.evaluate(
|
||||
client=client,
|
||||
model=model,
|
||||
sampling_args={
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
},
|
||||
num_examples=num_examples,
|
||||
max_concurrent=self.max_concurrent,
|
||||
save_results=False,
|
||||
num_examples = (
|
||||
self.config.max_eval_items if self.config.max_eval_items > 0 else -1
|
||||
)
|
||||
|
||||
# Use ManagedServer for automatic token/logprob tracking
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
# Create adapter that looks like AsyncOpenAI for verifiers
|
||||
adapter = ManagedServerAdapter(
|
||||
managed_server=managed,
|
||||
base_url=server_config.base_url,
|
||||
)
|
||||
|
||||
# Use verifiers' batch evaluation
|
||||
results = await self.vf_env.evaluate(
|
||||
client=adapter,
|
||||
model=model_name,
|
||||
sampling_args={
|
||||
"temperature": self.config.eval_temperature,
|
||||
"max_tokens": self.config.eval_max_tokens,
|
||||
},
|
||||
num_examples=num_examples,
|
||||
max_concurrent=self.config.max_concurrent,
|
||||
save_results=False,
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# Extract from verifiers output
|
||||
|
|
@ -193,6 +250,11 @@ class VerifiersEval(EvalBase):
|
|||
"avg_score": avg_score,
|
||||
}
|
||||
|
||||
# Add per-function metrics
|
||||
for func_name, data in reward_breakdown.items():
|
||||
metrics[f"{func_name}_avg"] = data["avg"]
|
||||
metrics[f"{func_name}_correct_rate"] = data["correct"] / total
|
||||
|
||||
# Print results summary
|
||||
print(f"\n{'=' * 60}")
|
||||
print("Verifiers Evaluation Results")
|
||||
|
|
@ -233,114 +295,25 @@ class VerifiersEval(EvalBase):
|
|||
}
|
||||
)
|
||||
|
||||
# Use EvalBase's evaluate_log
|
||||
task_name = f"VerifiersEval@{self.vf_env_name.replace('/', '_')}"
|
||||
evaluate_log(
|
||||
# Use BaseEnv's evaluate_log
|
||||
task_name = f"VerifiersEval@{self.config.vf_env_name.replace('/', '_')}"
|
||||
await self.evaluate_log(
|
||||
metrics=metrics,
|
||||
eval_dir=getattr(self, "eval_dir", None),
|
||||
task_name=task_name,
|
||||
model_name=model,
|
||||
model_name=model_name,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
generation_parameters=self.get_generation_params(),
|
||||
generation_parameters={
|
||||
"temperature": self.config.eval_temperature,
|
||||
"max_tokens": self.config.eval_max_tokens,
|
||||
"n": 1,
|
||||
},
|
||||
samples=samples,
|
||||
verbose=getattr(self, "verbose", False),
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
async def main():
|
||||
"""CLI entry point for verifiers evaluation."""
|
||||
import os
|
||||
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Evaluate models using Verifiers environments"
|
||||
)
|
||||
# Server args (same as eval_runner)
|
||||
parser.add_argument(
|
||||
"--server-url",
|
||||
type=str,
|
||||
default="http://localhost:8000/v1",
|
||||
help="URL of the inference server",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model name to evaluate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api-key",
|
||||
type=str,
|
||||
default=os.getenv("OPENAI_API_KEY", "x"),
|
||||
help="API key (defaults to OPENAI_API_KEY env var)",
|
||||
)
|
||||
# Verifiers-specific args
|
||||
parser.add_argument(
|
||||
"--vf-env-name",
|
||||
type=str,
|
||||
default="primeintellect/gsm8k",
|
||||
help="Verifiers environment name (e.g., primeintellect/gsm8k)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-eval-items",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Maximum items to evaluate (-1 for all)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Generation temperature",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Maximum tokens per completion",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-concurrent",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Maximum concurrent requests",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to save evaluation results",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create server manager
|
||||
server_manager = ServerManager(
|
||||
configs=[
|
||||
APIServerConfig(
|
||||
api_key=args.api_key,
|
||||
base_url=args.server_url,
|
||||
model_name=args.model_name,
|
||||
health_check=False,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# Create and run evaluation
|
||||
eval_instance = VerifiersEval(
|
||||
vf_env_name=args.vf_env_name,
|
||||
max_eval_items=args.max_eval_items,
|
||||
temperature=args.temperature,
|
||||
max_tokens=args.max_tokens,
|
||||
max_concurrent=args.max_concurrent,
|
||||
eval_dir=args.eval_dir,
|
||||
verbose=True,
|
||||
)
|
||||
return await eval_instance(server_manager)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
VerifiersEvalEnv.cli()
|
||||
|
|
|
|||
|
|
@ -46,65 +46,11 @@ from atroposlib.envs.base import (
|
|||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServerAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ManagedServerAdapter:
|
||||
"""
|
||||
Adapter that makes ManagedServer look like AsyncOpenAI for verifiers.
|
||||
|
||||
Implements the subset of AsyncOpenAI interface that verifiers uses:
|
||||
- client.chat.completions.create()
|
||||
- client.completions.create()
|
||||
- client.base_url
|
||||
"""
|
||||
|
||||
def __init__(self, managed_server: ManagedServer, base_url: str):
|
||||
self._managed = managed_server
|
||||
self.base_url = base_url
|
||||
self.chat = self._ChatNamespace(self._managed)
|
||||
self.completions = self._CompletionsNamespace(self._managed)
|
||||
|
||||
class _ChatNamespace:
|
||||
def __init__(self, managed: ManagedServer):
|
||||
self._managed = managed
|
||||
self.completions = ManagedServerAdapter._ChatCompletionsNamespace(managed)
|
||||
|
||||
class _ChatCompletionsNamespace:
|
||||
def __init__(self, managed: ManagedServer):
|
||||
self._managed = managed
|
||||
|
||||
async def create(self, **kwargs):
|
||||
logger.info(
|
||||
"ManagedServerAdapter.chat.completions.create called with model=%s",
|
||||
kwargs.get("model"),
|
||||
)
|
||||
result = await self._managed.chat_completion(**kwargs)
|
||||
logger.info("ManagedServerAdapter.chat.completions.create completed")
|
||||
return result
|
||||
|
||||
class _CompletionsNamespace:
|
||||
def __init__(self, managed: ManagedServer):
|
||||
self._managed = managed
|
||||
|
||||
async def create(self, **kwargs):
|
||||
return await self._managed.completion(**kwargs)
|
||||
|
||||
async def post(self, path: str, body: dict, cast_to: type):
|
||||
raise NotImplementedError(
|
||||
f"ManagedServerAdapter does not support post() for path '{path}'. "
|
||||
"This is used for vLLM interleaved rollouts. Use standard chat completions."
|
||||
)
|
||||
|
||||
def copy(self, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"ManagedServerAdapter does not support copy(). "
|
||||
"This is used for vLLM tokenization endpoints."
|
||||
)
|
||||
|
||||
|
||||
class VfEnvConfig(BaseEnvConfig):
|
||||
vf_env_name: str = ""
|
||||
env_args: str = "{}"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue