diff --git a/README.md b/README.md index 4cf96783..7bddf763 100644 --- a/README.md +++ b/README.md @@ -151,6 +151,7 @@ If you're looking to get into developing the repo or using the environments: pip install -e . # for using pip install -e .[dev] # for development pip install -e .[examples] # for running examples +pip install -e .[verifiers] # for verifiers integration pip install -e .[all] # for everything ``` diff --git a/atroposlib/envs/server_handling/managed_server.py b/atroposlib/envs/server_handling/managed_server.py index cbc97268..7b26b5c1 100644 --- a/atroposlib/envs/server_handling/managed_server.py +++ b/atroposlib/envs/server_handling/managed_server.py @@ -509,3 +509,63 @@ class ManagedServer: self.sequences.clear() else: self.current_nodes.clear() + + +class ManagedServerAdapter: + """ + Adapter that makes ManagedServer look like AsyncOpenAI for external libraries. + + Implements the subset of AsyncOpenAI interface commonly used: + - client.chat.completions.create() + - client.completions.create() + - client.base_url + + This allows libraries like verifiers to use ManagedServer transparently + while still getting automatic token and logprob tracking. + """ + + def __init__(self, managed_server: ManagedServer, base_url: str): + """ + Initialize the adapter. + + Args: + managed_server: The ManagedServer instance to wrap + base_url: The base URL to expose (for compatibility checks) + """ + self._managed = managed_server + self.base_url = base_url + self.chat = self._ChatNamespace(self._managed) + self.completions = self._CompletionsNamespace(self._managed) + + class _ChatNamespace: + def __init__(self, managed: ManagedServer): + self._managed = managed + self.completions = ManagedServerAdapter._ChatCompletionsNamespace(managed) + + class _ChatCompletionsNamespace: + def __init__(self, managed: ManagedServer): + self._managed = managed + + async def create(self, **kwargs): + return await self._managed.chat_completion(**kwargs) + + class _CompletionsNamespace: + def __init__(self, managed: ManagedServer): + self._managed = managed + + async def create(self, **kwargs): + return await self._managed.completion(**kwargs) + + async def post(self, path: str, body: dict, cast_to: type): + """Not supported - raises NotImplementedError.""" + raise NotImplementedError( + f"ManagedServerAdapter does not support post() for path '{path}'. " + "This is used for vLLM interleaved rollouts. Use standard chat completions." + ) + + def copy(self, **kwargs): + """Not supported - raises NotImplementedError.""" + raise NotImplementedError( + "ManagedServerAdapter does not support copy(). " + "This is used for vLLM tokenization endpoints." + ) diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py index e9a75766..41fec651 100644 --- a/atroposlib/envs/server_handling/vllm_server.py +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -177,6 +177,8 @@ class VLLMServer(APIServer): prompt_tokens = prompt_tokens[1:] if "max_new_tokens" in kwargs: kwargs["max_tokens"] = kwargs.pop("max_new_tokens") + if "max_completion_tokens" in kwargs: + kwargs["max_tokens"] = kwargs.pop("max_completion_tokens") if "model" in kwargs: kwargs.pop("model") # Prepare request for VLLM native API diff --git a/environments/README.md b/environments/README.md index 79201afb..ee700dde 100644 --- a/environments/README.md +++ b/environments/README.md @@ -11,7 +11,159 @@ This directory contains various environments for training and evaluating languag --- +### Prime Intellect Verifiers Integration +A flexible environment that integrates with the [Verifiers](https://docs.primeintellect.ai/) ecosystem, allowing you to use any registered Prime environment for RL training, SFT data generation, or evaluation. + +**Files:** +- `environments/verifiers_server.py` - Training and SFT data generation +- `environments/eval_environments/verifiers_eval.py` - Standalone evaluation + +**Dependencies:** + +- `verifiers` Python package (install via `pip install verifiers` or include in your environment) +- Prime CLI for environment management (`uv tool install prime`) +- Prime CLI login required (`prime login`) +- Environment installation (`prime env install owner/env_name`) + +**Supported Modes:** + +| Mode | File | Description | +|------|------|-------------| +| `serve` | `verifiers_server.py` | RL training with local inference server (requires ManagedServer for logprobs) | +| `process` | `verifiers_server.py` | SFT data generation with ANY API (OpenAI, Claude, local, etc.) | +| `evaluate` | `verifiers_server.py` | Quick evaluation using ManagedServer | +| `evaluate` | `verifiers_eval.py` | Standalone evaluation with detailed metrics and retry logic | + +**Input Format:** + +- Loaded dynamically from the specified Prime environment via `vf.load_environment()` +- Each item contains: + - `question`: The problem/prompt + - `answer`: The expected answer for verification + +**System Prompt:** + +- Dynamically loaded from the Prime environment's `system_prompt` configuration + +**Reward Function:** + +- Uses the environment's **rubric** system with: + - `parser`: Extracts answers from completions (e.g., `parser.parse_answer(completion)`) + - `funcs`: List of reward functions that receive `(parser, completion, answer)` + - `weights`: Weights for combining reward functions (normalized to sum to 1.0) +- Final score is weighted sum of all reward function outputs + +**W&B Metrics Logged (Training - `verifiers_server.py`):** + +| Metric | Description | +|--------|-------------| +| `train/percent_correct` | Average score from verifiers reward functions (0-1) | +| `train/rollouts` | Table of tokenized completions with scores | +| `train/completion_lengths_*` | Response length statistics (std, min, max, p95) | +| `server/server_0_request_time_*` | API latency metrics (avg, std, 99p) | +| `eval/avg_total_score` | Average score on evaluation dataset | + +**Output (Evaluation - `verifiers_eval.py`):** + +Uses `evaluate_log()` from `EvalBase` to output: +- Console: Metrics table with accuracy, avg_score, time, and per-reward function breakdown +- File: `metrics.json` and `samples.jsonl` (when `--env.data_dir_to_save_evals` is specified) + +**Configuration Options (`VfEnvConfig` for `verifiers_server.py`):** + +| Option | Type | Default | Description | +|--------|------|---------|-------------| +| `vf_env_name` | str | `""` | Prime environment identifier (e.g., `"will/wordle"`, `"primeintellect/gsm8k"`) | +| `env_args` | Dict | `{}` | Additional arguments passed to `vf.load_environment()`. Read environment specific documentation to get these args. | + +**CLI Options (`verifiers_eval.py`):** + +Uses a simple argparse CLI with direct arguments: + +| Option | Type | Default | Description | +|--------|------|---------|-------------| +| `--server-url` | str | `http://localhost:8000/v1` | URL of the inference server | +| `--model-name` | str | (required) | Model name to evaluate | +| `--api-key` | str | `$OPENAI_API_KEY` | API key (uses env var if not specified) | +| `--vf-env-name` | str | `primeintellect/gsm8k` | Prime environment identifier | +| `--temperature` | float | `0.0` | Temperature for generation | +| `--max-tokens` | int | `2048` | Maximum tokens per completion | +| `--max-eval-items` | int | `-1` | Maximum items to evaluate (-1 for all) | +| `--max-concurrent` | int | `64` | Maximum concurrent requests | +| `--eval-dir` | str | `None` | Directory to save evaluation results | + +**Usage Examples:** + +```bash +# RL Training (requires local vLLM/SGLang server) +python verifiers_server.py serve \ + --env.vf_env_name "will/wordle" \ + --openai.base_url http://localhost:9001/v1 \ + --slurm false + +# SFT Data Generation with OpenAI GPT-4o +python verifiers_server.py process \ + --env.vf_env_name "will/wordle" \ + --env.data_path_to_save_groups gpt4o_sft_data.jsonl \ + --env.total_steps 100 \ + --env.group_size 4 \ + --openai.model_name gpt-4o \ + --openai.base_url https://api.openai.com/v1 + +# SFT Data Generation with local server +python verifiers_server.py process \ + --env.vf_env_name "will/wordle" \ + --env.data_path_to_save_groups local_sft_data.jsonl \ + --openai.base_url http://localhost:9001/v1 + +# Quick Evaluation via verifiers_server.py +python verifiers_server.py evaluate \ + --env.vf_env_name "will/wordle" \ + --openai.base_url http://localhost:9001/v1 + +# Standalone Evaluation with OpenAI (verifiers_eval.py) +python eval_environments/verifiers_eval.py \ + --server-url https://api.openai.com/v1 \ + --model-name gpt-4o \ + --vf-env-name primeintellect/gsm8k + +# Quick test run with limited items +python eval_environments/verifiers_eval.py \ + --server-url https://api.openai.com/v1 \ + --model-name gpt-4o-mini \ + --vf-env-name primeintellect/alphabet-sort \ + --max-eval-items 10 + +# Evaluation with local server and results saved +python eval_environments/verifiers_eval.py \ + --server-url http://localhost:9001/v1 \ + --model-name Qwen/Qwen2.5-7B-Instruct \ + --vf-env-name primeintellect/gsm8k \ + --eval-dir ./eval_results +``` + +**Key Implementation Details:** + +- **RL Training Mode (`serve`)**: Uses `ManagedServer` for proper token/logprob alignment required by policy gradient methods (GRPO, PPO, REINFORCE). Returns `ScoredDataGroup` with `tokens`, `masks`, `scores`, and `inference_logprobs`. +- **SFT Datagen Mode (`process`)**: Uses `tokenize_for_trainer` to tokenize API responses with your target model's tokenizer (e.g., GPT-4o responses tokenized for Qwen/Llama). Does NOT require logprobs. +- **Evaluation (`verifiers_eval.py`)**: Standalone evaluation script using `EvalBase` with simple argparse CLI. Uses verifiers' native batch evaluation with `ManagedServerAdapter` for token/logprob tracking and outputs results via `evaluate_log()`. Works with any OpenAI-compatible API. + +**Prime Environment Installation:** +```bash +# Install Prime CLI +uv tool install prime + +# Login to Prime +prime login + +# Install an environment (e.g., Wordle, GSM8K) +prime env install will/wordle +prime env install primeintellect/gsm8k + +# List available environments +prime env list +``` ### Letter Counting Environment (`letter_counting_environment.py`) diff --git a/environments/configs/verifiers.yaml b/environments/configs/verifiers.yaml new file mode 100644 index 00000000..91ef7ec2 --- /dev/null +++ b/environments/configs/verifiers.yaml @@ -0,0 +1,31 @@ +# Verifiers environment configuration +# Usage: python environments/verifiers_server.py serve --config environments/configs/verifiers.yaml +# +# For SFT data generation with external API: +# python environments/verifiers_server.py process \ +# --env.vf_env_name primeintellect/gsm8k \ +# --env.data_path_to_save_groups output.jsonl \ +# --openai.base_url https://api.openai.com/v1 \ +# --openai.api_key $OPENAI_API_KEY \ +# --openai.model_name gpt-4o + +env: + vf_env_name: "primeintellect/gsm8k" # Prime Env Hub environment + env_args: {} + group_size: 8 + max_token_length: 2048 + tokenizer_name: "Qwen/Qwen2.5-1.5B-Instruct" + rollout_server_url: "http://localhost:8000" + use_wandb: true + wandb_name: "verifiers" + total_steps: 1000 + batch_size: 4 + steps_per_eval: 100 + +openai: + - model_name: "Qwen/Qwen2.5-1.5B-Instruct" + base_url: "http://localhost:9001/v1" + api_key: "x" + +slurm: false +testing: false diff --git a/environments/eval_environments/verifiers_eval.py b/environments/eval_environments/verifiers_eval.py new file mode 100644 index 00000000..559b0fd0 --- /dev/null +++ b/environments/eval_environments/verifiers_eval.py @@ -0,0 +1,369 @@ +""" +Verifiers Evaluation Environment for Atropos + +This environment evaluates models using Prime Intellect's Verifiers library. +It supports any environment registered with the Verifiers ecosystem. + +Uses verifiers' native rollout and scoring machinery - just pass an OpenAI-compatible +client and verifiers handles generation, parsing, and scoring. + +To install a Verifiers/Prime environment: +1. uv tool install prime +2. prime login +3. prime env install will/wordle (or any owner/environment) +Docs: https://docs.primeintellect.ai/tutorials-environments/install + +Usage: + # Evaluate with OpenAI + python verifiers_eval.py \ + --server-url https://api.openai.com/v1 \ + --model-name gpt-4o \ + --vf-env-name primeintellect/gsm8k \ + --max-eval-items 50 + + # Evaluate with local server + python verifiers_eval.py \ + --server-url http://localhost:8000/v1 \ + --model-name Qwen/Qwen2.5-7B-Instruct \ + --vf-env-name primeintellect/gsm8k + +Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.) +""" + +import argparse +import json +import time +from typing import Any, Dict + +import verifiers as vf + +from atroposlib.envs.eval import EvalBase, evaluate_log +from atroposlib.envs.server_handling.managed_server import ManagedServerAdapter +from atroposlib.envs.server_handling.server_baseline import APIServerConfig +from atroposlib.envs.server_handling.server_manager import ServerManager + +# Patch math_verify timeout to work in async context +# The signal-based timeout doesn't work in non-main threads (asyncio event loop) + + +def _no_signal_timeout( + _timeout_seconds: int | None = None, *, timeout_seconds: int | None = None +): + """Replacement timeout decorator that doesn't use signals. + + Accepts both positional arg (timeout(5)) and keyword arg (timeout(timeout_seconds=5)). + """ + # Silence unused parameter warnings - these match the original API signature + del _timeout_seconds, timeout_seconds + + def decorator(func): + def wrapper(*args, **kwargs): + # Just call the function without timeout - safe in async context + return func(*args, **kwargs) + + return wrapper + + return decorator + + +try: + import math_verify.grader + import math_verify.parser + import math_verify.utils + + # Patch all modules that use the timeout decorator + math_verify.utils.timeout = _no_signal_timeout + math_verify.parser.timeout = _no_signal_timeout + math_verify.grader.timeout = _no_signal_timeout +except ImportError: + pass # math_verify not installed + + +class VerifiersEval(EvalBase): + """ + Verifiers Evaluation using EvalBase pattern. + + Uses verifiers' native batch evaluation for efficiency, + with ManagedServerAdapter for token/logprob tracking. + + Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.) + """ + + def __init__( + self, + vf_env_name: str = "primeintellect/gsm8k", + env_args: str = "{}", + temperature: float = 0.0, + max_tokens: int = 2048, + max_eval_items: int = -1, + max_concurrent: int = 64, + **kwargs, + ): + self.vf_env_name = vf_env_name + self.env_args_str = env_args + self.temperature = temperature + self.max_tokens = max_tokens + self.max_eval_items = max_eval_items + self.max_concurrent = max_concurrent + super().__init__(**kwargs) + + def get_env_args(self) -> Dict[str, Any]: + """Parse env_args JSON string into dict.""" + if isinstance(self.env_args_str, dict): + return self.env_args_str + return json.loads(self.env_args_str) + + def setup_data(self) -> list: + """Load verifiers environment and dataset.""" + env_args = self.get_env_args() + self.vf_env = vf.load_environment(self.vf_env_name, **env_args) + self.reward_func_names = self.vf_env.rubric._get_reward_func_names() + + # Load evaluation dataset + dataset = self.vf_env.get_eval_dataset() + if self.max_eval_items > 0: + n = min(len(dataset), self.max_eval_items) + dataset = dataset.select(range(n)) + return dataset.to_list() + + async def run_item(self, server: ServerManager, data_item: dict): + """Not used - verifiers uses batch evaluation in __call__.""" + # This won't be called since we override __call__ + raise NotImplementedError("VerifiersEval uses batch evaluation in __call__") + + async def __call__(self, server_manager: ServerManager): + """Run evaluation using verifiers with ManagedServerAdapter.""" + start_time = time.time() + + # Get server config + server = server_manager.servers[0] + model_name = server.config.model_name + + num_examples = self.max_eval_items if self.max_eval_items > 0 else -1 + + # Use ManagedServer for automatic token/logprob tracking + async with server_manager.managed_server(tokenizer=None) as managed: + # Create adapter that looks like AsyncOpenAI for verifiers + adapter = ManagedServerAdapter( + managed_server=managed, + base_url=server.config.base_url, + ) + + # Use verifiers' batch evaluation + results = await self.vf_env.evaluate( + client=adapter, + model=model_name, + sampling_args={ + "temperature": self.temperature, + "max_tokens": self.max_tokens, + }, + num_examples=num_examples, + max_concurrent=self.max_concurrent, + save_results=False, + ) + + end_time = time.time() + + # Extract from verifiers output + rewards = results["reward"] + per_func_metrics = results["metrics"] + prompts = results["prompt"] + completions = results["completion"] + answers = results["answer"] + + total = len(rewards) + correct = sum(1 for r in rewards if r > 0) + avg_score = sum(rewards) / total if total > 0 else 0.0 + accuracy = correct / total if total > 0 else 0.0 + + # Per-reward function breakdown + reward_breakdown = {} + for func_name, values in per_func_metrics.items(): + if values: + reward_breakdown[func_name] = { + "avg": sum(values) / len(values), + "correct": sum(1 for v in values if v > 0), + } + + metrics = { + "accuracy": accuracy, + "avg_score": avg_score, + } + + # Add per-function metrics + for func_name, data in reward_breakdown.items(): + metrics[f"{func_name}_avg"] = data["avg"] + metrics[f"{func_name}_correct_rate"] = data["correct"] / total + + # Print results summary + print(f"\n{'=' * 60}") + print("Verifiers Evaluation Results") + print(f"{'=' * 60}") + print(f" Average Score: {avg_score:.4f}") + print(f" Accuracy: {accuracy:.2%} ({correct}/{total})") + print(f" Time: {end_time - start_time:.1f}s") + print("\n Per-Reward Function:") + for name, data in reward_breakdown.items(): + print( + f" {name}: avg={data['avg']:.4f}, correct={data['correct']}/{total}" + ) + print(f"{'=' * 60}\n") + + # Build samples for logging + system_prompt = self.vf_env.system_prompt or "" + samples = [] + for i in range(min(total, 100)): # Limit samples for logging + prompt_msgs = prompts[i] if isinstance(prompts[i], list) else [] + completion_msgs = completions[i] if completions[i] else [] + + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.extend(prompt_msgs) + if isinstance(completion_msgs, list): + messages.extend(completion_msgs) + + samples.append( + { + "messages": messages, + "gold_answer": answers[i] if i < len(answers) else "", + "score": rewards[i], + "correct": rewards[i] > 0, + "metrics": { + k: v[i] for k, v in per_func_metrics.items() if i < len(v) + }, + } + ) + + # Log results + task_name = f"VerifiersEval@{self.vf_env_name.replace('/', '_')}" + evaluate_log( + metrics=metrics, + eval_dir=getattr(self, "eval_dir", None), + task_name=task_name, + model_name=model_name, + start_time=start_time, + end_time=end_time, + generation_parameters={ + "temperature": self.temperature, + "max_tokens": self.max_tokens, + "n": 1, + }, + samples=samples, + verbose=getattr(self, "verbose", True), + ) + + return metrics + + +async def main(): + """Run verifiers evaluation with argparse CLI.""" + import os + + parser = argparse.ArgumentParser( + description="Evaluate models using Prime Intellect's Verifiers library" + ) + parser.add_argument( + "--server-url", + type=str, + default="http://localhost:8000/v1", + help="URL of the inference server (default: http://localhost:8000/v1)", + ) + parser.add_argument( + "--model-name", + type=str, + required=True, + help="Model name to evaluate", + ) + parser.add_argument( + "--api-key", + type=str, + default=None, + help="API key (default: uses OPENAI_API_KEY env var)", + ) + parser.add_argument( + "--vf-env-name", + type=str, + default="primeintellect/gsm8k", + help="Verifiers environment name (default: primeintellect/gsm8k)", + ) + parser.add_argument( + "--env-args", + type=str, + default="{}", + help="JSON string of environment-specific args (default: {})", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.0, + help="Generation temperature (default: 0.0)", + ) + parser.add_argument( + "--max-tokens", + type=int, + default=2048, + help="Maximum tokens per completion (default: 2048)", + ) + parser.add_argument( + "--max-eval-items", + type=int, + default=-1, + help="Maximum items to evaluate, -1 for all (default: -1)", + ) + parser.add_argument( + "--max-concurrent", + type=int, + default=64, + help="Maximum concurrent requests (default: 64)", + ) + parser.add_argument( + "--eval-dir", + type=str, + default=None, + help="Directory to save evaluation results (default: None)", + ) + parser.add_argument( + "--verbose", + action="store_true", + default=True, + help="Print verbose output (default: True)", + ) + + args = parser.parse_args() + + # Get API key from args or environment + api_key = args.api_key or os.environ.get("OPENAI_API_KEY", "dummy") + + # Create evaluation instance + eval_env = VerifiersEval( + vf_env_name=args.vf_env_name, + env_args=args.env_args, + temperature=args.temperature, + max_tokens=args.max_tokens, + max_eval_items=args.max_eval_items, + max_concurrent=args.max_concurrent, + eval_dir=args.eval_dir, + verbose=args.verbose, + ) + + # Create server manager + server_manager = ServerManager( + configs=[ + APIServerConfig( + api_key=api_key, + base_url=args.server_url, + model_name=args.model_name, + health_check=False, + ), + ] + ) + + # Run evaluation + return await eval_env(server_manager) + + +if __name__ == "__main__": + import asyncio + + asyncio.run(main()) diff --git a/environments/verifiers_server.py b/environments/verifiers_server.py new file mode 100644 index 00000000..98413d8a --- /dev/null +++ b/environments/verifiers_server.py @@ -0,0 +1,330 @@ +""" +Verifiers Training Environment for Atropos + +Unified environment that works for both RL training (serve) and SFT data generation (process). +Uses vf_env.generate() with ManagedServer (via adapter) for automatic token and logprob tracking. + +Usage: + # RL Training (GRPO - no inference logprobs needed) + python verifiers_server.py serve \ + --env.vf_env_name "primeintellect/alphabet-sort" \ + --openai.base_url http://localhost:9001/v1 \ + --slurm false + + # SFT Data Generation with OpenAI GPT-4o + python verifiers_server.py process \ + --env.vf_env_name "primeintellect/alphabet-sort" \ + --env.data_path_to_save_groups gpt4o_sft_data.jsonl \ + --env.total_steps 100 \ + --env.group_size 4 \ + --openai.model_name gpt-4o \ + --openai.base_url https://api.openai.com/v1 + + # SFT Data Generation with local server + python verifiers_server.py process \ + --env.vf_env_name "primeintellect/alphabet-sort" \ + --env.data_path_to_save_groups local_sft_data.jsonl \ + --openai.base_url http://localhost:9001/v1 + +To install a Verifiers/Prime environment: +1. uv tool install prime +2. prime login +3. prime env install primeintellect/alphabet-sort (or any owner/environment) +Docs: https://docs.primeintellect.ai/tutorials-environments/install +""" + +import json +import logging +from collections import defaultdict +from typing import Any, Dict, List, Optional, Tuple + +import verifiers as vf + +from atroposlib.envs.base import ( + APIServerConfig, + BaseEnv, + BaseEnvConfig, + ScoredDataGroup, +) +from atroposlib.envs.server_handling.managed_server import ManagedServerAdapter + +logger = logging.getLogger(__name__) + + +class VfEnvConfig(BaseEnvConfig): + vf_env_name: str = "" + env_args: str = "{}" + + def get_env_args(self) -> Dict[str, Any]: + """Parse env_args JSON string into dict.""" + if isinstance(self.env_args, dict): + return self.env_args + return json.loads(self.env_args) + + +class VerifiersEnv(BaseEnv): + name = "verifiers" + env_config_cls = VfEnvConfig # type: ignore[assignment] + + def __init__( + self, + config: VfEnvConfig, + server_configs: List[APIServerConfig], + slurm=False, + testing=False, + ): + super().__init__(config, server_configs, slurm, testing) + + # Metrics buffers for wandb logging + self.reward_buffer: List[float] = [] + self.metrics_buffer: Dict[str, List[float]] = defaultdict(list) + self.num_turns_buffer: List[int] = [] + self.groups_with_identical_scores: int = 0 + self.groups_total: int = 0 + + logger.info("Loading verifiers environment: %s", config.vf_env_name) + env_args = config.get_env_args() + if env_args: + logger.info("Environment args: %s", env_args) + self.vf_env = vf.load_environment(config.vf_env_name, **env_args) + self.rubric = self.vf_env.rubric + self.system_prompt = self.vf_env.system_prompt + + # Get reward function names for metrics reporting + self.reward_func_names = self.rubric._get_reward_func_names() + logger.info("Reward functions: %s", self.reward_func_names) + + # Log multi-turn config if available + if hasattr(self.vf_env, "max_turns"): + logger.info("Max turns: %d", self.vf_env.max_turns) + + @classmethod + def config_init(cls) -> Tuple[VfEnvConfig, List[APIServerConfig]]: + env_config = VfEnvConfig( + tokenizer_name="Qwen/Qwen2.5-1.5B-Instruct", + group_size=8, + use_wandb=True, + rollout_server_url="http://localhost:8000", + total_steps=1000, + batch_size=4, + steps_per_eval=100, + max_token_length=2048, + wandb_name="verifiers", + ) + server_configs = [ + APIServerConfig( + model_name="Qwen/Qwen2.5-1.5B-Instruct", + base_url="http://localhost:9001/v1", + api_key="x", + num_requests_for_eval=4, + server_type="sglang", + ), + ] + return env_config, server_configs + + async def wandb_log(self, wandb_metrics: Optional[Dict] = None): + """Enhanced wandb logging with verifiers-specific metrics.""" + if wandb_metrics is None: + wandb_metrics = {} + + # Log mean reward across all rollouts + if self.reward_buffer: + wandb_metrics["metrics/mean_reward"] = sum(self.reward_buffer) / len( + self.reward_buffer + ) + wandb_metrics["metrics/reward_std"] = ( + ( + sum( + (r - wandb_metrics["metrics/mean_reward"]) ** 2 + for r in self.reward_buffer + ) + / len(self.reward_buffer) + ) + ** 0.5 + if len(self.reward_buffer) > 1 + else 0.0 + ) + self.reward_buffer = [] + + # Log per-reward-function metrics (e.g., strict_accuracy, format_score) + if self.metrics_buffer: + for metric_name, values in self.metrics_buffer.items(): + if values: + avg_metric = sum(values) / len(values) + wandb_metrics[f"metrics/{metric_name}"] = avg_metric + self.metrics_buffer = defaultdict(list) + + # Log multi-turn statistics + if self.num_turns_buffer: + wandb_metrics["metrics/avg_num_turns"] = sum(self.num_turns_buffer) / len( + self.num_turns_buffer + ) + wandb_metrics["metrics/max_num_turns"] = max(self.num_turns_buffer) + self.num_turns_buffer = [] + + # Log group filtering statistics (helpful for debugging) + if self.groups_total > 0: + wandb_metrics["metrics/groups_with_identical_scores"] = ( + self.groups_with_identical_scores + ) + wandb_metrics["metrics/groups_total"] = self.groups_total + wandb_metrics["metrics/identical_score_rate"] = ( + self.groups_with_identical_scores / self.groups_total + ) + # Reset counters + self.groups_with_identical_scores = 0 + self.groups_total = 0 + + await super().wandb_log(wandb_metrics) + + async def setup(self): + # Dataset already has: prompt, answer, info, example_id, task + train_data = self.vf_env.get_dataset() + self.train = train_data.to_list() + self.iter = 0 + + def save_checkpoint(self, step, data=None): + if data is None: + data = {} + data["iter"] = self.iter + super().save_checkpoint(step, data) + + async def get_next_item(self): + next_item = self.train[self.iter % len(self.train)] + self.iter += 1 + return next_item + + async def evaluate(self) -> Dict[str, float]: + """No-op. Use environments/eval_environments/verifiers_eval.py for evaluation.""" + return {} + + async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, list]: + """Unified trajectory collection using vf_env.generate() with ManagedServer. + + Works for both RL training (serve) and SFT data generation (process). + Uses ManagedServer adapter for automatic token and logprob tracking. + """ + # Get server config (handle both real servers and test harness) + if hasattr(self.server, "servers") and self.server.servers: + server_config = self.server.servers[0].config + else: + # Fallback for testing + server_config = APIServerConfig( + model_name=self.config.tokenizer_name, + base_url="http://localhost:8000/v1", + ) + + # Build inputs for group_size rollouts + inputs = [ + { + "prompt": item["prompt"], + "answer": item.get("answer", ""), + "example_id": item["example_id"], + "task": item.get("task", self.config.vf_env_name), + "info": item.get("info", {}), + } + for _ in range(self.config.group_size) + ] + + # Use ManagedServer for automatic token/logprob tracking + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + # Create adapter that looks like AsyncOpenAI for verifiers + adapter = ManagedServerAdapter( + managed_server=managed, + base_url=server_config.base_url, + ) + + # Use vf_env.generate() - handles batching and scoring internally + results = await self.vf_env.generate( + inputs=inputs, + client=adapter, + model=server_config.model_name, + sampling_args={ + "temperature": 1.0, + "max_completion_tokens": self.config.max_token_length, + }, + max_concurrent=self.config.group_size, + max_concurrent_scoring=self.config.group_size, + save_results=False, + independent_scoring=True, + ) + + # Get tracked state from ManagedServer + managed_state = managed.get_state() + nodes = managed_state["nodes"] + + scored_data: ScoredDataGroup = { + "tokens": [], + "masks": [], + "scores": [], + "messages": [], + "inference_logprobs": [], + } + + # Zip verifiers states with ManagedServer nodes for logprob tracking + for i, vf_state in enumerate(results["state"]): + # Extract messages from state + messages = list(vf_state.get("prompt", [])) + list( + vf_state.get("completion", []) + ) + messages = [ + {**msg, "content": msg.get("content") or ""} for msg in messages + ] + + # Get trajectory for metrics + trajectory = vf_state.get("trajectory", []) + + # Get tokens, masks, and logprobs from ManagedServer + # IMPORTANT: We use ManagedServer's tokens (not re-tokenize) to ensure + # alignment with logprobs. ManagedServer tracks tokens and logprobs together. + if i >= len(nodes): + raise RuntimeError( + f"Node count mismatch: expected at least {i + 1} nodes, got {len(nodes)}. " + "ManagedServer should track all rollouts." + ) + + node = nodes[i] + scored_data["tokens"].append(node.tokens) + scored_data["masks"].append(node.masked_tokens) + scored_data["inference_logprobs"].append(node.logprobs) + scored_data["messages"].append(messages) + + reward = vf_state.get("reward", 0.0) + scored_data["scores"].append(reward) + + # Metrics logging + self.reward_buffer.append(reward) + num_turns = len(trajectory) + self.num_turns_buffer.append(num_turns) + logger.debug("Rollout: %d turns, reward=%.3f", num_turns, reward) + + # Per-function metrics from verifiers state + state_metrics = vf_state.get("metrics", {}) + for metric_name, metric_value in state_metrics.items(): + if isinstance(metric_value, (int, float)): + self.metrics_buffer[metric_name].append(float(metric_value)) + + # Log group summary + turns = [len(s.get("trajectory", [])) for s in results["state"]] + logger.info( + "Group: %d rollouts, turns=%s, rewards=%s, nodes=%d", + len(results["state"]), + turns, + [f"{s:.3f}" for s in scored_data["scores"]], + len(nodes), + ) + + # Track identical scores for debugging + self.groups_total += 1 + if len(set(scored_data["scores"])) == 1: + self.groups_with_identical_scores += 1 + logger.debug( + "Group has identical scores (%.3f) - will be filtered by base env", + scored_data["scores"][0], + ) + + return scored_data, [] + + +if __name__ == "__main__": + VerifiersEnv.cli() diff --git a/pyproject.toml b/pyproject.toml index dd3841ee..e61f25a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ dependencies = [ "numpy", "wandb", "gymnasium", - "math-verify==0.7.0", + "math-verify>=0.8.0", "jinja2", "nltk", "rich", @@ -58,6 +58,9 @@ examples = [ "atroposlib[rewardfns]", "langdetect" ] +verifiers = [ + "verifiers==0.1.9.post2" +] [build-system] requires = ["hatchling"]