diff --git a/atroposlib/envs/server_handling/managed_server.py b/atroposlib/envs/server_handling/managed_server.py index cbc97268..7b26b5c1 100644 --- a/atroposlib/envs/server_handling/managed_server.py +++ b/atroposlib/envs/server_handling/managed_server.py @@ -509,3 +509,63 @@ class ManagedServer: self.sequences.clear() else: self.current_nodes.clear() + + +class ManagedServerAdapter: + """ + Adapter that makes ManagedServer look like AsyncOpenAI for external libraries. + + Implements the subset of AsyncOpenAI interface commonly used: + - client.chat.completions.create() + - client.completions.create() + - client.base_url + + This allows libraries like verifiers to use ManagedServer transparently + while still getting automatic token and logprob tracking. + """ + + def __init__(self, managed_server: ManagedServer, base_url: str): + """ + Initialize the adapter. + + Args: + managed_server: The ManagedServer instance to wrap + base_url: The base URL to expose (for compatibility checks) + """ + self._managed = managed_server + self.base_url = base_url + self.chat = self._ChatNamespace(self._managed) + self.completions = self._CompletionsNamespace(self._managed) + + class _ChatNamespace: + def __init__(self, managed: ManagedServer): + self._managed = managed + self.completions = ManagedServerAdapter._ChatCompletionsNamespace(managed) + + class _ChatCompletionsNamespace: + def __init__(self, managed: ManagedServer): + self._managed = managed + + async def create(self, **kwargs): + return await self._managed.chat_completion(**kwargs) + + class _CompletionsNamespace: + def __init__(self, managed: ManagedServer): + self._managed = managed + + async def create(self, **kwargs): + return await self._managed.completion(**kwargs) + + async def post(self, path: str, body: dict, cast_to: type): + """Not supported - raises NotImplementedError.""" + raise NotImplementedError( + f"ManagedServerAdapter does not support post() for path '{path}'. " + "This is used for vLLM interleaved rollouts. Use standard chat completions." + ) + + def copy(self, **kwargs): + """Not supported - raises NotImplementedError.""" + raise NotImplementedError( + "ManagedServerAdapter does not support copy(). " + "This is used for vLLM tokenization endpoints." + ) diff --git a/environments/README.md b/environments/README.md index 5473460b..67bcb6f6 100644 --- a/environments/README.md +++ b/environments/README.md @@ -66,9 +66,9 @@ A flexible environment that integrates with the [Verifiers](https://docs.primein **Output (Evaluation - `verifiers_eval.py`):** -Uses `evaluate_log()` from `atroposlib.envs.eval` to output: +Uses `evaluate_log()` from `BaseEnv` to output: - Console: Metrics table with accuracy, avg_score, time, and per-reward function breakdown -- File: `metrics.json` and `samples.jsonl` (when `--eval-dir` is specified) +- File: `metrics.json` and `samples.jsonl` (when `--env.data_dir_to_save_evals` is specified) **Configuration Options (`VfEnvConfig` for `verifiers_server.py`):** @@ -79,17 +79,19 @@ Uses `evaluate_log()` from `atroposlib.envs.eval` to output: **CLI Options (`verifiers_eval.py`):** +Uses the standard BaseEnv CLI pattern with `evaluate` subcommand. Key options: + | Option | Type | Default | Description | |--------|------|---------|-------------| -| `--server-url` | str | `http://localhost:8000/v1` | URL of the inference server | -| `--model-name` | str | (required) | Model name to evaluate | -| `--api-key` | str | `$OPENAI_API_KEY` | API key (defaults to env var) | -| `--vf-env-name` | str | `primeintellect/gsm8k` | Prime environment identifier | -| `--temperature` | float | `0.0` | Temperature for generation | -| `--max-tokens` | int | `2048` | Maximum tokens per completion | -| `--max-eval-items` | int | `-1` | Maximum items to evaluate (-1 for all) | -| `--max-concurrent` | int | `64` | Maximum concurrent requests | -| `--eval-dir` | str | `None` | Directory to save evaluation results | +| `--openai.base_url` | str | `http://localhost:9001/v1` | URL of the inference server | +| `--openai.model_name` | str | `Qwen/Qwen2.5-1.5B-Instruct` | Model name to evaluate | +| `--openai.api_key` | str | `x` | API key | +| `--env.vf_env_name` | str | `primeintellect/gsm8k` | Prime environment identifier | +| `--env.eval_temperature` | float | `0.0` | Temperature for generation | +| `--env.eval_max_tokens` | int | `2048` | Maximum tokens per completion | +| `--env.max_eval_items` | int | `-1` | Maximum items to evaluate (-1 for all) | +| `--env.max_concurrent` | int | `64` | Maximum concurrent requests | +| `--env.data_dir_to_save_evals` | str | `None` | Directory to save evaluation results | **Usage Examples:** @@ -121,31 +123,33 @@ python verifiers_server.py evaluate \ --openai.base_url http://localhost:9001/v1 # Standalone Evaluation with OpenAI (verifiers_eval.py) -python eval_environments/verifiers_eval.py \ - --server-url https://api.openai.com/v1 \ - --model-name gpt-4o \ - --vf-env-name primeintellect/gsm8k +python eval_environments/verifiers_eval.py evaluate \ + --openai.base_url https://api.openai.com/v1 \ + --openai.api_key $OPENAI_API_KEY \ + --openai.model_name gpt-4o \ + --env.vf_env_name primeintellect/gsm8k # Quick test run with limited items -python eval_environments/verifiers_eval.py \ - --server-url https://api.openai.com/v1 \ - --model-name gpt-4o-mini \ - --vf-env-name primeintellect/alphabet-sort \ - --max-eval-items 10 +python eval_environments/verifiers_eval.py evaluate \ + --openai.base_url https://api.openai.com/v1 \ + --openai.api_key $OPENAI_API_KEY \ + --openai.model_name gpt-4o-mini \ + --env.vf_env_name primeintellect/alphabet-sort \ + --env.max_eval_items 10 # Evaluation with local server and results saved -python eval_environments/verifiers_eval.py \ - --server-url http://localhost:9001/v1 \ - --model-name Qwen/Qwen2.5-7B-Instruct \ - --vf-env-name primeintellect/gsm8k \ - --eval-dir ./eval_results +python eval_environments/verifiers_eval.py evaluate \ + --openai.base_url http://localhost:9001/v1 \ + --openai.model_name Qwen/Qwen2.5-7B-Instruct \ + --env.vf_env_name primeintellect/gsm8k \ + --env.data_dir_to_save_evals ./eval_results ``` **Key Implementation Details:** - **RL Training Mode (`serve`)**: Uses `ManagedServer` for proper token/logprob alignment required by policy gradient methods (GRPO, PPO, REINFORCE). Returns `ScoredDataGroup` with `tokens`, `masks`, `scores`, and `inference_logprobs`. - **SFT Datagen Mode (`process`)**: Uses `tokenize_for_trainer` to tokenize API responses with your target model's tokenizer (e.g., GPT-4o responses tokenized for Qwen/Llama). Does NOT require logprobs. -- **Evaluation (`verifiers_eval.py`)**: Standalone evaluation script using `EvalBase` pattern. Uses verifiers' native batch evaluation for efficiency and outputs results via `evaluate_log()`. Works with any OpenAI-compatible API. +- **Evaluation (`verifiers_eval.py`)**: Standalone evaluation script using `BaseEnv` pattern with `evaluate` subcommand. Uses verifiers' native batch evaluation for efficiency and outputs results via `evaluate_log()`. Works with any OpenAI-compatible API. **Prime Environment Installation:** ```bash diff --git a/environments/eval_environments/verifiers_eval.py b/environments/eval_environments/verifiers_eval.py index dd21e492..3100cf04 100644 --- a/environments/eval_environments/verifiers_eval.py +++ b/environments/eval_environments/verifiers_eval.py @@ -14,31 +14,53 @@ To install a Verifiers/Prime environment: Docs: https://docs.primeintellect.ai/tutorials-environments/install Usage: - python verifiers_eval.py \ - --server-url http://localhost:8000/v1 \ - --model-name Qwen/Qwen2.5-7B-Instruct \ - --vf-env-name primeintellect/gsm8k \ - --max-eval-items 100 + # Evaluate with local server + python verifiers_eval.py evaluate \ + --env.vf_env_name "primeintellect/gsm8k" \ + --env.max_eval_items 100 \ + --openai.model_name "Qwen/Qwen2.5-7B-Instruct" \ + --openai.base_url "http://localhost:8000/v1" + + # Evaluate with OpenAI + python verifiers_eval.py evaluate \ + --env.vf_env_name "primeintellect/gsm8k" \ + --env.max_eval_items 50 \ + --openai.model_name "gpt-4o" \ + --openai.api_key "$OPENAI_API_KEY" \ + --openai.base_url "https://api.openai.com/v1" Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.) """ -import argparse -import asyncio +import json import time -from typing import Tuple +from typing import Any, Dict, List, Tuple import verifiers as vf -from openai import AsyncOpenAI -from atroposlib.envs.eval import EvalBase, evaluate_log -from atroposlib.envs.server_handling.server_manager import ServerManager +from atroposlib.envs.base import ( + APIServerConfig, + BaseEnv, + BaseEnvConfig, + ScoredDataGroup, +) +# Import ManagedServerAdapter from shared location +from atroposlib.envs.server_handling.managed_server import ManagedServerAdapter # Patch math_verify timeout to work in async context # The signal-based timeout doesn't work in non-main threads (asyncio event loop) -def _no_signal_timeout(_timeout_seconds: int): - """Replacement timeout decorator that doesn't use signals.""" + + +def _no_signal_timeout( + _timeout_seconds: int | None = None, *, timeout_seconds: int | None = None +): + """Replacement timeout decorator that doesn't use signals. + + Accepts both positional arg (timeout(5)) and keyword arg (timeout(timeout_seconds=5)). + """ + # Silence unused parameter warnings - these match the original API signature + del _timeout_seconds, timeout_seconds def decorator(func): def wrapper(*args, **kwargs): @@ -63,108 +85,143 @@ except ImportError: pass # math_verify not installed -class VerifiersEval(EvalBase): +class VfEvalConfig(BaseEnvConfig): + """Configuration for Verifiers evaluation environment.""" + + vf_env_name: str = "primeintellect/gsm8k" + env_args: str = "{}" # JSON string for environment-specific args + eval_temperature: float = 0.0 + eval_max_tokens: int = 2048 + max_eval_items: int = -1 # -1 means evaluate all items + max_concurrent: int = 64 + + # Override BaseEnvConfig defaults for eval mode + group_size: int = 1 + total_steps: int = 1 + steps_per_eval: int = 1 + use_wandb: bool = True + + def get_env_args(self) -> Dict[str, Any]: + """Parse env_args JSON string into dict.""" + if isinstance(self.env_args, dict): + return self.env_args + return json.loads(self.env_args) + + +class VerifiersEvalEnv(BaseEnv): """ - Verifiers Evaluation using EvalBase pattern. + Verifiers Evaluation Environment using BaseEnv pattern. Uses verifiers' native batch evaluation for efficiency, - with EvalBase's standardized logging via evaluate_log(). + with BaseEnv's standardized logging via evaluate_log(). Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.) """ - def __init__( - self, - vf_env_name: str = "primeintellect/gsm8k", - env_args: dict = None, - temperature: float = 0.0, - max_tokens: int = 2048, - max_eval_items: int = -1, - max_concurrent: int = 64, - eval_dir: str = None, - verbose: bool = True, - **kwargs, - ): - self.vf_env_name = vf_env_name - self.env_args = env_args or {} - self.temperature = temperature - self.max_tokens = max_tokens - self.max_eval_items = max_eval_items - self.max_concurrent = max_concurrent + name = "verifiers_eval" + env_config_cls = VfEvalConfig # type: ignore[assignment] - # Load verifiers environment - self.vf_env = vf.load_environment(vf_env_name, **self.env_args) + @classmethod + def config_init(cls) -> Tuple[VfEvalConfig, List[APIServerConfig]]: + """Return default configurations.""" + env_config = VfEvalConfig( + tokenizer_name="Qwen/Qwen2.5-1.5B-Instruct", + vf_env_name="primeintellect/gsm8k", + wandb_name="verifiers_eval", + ) + server_configs = [ + APIServerConfig( + model_name="Qwen/Qwen2.5-1.5B-Instruct", + base_url="http://localhost:9001/v1", + api_key="x", + ), + ] + return env_config, server_configs + + async def setup(self): + """Load verifiers environment and dataset.""" + env_args = self.config.get_env_args() + self.vf_env = vf.load_environment(self.config.vf_env_name, **env_args) self.reward_func_names = self.vf_env.rubric._get_reward_func_names() - # Initialize EvalBase (calls setup_data) - super().__init__( - eval_dir=eval_dir, - verbose=verbose, - **kwargs, - ) - - def get_generation_params(self): - """Generation params for logging.""" - return { - "temperature": self.temperature, - "max_tokens": self.max_tokens, - "n": 1, - } - - def setup_data(self) -> list: - """Return evaluation dataset from verifiers environment.""" + # Load evaluation dataset dataset = self.vf_env.get_eval_dataset() - if self.max_eval_items > 0: - n = min(len(dataset), self.max_eval_items) + if self.config.max_eval_items > 0: + n = min(len(dataset), self.config.max_eval_items) dataset = dataset.select(range(n)) - return dataset.to_list() + self.data = dataset.to_list() - async def run_item( - self, server: ServerManager, data_item: dict # noqa: ARG002 - ) -> Tuple[dict, list]: - """Not used - we override __call__ for batch evaluation.""" - raise NotImplementedError( - "VerifiersEval uses batch evaluation via __call__, not per-item run_item" + async def get_next_item(self): + """Not used in eval mode - stub implementation.""" + return None + + async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, list]: + """Not used in eval mode - stub implementation.""" + _ = item # unused in eval mode + return ( + ScoredDataGroup( + tokens=[], + masks=[], + scores=[], + messages=[], + inference_logprobs=[], + advantages=[], + ref_logprobs=[], + generation_params=None, + group_overrides=None, + overrides=[], + images=[], + ), + [], ) - async def __call__(self, server_manager: ServerManager): - """Run evaluation using verifiers' native batch machinery.""" + async def evaluate(self, *args, **kwargs) -> Dict[str, float]: + """Run evaluation using verifiers with ManagedServer.""" start_time = time.time() - # Create OpenAI client from server config - server = server_manager.servers[0] - client = AsyncOpenAI( - api_key=server.config.api_key or "x", - base_url=server.config.base_url, - timeout=getattr(server.config, "timeout", 600), - ) - model = server.config.model_name + # Get server config + if hasattr(self.server, "servers") and self.server.servers: + server_config = self.server.servers[0].config + else: + server_config = self.server_configs[0] + + model_name = server_config.model_name print(f"\n{'=' * 60}") - print(f"Verifiers Evaluation: {self.vf_env_name}") + print(f"Verifiers Evaluation: {self.config.vf_env_name}") print(f"{'=' * 60}") - print(f" Model: {model}") + print(f" Model: {model_name}") print(f" Items: {len(self.data)}") print(f" Reward functions: {self.reward_func_names}") - print(f" Temperature: {self.temperature}") - print(f" Max concurrent: {self.max_concurrent}") + print(f" Temperature: {self.config.eval_temperature}") + print(f" Max concurrent: {self.config.max_concurrent}") print(f"{'=' * 60}\n") - num_examples = self.max_eval_items if self.max_eval_items > 0 else -1 - - # Use verifiers' batch evaluation - results = await self.vf_env.evaluate( - client=client, - model=model, - sampling_args={ - "temperature": self.temperature, - "max_tokens": self.max_tokens, - }, - num_examples=num_examples, - max_concurrent=self.max_concurrent, - save_results=False, + num_examples = ( + self.config.max_eval_items if self.config.max_eval_items > 0 else -1 ) + # Use ManagedServer for automatic token/logprob tracking + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + # Create adapter that looks like AsyncOpenAI for verifiers + adapter = ManagedServerAdapter( + managed_server=managed, + base_url=server_config.base_url, + ) + + # Use verifiers' batch evaluation + results = await self.vf_env.evaluate( + client=adapter, + model=model_name, + sampling_args={ + "temperature": self.config.eval_temperature, + "max_tokens": self.config.eval_max_tokens, + }, + num_examples=num_examples, + max_concurrent=self.config.max_concurrent, + save_results=False, + ) + end_time = time.time() # Extract from verifiers output @@ -193,6 +250,11 @@ class VerifiersEval(EvalBase): "avg_score": avg_score, } + # Add per-function metrics + for func_name, data in reward_breakdown.items(): + metrics[f"{func_name}_avg"] = data["avg"] + metrics[f"{func_name}_correct_rate"] = data["correct"] / total + # Print results summary print(f"\n{'=' * 60}") print("Verifiers Evaluation Results") @@ -233,114 +295,25 @@ class VerifiersEval(EvalBase): } ) - # Use EvalBase's evaluate_log - task_name = f"VerifiersEval@{self.vf_env_name.replace('/', '_')}" - evaluate_log( + # Use BaseEnv's evaluate_log + task_name = f"VerifiersEval@{self.config.vf_env_name.replace('/', '_')}" + await self.evaluate_log( metrics=metrics, - eval_dir=getattr(self, "eval_dir", None), task_name=task_name, - model_name=model, + model_name=model_name, start_time=start_time, end_time=end_time, - generation_parameters=self.get_generation_params(), + generation_parameters={ + "temperature": self.config.eval_temperature, + "max_tokens": self.config.eval_max_tokens, + "n": 1, + }, samples=samples, - verbose=getattr(self, "verbose", False), + verbose=True, ) return metrics -async def main(): - """CLI entry point for verifiers evaluation.""" - import os - - from atroposlib.envs.server_handling.server_baseline import APIServerConfig - - parser = argparse.ArgumentParser( - description="Evaluate models using Verifiers environments" - ) - # Server args (same as eval_runner) - parser.add_argument( - "--server-url", - type=str, - default="http://localhost:8000/v1", - help="URL of the inference server", - ) - parser.add_argument( - "--model-name", - type=str, - required=True, - help="Model name to evaluate", - ) - parser.add_argument( - "--api-key", - type=str, - default=os.getenv("OPENAI_API_KEY", "x"), - help="API key (defaults to OPENAI_API_KEY env var)", - ) - # Verifiers-specific args - parser.add_argument( - "--vf-env-name", - type=str, - default="primeintellect/gsm8k", - help="Verifiers environment name (e.g., primeintellect/gsm8k)", - ) - parser.add_argument( - "--max-eval-items", - type=int, - default=-1, - help="Maximum items to evaluate (-1 for all)", - ) - parser.add_argument( - "--temperature", - type=float, - default=0.0, - help="Generation temperature", - ) - parser.add_argument( - "--max-tokens", - type=int, - default=2048, - help="Maximum tokens per completion", - ) - parser.add_argument( - "--max-concurrent", - type=int, - default=64, - help="Maximum concurrent requests", - ) - parser.add_argument( - "--eval-dir", - type=str, - default=None, - help="Directory to save evaluation results", - ) - args = parser.parse_args() - - # Create server manager - server_manager = ServerManager( - configs=[ - APIServerConfig( - api_key=args.api_key, - base_url=args.server_url, - model_name=args.model_name, - health_check=False, - ), - ] - ) - - # Create and run evaluation - eval_instance = VerifiersEval( - vf_env_name=args.vf_env_name, - max_eval_items=args.max_eval_items, - temperature=args.temperature, - max_tokens=args.max_tokens, - max_concurrent=args.max_concurrent, - eval_dir=args.eval_dir, - verbose=True, - ) - return await eval_instance(server_manager) - - if __name__ == "__main__": - asyncio.run(main()) + VerifiersEvalEnv.cli() diff --git a/environments/verifiers_server.py b/environments/verifiers_server.py index 0cd8cf45..98413d8a 100644 --- a/environments/verifiers_server.py +++ b/environments/verifiers_server.py @@ -46,65 +46,11 @@ from atroposlib.envs.base import ( BaseEnvConfig, ScoredDataGroup, ) -from atroposlib.envs.server_handling.managed_server import ManagedServer +from atroposlib.envs.server_handling.managed_server import ManagedServerAdapter logger = logging.getLogger(__name__) -class ManagedServerAdapter: - """ - Adapter that makes ManagedServer look like AsyncOpenAI for verifiers. - - Implements the subset of AsyncOpenAI interface that verifiers uses: - - client.chat.completions.create() - - client.completions.create() - - client.base_url - """ - - def __init__(self, managed_server: ManagedServer, base_url: str): - self._managed = managed_server - self.base_url = base_url - self.chat = self._ChatNamespace(self._managed) - self.completions = self._CompletionsNamespace(self._managed) - - class _ChatNamespace: - def __init__(self, managed: ManagedServer): - self._managed = managed - self.completions = ManagedServerAdapter._ChatCompletionsNamespace(managed) - - class _ChatCompletionsNamespace: - def __init__(self, managed: ManagedServer): - self._managed = managed - - async def create(self, **kwargs): - logger.info( - "ManagedServerAdapter.chat.completions.create called with model=%s", - kwargs.get("model"), - ) - result = await self._managed.chat_completion(**kwargs) - logger.info("ManagedServerAdapter.chat.completions.create completed") - return result - - class _CompletionsNamespace: - def __init__(self, managed: ManagedServer): - self._managed = managed - - async def create(self, **kwargs): - return await self._managed.completion(**kwargs) - - async def post(self, path: str, body: dict, cast_to: type): - raise NotImplementedError( - f"ManagedServerAdapter does not support post() for path '{path}'. " - "This is used for vLLM interleaved rollouts. Use standard chat completions." - ) - - def copy(self, **kwargs): - raise NotImplementedError( - "ManagedServerAdapter does not support copy(). " - "This is used for vLLM tokenization endpoints." - ) - - class VfEnvConfig(BaseEnvConfig): vf_env_name: str = "" env_args: str = "{}"