From ed826de7242f4a76e01c52910049af4bc38dbcb6 Mon Sep 17 00:00:00 2001 From: "balyan.sid@gmail.com" Date: Fri, 9 Jan 2026 14:21:03 +0530 Subject: [PATCH 01/22] wip: verifiers integration --- environments/verifiers_server.py | 232 +++++++++++++++++++++++++++++++ pyproject.toml | 1 + 2 files changed, 233 insertions(+) create mode 100644 environments/verifiers_server.py diff --git a/environments/verifiers_server.py b/environments/verifiers_server.py new file mode 100644 index 00000000..5091dd88 --- /dev/null +++ b/environments/verifiers_server.py @@ -0,0 +1,232 @@ +# To install a Verifiers/Prime environment: +# 1. uv tool install prime +# 2. prime login +# 3. prime env install will/wordle (or any owner/environment) +# Docs: https://docs.primeintellect.ai/tutorials-environments/install + +import os +import time +from typing import List, Tuple + +import verifiers as vf +from tqdm.asyncio import tqdm_asyncio + +from atroposlib.envs.base import ( + APIServerConfig, + BaseEnv, + BaseEnvConfig, + ScoredDataGroup, +) + + +class VfEnvConfig(BaseEnvConfig): + vf_env_name: str = "" + env_args: dict = {} + + +class VerifiersEnv(BaseEnv): + name = "verifiers" + env_config_cls = VfEnvConfig # type: ignore[assignment] + + def __init__( + self, + config: VfEnvConfig, + server_configs: List[APIServerConfig], + slurm=False, + testing=False, + ): + super().__init__(config, server_configs, slurm, testing) + self.eval_metrics = list() + self.vf_env = vf.load_environment(config.vf_env_name, **config.env_args) + self.rubric = self.vf_env.rubric + + self.parser = self.rubric.parser + self.reward_funcs = self.rubric.funcs + self.reward_weights = self.rubric.weights + self.reward_scales = [ + weight / sum(self.reward_weights) for weight in self.reward_weights + ] + self.system_prompt = self.vf_env.system_prompt + + @classmethod + def config_init(cls) -> Tuple[VfEnvConfig, List[APIServerConfig]]: + env_config = VfEnvConfig( + group_size=8, + use_wandb=False, + rollout_server_url="http://localhost:8000", + total_steps=10, + batch_size=4, + steps_per_eval=1, + max_token_length=2048, + wandb_name="verifiers", + ) + server_configs = [ + APIServerConfig( + model_name="gpt-4.1-nano", + base_url=None, + api_key=os.getenv("OPENAI_API_KEY"), + num_requests_for_eval=4, + ), + ] + return env_config, server_configs + + async def setup(self): + self.train = self.vf_env.get_dataset() + test_data = self.vf_env.get_eval_dataset() + self.test = test_data.select_columns(["question", "answer"]).to_list() + self.iter = 0 + + async def rollout_and_score_eval( + self, question: str, answer: str, **kwargs + ) -> dict: + system_prompt = kwargs.get("system_prompt") + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": question}, + ] + + completion = await self.server.chat_completion( + messages=messages, + n=1, + max_tokens=self.config.max_token_length, + temperature=0.0, + ) + + response_content = completion.choices[0].message.content or "" + messages.append({"role": "assistant", "content": response_content}) + + answer_parsed = self.parser.parse_answer(completion=response_content) + + rewards = [] + for func in self.reward_funcs: + reward = func( + parser=self.parser, + completion=messages, + answer=answer, + ) + rewards.append(reward) + weighted_rewards = [r * self.reward_scales[j] for j, r in enumerate(rewards)] + + score = sum(weighted_rewards) + + sample = { + "messages": messages, + "question": question, + "gold_answer": answer, + "model_parsed": str(answer_parsed) if answer_parsed else None, + "score": int(score), + "correct": bool(score), + "finish_reason": completion.choices[0].finish_reason, + } + + return {"score": score, "sample": sample} + + async def evaluate(self, *args, **kwargs): + start_time = time.time() + + eval_tasks = [] + for item in self.test: + eval_tasks.append( + self.rollout_and_score_eval( + item["question"], item["answer"], system_prompt=self.system_prompt + ) + ) + results = await tqdm_asyncio.gather(*eval_tasks) + + scores = [result["score"] for result in results] + samples = [result["sample"] for result in results] + + avg_total_score = sum(scores) / len(scores) + + end_time = time.time() + + self.eval_metrics.append(("eval/avg_total_score", avg_total_score)) + + eval_metrics = {"eval/avg_total_score": avg_total_score} + + await self.evaluate_log( + metrics=eval_metrics, + samples=samples, + start_time=start_time, + end_time=end_time, + generation_parameters={ + "temperature": 0.0, + "max_tokens": self.config.max_token_length, + }, + ) + + return eval_metrics + + async def get_next_item(self): + next_item = self.train[self.iter % len(self.train)] + self.iter += 1 + return next_item + + async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, list]: + question = item["question"] + answer = item["answer"] + + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": question}, + ] + + completions = await self.server.chat_completion( + messages=messages, + n=self.config.group_size, + max_tokens=self.config.max_token_length, + temperature=1.0, + ) + + prompt_text = self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + prompt_tokens = self.tokenizer.encode(prompt_text, add_special_tokens=False) + prompt_len = len(prompt_tokens) + + scores: ScoredDataGroup = { + "tokens": [], + "masks": [], + "scores": [], + "inference_logprobs": [], + } + + for choice in completions.choices: + response = choice.message.content or "" + + # Tokenize full sequence (prompt + completion) + full_text = prompt_text + response + full_tokens = self.tokenizer.encode(full_text, add_special_tokens=False) + + # Create masks: -100 for prompt, actual tokens for completion + masks = [-100] * prompt_len + full_tokens[prompt_len:] + + logprobs = [1.0] * prompt_len + [0.0] * (len(full_tokens) - prompt_len) + + # Score using reward funcs + completion_messages = messages + [ + {"role": "assistant", "content": response} + ] + rewards = [] + for func in self.reward_funcs: + reward = func( + parser=self.parser, + completion=completion_messages, + answer=answer, + ) + rewards.append(reward) + weighted_rewards = [ + r * self.reward_scales[j] for j, r in enumerate(rewards) + ] + score = sum(weighted_rewards) + + scores["tokens"].append(full_tokens) + scores["masks"].append(masks) + scores["inference_logprobs"].append(logprobs) + scores["scores"].append(score) + + return scores, [] + + +if __name__ == "__main__": + VerifiersEnv.cli() diff --git a/pyproject.toml b/pyproject.toml index dd3841ee..653592c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "jsonlines", "pydantic-cli", "hf_transfer", + "verifiers>=0.1.8.post2", ] [project.scripts] From b62c4161300ddee104266aa8af90243a6145bd98 Mon Sep 17 00:00:00 2001 From: "balyan.sid@gmail.com" Date: Fri, 9 Jan 2026 14:37:50 +0530 Subject: [PATCH 02/22] make verifiers deps optional and update README --- README.md | 1 + pyproject.toml | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 4cf96783..7bddf763 100644 --- a/README.md +++ b/README.md @@ -151,6 +151,7 @@ If you're looking to get into developing the repo or using the environments: pip install -e . # for using pip install -e .[dev] # for development pip install -e .[examples] # for running examples +pip install -e .[verifiers] # for verifiers integration pip install -e .[all] # for everything ``` diff --git a/pyproject.toml b/pyproject.toml index 653592c3..0b1ab5ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,6 @@ dependencies = [ "jsonlines", "pydantic-cli", "hf_transfer", - "verifiers>=0.1.8.post2", ] [project.scripts] @@ -59,6 +58,9 @@ examples = [ "atroposlib[rewardfns]", "langdetect" ] +verifiers = [ + "verifiers>=0.1.5.post0" +] [build-system] requires = ["hatchling"] From 9d5cd2b593c0bdbdc366b8c7a09f056938d8aa01 Mon Sep 17 00:00:00 2001 From: "balyan.sid@gmail.com" Date: Fri, 9 Jan 2026 16:18:46 +0530 Subject: [PATCH 03/22] fix: improve verifiers environments consistency and correctness - verifiers_server.py: consistent dataset column selection for train/test, remove redundant comments, preserve float precision for scores - verifiers_eval.py: add env_config_cls, fix constructor signature to match BaseEnv (slurm bool), make stub methods raise NotImplementedError --- .../eval_environments/verifiers_eval.py | 357 ++++++++++++++++++ environments/verifiers_server.py | 134 ++++--- 2 files changed, 441 insertions(+), 50 deletions(-) create mode 100644 environments/eval_environments/verifiers_eval.py diff --git a/environments/eval_environments/verifiers_eval.py b/environments/eval_environments/verifiers_eval.py new file mode 100644 index 00000000..61e17656 --- /dev/null +++ b/environments/eval_environments/verifiers_eval.py @@ -0,0 +1,357 @@ +""" +Verifiers Evaluation Environment for Atropos + +This environment evaluates models using Prime Intellect's Verifiers library. +It supports any environment registered with the Verifiers ecosystem. + +To install a Verifiers/Prime environment: +1. uv tool install prime +2. prime login +3. prime env install will/wordle (or any owner/environment) +Docs: https://docs.primeintellect.ai/tutorials-environments/install + +Usage: + python verifiers_evaluation.py evaluate \ + --env.vf_env_name primeintellect/gsm8k \ + --openai.model_name gpt-4.1-nano \ + --openai.api_key $OPENAI_API_KEY +""" + +import asyncio +import os +import time +from typing import Dict, List, Optional, Tuple + +import verifiers as vf +from pydantic import Field +from tqdm.asyncio import tqdm_asyncio + +from atroposlib.envs.base import ( + APIServerConfig, + BaseEnv, + BaseEnvConfig, +) + + +class VerifiersEvaluationConfig(BaseEnvConfig): + """Configuration for Verifiers evaluation environment.""" + + # Verifiers environment + vf_env_name: str = Field( + default="", + description="Verifiers environment name (e.g., primeintellect/gsm8k)", + ) + env_args: dict = Field( + default_factory=dict, + description="Additional arguments for verifiers environment", + ) + + # Generation parameters + temperature: float = Field( + default=0.0, description="Temperature for generation (0.0 for deterministic)" + ) + max_tokens: int = Field(default=2048, description="Maximum tokens for generation") + + # Retry and debug configuration + max_retries: int = Field( + default=3, description="Maximum retries for failed API calls" + ) + retry_delay: float = Field( + default=1.0, description="Delay between retries in seconds" + ) + min_response_length: int = Field( + default=1, description="Minimum response length to consider valid" + ) + full_debug: bool = Field(default=False, description="Enable full debug output") + + # Override defaults for evaluation mode + group_size: int = 1 + max_num_workers: int = 256 + max_num_workers_per_node: int = 64 + use_wandb: bool = True + rollout_server_url: str = "http://localhost:8000" + total_steps: int = 1 + wandb_name: str = "verifiers_evaluation" + steps_per_eval: int = 1 + + +class VerifiersEvaluationEnv(BaseEnv): + """ + Verifiers Evaluation Environment. + + Evaluates models using Prime Intellect's Verifiers library rubrics. + Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, etc.) + """ + + name = "verifiers_evaluation" + env_config_cls = VerifiersEvaluationConfig # type: ignore[assignment] + + def __init__( + self, + config: VerifiersEvaluationConfig, + server_configs: List[APIServerConfig], + slurm: bool = False, + testing: bool = False, + ): + super().__init__(config, server_configs, slurm, testing) + self.config: VerifiersEvaluationConfig = config + + # Load verifiers environment + self.vf_env = vf.load_environment(config.vf_env_name, **config.env_args) + self.rubric = self.vf_env.rubric + + # Extract rubric components + self.parser = self.rubric.parser + self.reward_funcs = self.rubric.funcs + self.reward_weights = self.rubric.weights + self.reward_scales = [ + weight / sum(self.reward_weights) for weight in self.reward_weights + ] + self.system_prompt = self.vf_env.system_prompt + + # Tracking + self.eval_items: List[Dict] = [] + self._dataset_loaded = False + + @classmethod + def config_init(cls) -> Tuple[VerifiersEvaluationConfig, List[APIServerConfig]]: + """Default configuration for evaluation.""" + env_config = VerifiersEvaluationConfig( + vf_env_name="primeintellect/gsm8k", + temperature=0.0, + max_tokens=2048, + use_wandb=True, + wandb_name="verifiers_evaluation", + ) + server_configs = [ + APIServerConfig( + model_name="gpt-4.1-nano", + base_url=None, + api_key=os.getenv("OPENAI_API_KEY"), + num_requests_for_eval=256, + ), + ] + return env_config, server_configs + + async def setup(self) -> None: + """Initialize the environment and load datasets.""" + if not self._dataset_loaded: + # Load datasets from verifiers environment + test_data = self.vf_env.get_eval_dataset() + self.eval_items = test_data.select_columns(["question", "answer"]).to_list() + self._dataset_loaded = True + + print("\nVerifiers Evaluation Setup:") + print(f" Environment: {self.config.vf_env_name}") + print(f" Reward functions: {len(self.reward_funcs)}") + print(f" Reward weights: {self.reward_weights}") + print(f" Loaded {len(self.eval_items)} evaluation items") + + async def rollout_and_score(self, item: Dict) -> Optional[Dict]: + """ + Run evaluation on a single item and return the result. + + Args: + item: Dict with 'question' and 'answer' keys + + Returns: + Dict with evaluation results or None if failed + """ + question = item["question"] + answer = item["answer"] + + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": question}, + ] + + # Build API call parameters + kwargs = { + "messages": messages, + "temperature": self.config.temperature, + "max_tokens": self.config.max_tokens, + "n": 1, + } + + response_text = "" + for attempt in range(self.config.max_retries): + try: + # Direct API call (no ManagedServer) - eval doesn't need token tracking + response = await self.server.chat_completion(**kwargs) + response_text = response.choices[0].message.content or "" + + if len(response_text) >= self.config.min_response_length: + break + + except Exception as e: + if self.config.full_debug: + print(f" API error (attempt {attempt + 1}): {e}") + if attempt < self.config.max_retries - 1: + await asyncio.sleep(self.config.retry_delay) + continue + + if not response_text: + return None + + # Build completion messages for scoring + completion_messages = messages + [ + {"role": "assistant", "content": response_text} + ] + + # Parse answer + answer_parsed = self.parser.parse_answer(completion=response_text) + + # Score using reward funcs + rewards = [] + for func in self.reward_funcs: + reward = func( + parser=self.parser, + completion=completion_messages, + answer=answer, + ) + rewards.append(reward) + + weighted_rewards = [r * self.reward_scales[j] for j, r in enumerate(rewards)] + score = sum(weighted_rewards) + + if self.config.full_debug: + print("\n--- Item ---") + print(f"Question: {question[:100]}...") + print(f"Gold answer: {answer}") + print(f"Model parsed: {answer_parsed}") + print(f"Rewards: {rewards}") + print(f"Score: {score}") + + return { + "question": question, + "gold_answer": answer, + "response": response_text, + "model_parsed": str(answer_parsed) if answer_parsed else None, + "rewards": rewards, + "weighted_rewards": weighted_rewards, + "score": score, + "correct": bool(score > 0), + } + + async def evaluate(self, *args, **kwargs) -> Dict: + """Run the full evaluation.""" + print(f"\n{'='*60}") + print(f"Starting Verifiers Evaluation: {self.config.vf_env_name}") + print(f"{'='*60}") + print(f" Total questions: {len(self.eval_items)}") + print(f" Temperature: {self.config.temperature}") + print(f"{'='*60}\n") + + start_time = time.time() + + # Create evaluation tasks + tasks = [self.rollout_and_score(item) for item in self.eval_items] + + # Run with progress bar + results = await tqdm_asyncio.gather(*tasks, desc="Evaluating") + + # Filter out failed results + valid_results = [r for r in results if r is not None] + + if not valid_results: + print("Warning: No valid evaluation results obtained") + return {"error": "No valid results", "accuracy": 0.0} + + end_time = time.time() + + # Calculate metrics + total = len(valid_results) + scores = [r["score"] for r in valid_results] + correct = sum(1 for r in valid_results if r["correct"]) + + avg_score = sum(scores) / total if total > 0 else 0.0 + accuracy = correct / total if total > 0 else 0.0 + + # Per-reward function breakdown + reward_breakdown = {} + for i, weight in enumerate(self.reward_weights): + func_rewards = [r["rewards"][i] for r in valid_results] + reward_breakdown[f"reward_func_{i}"] = { + "weight": weight, + "avg": sum(func_rewards) / len(func_rewards), + "correct": sum(1 for r in func_rewards if r > 0), + } + + metrics = { + "avg_score": avg_score, + "accuracy": accuracy, + "total_evaluated": total, + "total_correct": correct, + "reward_breakdown": reward_breakdown, + } + + # Print results + print(f"\n{'='*60}") + print("Verifiers Evaluation Results") + print(f"{'='*60}") + print(f" Average Score: {avg_score:.4f}") + print(f" Accuracy: {accuracy:.2%} ({correct}/{total})") + print(f" Time: {end_time - start_time:.1f}s") + print("\n Per-Reward Function:") + for name, data in reward_breakdown.items(): + print( + f" {name}: avg={data['avg']:.4f}, correct={data['correct']}/{total}" + ) + print(f"{'='*60}\n") + + # Log to evaluate_log + samples = [ + { + "messages": [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": r["question"]}, + {"role": "assistant", "content": r["response"]}, + ], + "question": r["question"], + "gold_answer": r["gold_answer"], + "model_parsed": r["model_parsed"], + "score": r["score"], + "correct": r["correct"], + } + for r in valid_results + ] + + await self.evaluate_log( + metrics={"accuracy": accuracy, "avg_score": avg_score}, + samples=samples, + start_time=start_time, + end_time=end_time, + generation_parameters={ + "temperature": self.config.temperature, + "max_tokens": self.config.max_tokens, + }, + ) + + return metrics + + async def wandb_log(self, wandb_metrics: Optional[Dict] = None) -> None: + """Log metrics to Weights & Biases.""" + if wandb_metrics is None: + wandb_metrics = {} + + # Add config info + wandb_metrics["config/vf_env_name"] = self.config.vf_env_name + wandb_metrics["config/temperature"] = self.config.temperature + wandb_metrics["config/max_tokens"] = self.config.max_tokens + + await super().wandb_log(wandb_metrics) + + # Required abstract method implementations (stubs for evaluation-only mode) + async def get_next_item(self) -> Optional[Dict]: + """Not used in evaluation mode.""" + raise NotImplementedError("get_next_item not supported in evaluation-only mode") + + async def collect_trajectories(self, item) -> Tuple[List, List]: + """Not used in evaluation mode.""" + raise NotImplementedError( + "collect_trajectories not supported in evaluation-only mode" + ) + + +if __name__ == "__main__": + VerifiersEvaluationEnv.cli() diff --git a/environments/verifiers_server.py b/environments/verifiers_server.py index 5091dd88..873cf266 100644 --- a/environments/verifiers_server.py +++ b/environments/verifiers_server.py @@ -1,14 +1,21 @@ +# Verifiers Training Environment for Atropos +# +# NOTE: This environment requires a LOCAL inference server (vLLM, SGLang, TRL) +# for ALL modes (serve, process, evaluate) because it uses ManagedServer for +# token/logprob tracking. For evaluation with OpenAI API, use: +# environments/eval_environments/verifiers_eval.py +# # To install a Verifiers/Prime environment: # 1. uv tool install prime # 2. prime login # 3. prime env install will/wordle (or any owner/environment) # Docs: https://docs.primeintellect.ai/tutorials-environments/install -import os import time -from typing import List, Tuple +from typing import Any, Dict, List, Optional, Tuple import verifiers as vf +from pydantic import Field from tqdm.asyncio import tqdm_asyncio from atroposlib.envs.base import ( @@ -20,8 +27,12 @@ from atroposlib.envs.base import ( class VfEnvConfig(BaseEnvConfig): + """ + Configuration for the Verifiers environments. + """ + vf_env_name: str = "" - env_args: dict = {} + env_args: Dict[str, Any] = Field(default_factory=dict) class VerifiersEnv(BaseEnv): @@ -36,6 +47,7 @@ class VerifiersEnv(BaseEnv): testing=False, ): super().__init__(config, server_configs, slurm, testing) + self.percent_correct_buffer = list() self.eval_metrics = list() self.vf_env = vf.load_environment(config.vf_env_name, **config.env_args) self.rubric = self.vf_env.rubric @@ -51,31 +63,59 @@ class VerifiersEnv(BaseEnv): @classmethod def config_init(cls) -> Tuple[VfEnvConfig, List[APIServerConfig]]: env_config = VfEnvConfig( + tokenizer_name="Qwen/Qwen2.5-1.5B-Instruct", group_size=8, - use_wandb=False, + use_wandb=True, rollout_server_url="http://localhost:8000", - total_steps=10, + total_steps=1000, batch_size=4, - steps_per_eval=1, + steps_per_eval=100, max_token_length=2048, wandb_name="verifiers", ) + # Requires local inference server (vLLM, SGLang, TRL) + # For evaluation with OpenAI, use eval_environments/verifiers_evaluation.py server_configs = [ APIServerConfig( - model_name="gpt-4.1-nano", - base_url=None, - api_key=os.getenv("OPENAI_API_KEY"), + model_name="Qwen/Qwen2.5-1.5B-Instruct", + base_url="http://localhost:9001/v1", + api_key="x", num_requests_for_eval=4, ), ] return env_config, server_configs + async def wandb_log(self, wandb_metrics: Optional[Dict] = None): + if wandb_metrics is None: + wandb_metrics = {} + + # Calculate percent_correct from buffer + if self.percent_correct_buffer: + wandb_metrics["train/percent_correct"] = sum( + self.percent_correct_buffer + ) / len(self.percent_correct_buffer) + + self.percent_correct_buffer = list() + + for item in self.eval_metrics: + wandb_metrics[item[0]] = item[1] + self.eval_metrics = list() + + await super().wandb_log(wandb_metrics) + async def setup(self): - self.train = self.vf_env.get_dataset() + train_data = self.vf_env.get_dataset() + self.train = train_data.select_columns(["question", "answer"]).to_list() test_data = self.vf_env.get_eval_dataset() self.test = test_data.select_columns(["question", "answer"]).to_list() self.iter = 0 + def save_checkpoint(self, step, data=None): + if data is None: + data = {} + data["iter"] = self.iter + super().save_checkpoint(step, data) + async def rollout_and_score_eval( self, question: str, answer: str, **kwargs ) -> dict: @@ -85,12 +125,13 @@ class VerifiersEnv(BaseEnv): {"role": "user", "content": question}, ] - completion = await self.server.chat_completion( - messages=messages, - n=1, - max_tokens=self.config.max_token_length, - temperature=0.0, - ) + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + completion = await managed.chat_completion( + messages=messages, + n=1, + max_tokens=self.config.max_token_length, + temperature=0.0, + ) response_content = completion.choices[0].message.content or "" messages.append({"role": "assistant", "content": response_content}) @@ -114,7 +155,7 @@ class VerifiersEnv(BaseEnv): "question": question, "gold_answer": answer, "model_parsed": str(answer_parsed) if answer_parsed else None, - "score": int(score), + "score": score, "correct": bool(score), "finish_reason": completion.choices[0].finish_reason, } @@ -171,38 +212,25 @@ class VerifiersEnv(BaseEnv): {"role": "user", "content": question}, ] - completions = await self.server.chat_completion( - messages=messages, - n=self.config.group_size, - max_tokens=self.config.max_token_length, - temperature=1.0, - ) + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + completions = await managed.chat_completion( + messages=messages, + n=self.config.group_size, + max_tokens=self.config.max_token_length, + temperature=1.0, + ) + state = managed.get_state() + nodes = state["nodes"] - prompt_text = self.tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - prompt_tokens = self.tokenizer.encode(prompt_text, add_special_tokens=False) - prompt_len = len(prompt_tokens) + scored_data = ScoredDataGroup() + scored_data["tokens"] = list() + scored_data["masks"] = list() + scored_data["scores"] = list() + scored_data["inference_logprobs"] = list() - scores: ScoredDataGroup = { - "tokens": [], - "masks": [], - "scores": [], - "inference_logprobs": [], - } - - for choice in completions.choices: + for i, choice in enumerate(completions.choices): response = choice.message.content or "" - # Tokenize full sequence (prompt + completion) - full_text = prompt_text + response - full_tokens = self.tokenizer.encode(full_text, add_special_tokens=False) - - # Create masks: -100 for prompt, actual tokens for completion - masks = [-100] * prompt_len + full_tokens[prompt_len:] - - logprobs = [1.0] * prompt_len + [0.0] * (len(full_tokens) - prompt_len) - # Score using reward funcs completion_messages = messages + [ {"role": "assistant", "content": response} @@ -220,12 +248,18 @@ class VerifiersEnv(BaseEnv): ] score = sum(weighted_rewards) - scores["tokens"].append(full_tokens) - scores["masks"].append(masks) - scores["inference_logprobs"].append(logprobs) - scores["scores"].append(score) + # Use ManagedServer's properly aligned tokens/masks/logprobs + node = nodes[i] + scored_data["tokens"].append(node.tokens) + scored_data["masks"].append(node.masked_tokens) + scored_data["inference_logprobs"].append(node.logprobs) + scored_data["scores"].append(score) - return scores, [] + # Track scores for wandb logging + for score in scored_data["scores"]: + self.percent_correct_buffer.append(max(score, 0)) + + return scored_data, [] if __name__ == "__main__": From dda85430da9960d2a97d602d6b73c822586c91c1 Mon Sep 17 00:00:00 2001 From: "balyan.sid@gmail.com" Date: Fri, 9 Jan 2026 16:25:44 +0530 Subject: [PATCH 04/22] fix docstrings --- .../eval_environments/verifiers_eval.py | 14 +++++------ environments/verifiers_server.py | 24 +++++++++---------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/environments/eval_environments/verifiers_eval.py b/environments/eval_environments/verifiers_eval.py index 61e17656..6d7c3755 100644 --- a/environments/eval_environments/verifiers_eval.py +++ b/environments/eval_environments/verifiers_eval.py @@ -11,7 +11,7 @@ To install a Verifiers/Prime environment: Docs: https://docs.primeintellect.ai/tutorials-environments/install Usage: - python verifiers_evaluation.py evaluate \ + python verifiers_eval.py evaluate \ --env.vf_env_name primeintellect/gsm8k \ --openai.model_name gpt-4.1-nano \ --openai.api_key $OPENAI_API_KEY @@ -235,12 +235,12 @@ class VerifiersEvaluationEnv(BaseEnv): async def evaluate(self, *args, **kwargs) -> Dict: """Run the full evaluation.""" - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print(f"Starting Verifiers Evaluation: {self.config.vf_env_name}") - print(f"{'='*60}") + print(f"{'=' * 60}") print(f" Total questions: {len(self.eval_items)}") print(f" Temperature: {self.config.temperature}") - print(f"{'='*60}\n") + print(f"{'=' * 60}\n") start_time = time.time() @@ -286,9 +286,9 @@ class VerifiersEvaluationEnv(BaseEnv): } # Print results - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print("Verifiers Evaluation Results") - print(f"{'='*60}") + print(f"{'=' * 60}") print(f" Average Score: {avg_score:.4f}") print(f" Accuracy: {accuracy:.2%} ({correct}/{total})") print(f" Time: {end_time - start_time:.1f}s") @@ -297,7 +297,7 @@ class VerifiersEvaluationEnv(BaseEnv): print( f" {name}: avg={data['avg']:.4f}, correct={data['correct']}/{total}" ) - print(f"{'='*60}\n") + print(f"{'=' * 60}\n") # Log to evaluate_log samples = [ diff --git a/environments/verifiers_server.py b/environments/verifiers_server.py index 873cf266..e7216919 100644 --- a/environments/verifiers_server.py +++ b/environments/verifiers_server.py @@ -1,15 +1,15 @@ -# Verifiers Training Environment for Atropos -# -# NOTE: This environment requires a LOCAL inference server (vLLM, SGLang, TRL) -# for ALL modes (serve, process, evaluate) because it uses ManagedServer for -# token/logprob tracking. For evaluation with OpenAI API, use: -# environments/eval_environments/verifiers_eval.py -# -# To install a Verifiers/Prime environment: -# 1. uv tool install prime -# 2. prime login -# 3. prime env install will/wordle (or any owner/environment) -# Docs: https://docs.primeintellect.ai/tutorials-environments/install +""" +Verifiers Training Environment for Atropos +NOTE: This environment requires a LOCAL inference server (vLLM, SGLang, TRL) +for ALL modes (serve, process, evaluate) because it uses ManagedServer for +token/logprob tracking. For evaluation with OpenAI API, use: `environments/eval_environments/verifiers_eval.py` + +To install a Verifiers/Prime environment: +1. uv tool install prime +2. prime login +3. prime env install will/wordle (or any owner/environment) +Docs: https://docs.primeintellect.ai/tutorials-environments/install +""" import time from typing import Any, Dict, List, Optional, Tuple From 636715bb08b2fb41874d1b9a1040dabd6d2a9653 Mon Sep 17 00:00:00 2001 From: "balyan.sid@gmail.com" Date: Fri, 9 Jan 2026 16:51:19 +0530 Subject: [PATCH 05/22] add wandb to eval --- .../eval_environments/verifiers_eval.py | 57 ++++++++++--------- 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/environments/eval_environments/verifiers_eval.py b/environments/eval_environments/verifiers_eval.py index 6d7c3755..4efcd26e 100644 --- a/environments/eval_environments/verifiers_eval.py +++ b/environments/eval_environments/verifiers_eval.py @@ -26,6 +26,7 @@ import verifiers as vf from pydantic import Field from tqdm.asyncio import tqdm_asyncio +import wandb from atroposlib.envs.base import ( APIServerConfig, BaseEnv, @@ -46,13 +47,10 @@ class VerifiersEvaluationConfig(BaseEnvConfig): description="Additional arguments for verifiers environment", ) - # Generation parameters temperature: float = Field( default=0.0, description="Temperature for generation (0.0 for deterministic)" ) - max_tokens: int = Field(default=2048, description="Maximum tokens for generation") - # Retry and debug configuration max_retries: int = Field( default=3, description="Maximum retries for failed API calls" ) @@ -64,16 +62,6 @@ class VerifiersEvaluationConfig(BaseEnvConfig): ) full_debug: bool = Field(default=False, description="Enable full debug output") - # Override defaults for evaluation mode - group_size: int = 1 - max_num_workers: int = 256 - max_num_workers_per_node: int = 64 - use_wandb: bool = True - rollout_server_url: str = "http://localhost:8000" - total_steps: int = 1 - wandb_name: str = "verifiers_evaluation" - steps_per_eval: int = 1 - class VerifiersEvaluationEnv(BaseEnv): """ @@ -118,17 +106,11 @@ class VerifiersEvaluationEnv(BaseEnv): """Default configuration for evaluation.""" env_config = VerifiersEvaluationConfig( vf_env_name="primeintellect/gsm8k", - temperature=0.0, - max_tokens=2048, - use_wandb=True, - wandb_name="verifiers_evaluation", ) server_configs = [ APIServerConfig( model_name="gpt-4.1-nano", - base_url=None, api_key=os.getenv("OPENAI_API_KEY"), - num_requests_for_eval=256, ), ] return env_config, server_configs @@ -169,7 +151,7 @@ class VerifiersEvaluationEnv(BaseEnv): kwargs = { "messages": messages, "temperature": self.config.temperature, - "max_tokens": self.config.max_tokens, + "max_tokens": self.config.max_token_length, "n": 1, } @@ -323,23 +305,42 @@ class VerifiersEvaluationEnv(BaseEnv): end_time=end_time, generation_parameters={ "temperature": self.config.temperature, - "max_tokens": self.config.max_tokens, + "max_tokens": self.config.max_token_length, }, ) + # Log to wandb + await self.wandb_log(metrics) + return metrics async def wandb_log(self, wandb_metrics: Optional[Dict] = None) -> None: """Log metrics to Weights & Biases.""" - if wandb_metrics is None: - wandb_metrics = {} + if not self.config.use_wandb or wandb_metrics is None: + return - # Add config info - wandb_metrics["config/vf_env_name"] = self.config.vf_env_name - wandb_metrics["config/temperature"] = self.config.temperature - wandb_metrics["config/max_tokens"] = self.config.max_tokens + # Lazy init if wandb not already initialized + if wandb.run is None: + wandb.init( + project="verifiers-eval", + name=self.config.wandb_name, + config=self.config.model_dump(), + ) - await super().wandb_log(wandb_metrics) + log_dict = { + "verifiers/accuracy": wandb_metrics.get("accuracy", 0), + "verifiers/avg_score": wandb_metrics.get("avg_score", 0), + "verifiers/total_evaluated": wandb_metrics.get("total_evaluated", 0), + "verifiers/total_correct": wandb_metrics.get("total_correct", 0), + } + + # Add per-reward function metrics + reward_breakdown = wandb_metrics.get("reward_breakdown", {}) + for func_name, data in reward_breakdown.items(): + log_dict[f"verifiers/{func_name}_avg"] = data.get("avg", 0) + log_dict[f"verifiers/{func_name}_correct"] = data.get("correct", 0) + + wandb.log(log_dict) # Required abstract method implementations (stubs for evaluation-only mode) async def get_next_item(self) -> Optional[Dict]: From 5b09ad86f4f9acb40d6236d0b8031818f7909cd1 Mon Sep 17 00:00:00 2001 From: "balyan.sid@gmail.com" Date: Fri, 9 Jan 2026 19:20:41 +0530 Subject: [PATCH 06/22] update readme, add sft-datagen to verifiers_server --- environments/README.md | 153 ++++++++++++++ .../eval_environments/verifiers_eval.py | 11 +- environments/verifiers_server.py | 194 ++++++++++++++---- 3 files changed, 320 insertions(+), 38 deletions(-) diff --git a/environments/README.md b/environments/README.md index 79201afb..6000bf58 100644 --- a/environments/README.md +++ b/environments/README.md @@ -11,7 +11,160 @@ This directory contains various environments for training and evaluating languag --- +### Prime Intellect Verifiers Integration +A flexible environment that integrates with the [Verifiers](https://docs.primeintellect.ai/) ecosystem, allowing you to use any registered Prime environment for RL training, SFT data generation, or evaluation. + +**Files:** +- `environments/verifiers_server.py` - Training and SFT data generation +- `environments/eval_environments/verifiers_eval.py` - Standalone evaluation + +**Dependencies:** + +- `verifiers` Python package (install via `pip install verifiers` or include in your environment) +- Prime CLI for environment management (`uv tool install prime`) +- Prime CLI login required (`prime login`) +- Environment installation (`prime env install owner/env_name`) + +**Supported Modes:** + +| Mode | File | Description | +|------|------|-------------| +| `serve` | `verifiers_server.py` | RL training with local inference server (requires ManagedServer for logprobs) | +| `process` | `verifiers_server.py` | SFT data generation with ANY API (OpenAI, Claude, local, etc.) | +| `evaluate` | `verifiers_server.py` | Quick evaluation using ManagedServer | +| `evaluate` | `verifiers_eval.py` | Standalone evaluation with detailed metrics and retry logic | + +**Input Format:** + +- Loaded dynamically from the specified Prime environment via `vf.load_environment()` +- Each item contains: + - `question`: The problem/prompt + - `answer`: The expected answer for verification + +**System Prompt:** + +- Dynamically loaded from the Prime environment's `system_prompt` configuration + +**Reward Function:** + +- Uses the environment's **rubric** system with: + - `parser`: Extracts answers from completions (e.g., `parser.parse_answer(completion)`) + - `funcs`: List of reward functions that receive `(parser, completion, answer)` + - `weights`: Weights for combining reward functions (normalized to sum to 1.0) +- Final score is weighted sum of all reward function outputs + +**W&B Metrics Logged (Training - `verifiers_server.py`):** + +| Metric | Description | +|--------|-------------| +| `train/percent_correct` | Average score from verifiers reward functions (0-1) | +| `train/rollouts` | Table of tokenized completions with scores | +| `train/completion_lengths_*` | Response length statistics (std, min, max, p95) | +| `server/server_0_request_time_*` | API latency metrics (avg, std, 99p) | +| `eval/avg_total_score` | Average score on evaluation dataset | + +**W&B Metrics Logged (Evaluation - `verifiers_eval.py`):** + +| Metric | Description | +|--------|-------------| +| `verifiers/accuracy` | Proportion of items with score > 0 | +| `verifiers/avg_score` | Average weighted score across all items | +| `verifiers/total_evaluated` | Number of successfully evaluated items | +| `verifiers/total_correct` | Number of items with score > 0 | +| `verifiers/reward_func_N_avg` | Per-reward function average score | +| `verifiers/reward_func_N_correct` | Per-reward function correct count | + +**Configuration Options (`VfEnvConfig` for `verifiers_server.py`):** + +| Option | Type | Default | Description | +|--------|------|---------|-------------| +| `vf_env_name` | str | `""` | Prime environment identifier (e.g., `"will/wordle"`, `"primeintellect/gsm8k"`) | +| `env_args` | Dict | `{}` | Additional arguments passed to `vf.load_environment()`. Read environment specific documentation to get these args. | + +**Configuration Options (`VerifiersEvaluationConfig` for `verifiers_eval.py`):** + +| Option | Type | Default | Description | +|--------|------|---------|-------------| +| `vf_env_name` | str | `""` | Prime environment identifier | +| `env_args` | dict | `{}` | Additional arguments for verifiers environment | +| `temperature` | float | `0.0` | Temperature for generation (0.0 for deterministic) | +| `max_retries` | int | `3` | Maximum retries for failed API calls | +| `retry_delay` | float | `1.0` | Delay between retries in seconds | +| `min_response_length` | int | `1` | Minimum response length to consider valid | +| `full_debug` | bool | `False` | Enable verbose per-item debug output | +| `max_eval_items` | int | `-1` | Maximum number of items to evaluate (-1 for all) | + +**Usage Examples:** + +```bash +# RL Training (requires local vLLM/SGLang server) +python verifiers_server.py serve \ + --env.vf_env_name "will/wordle" \ + --openai.base_url http://localhost:9001/v1 \ + --slurm false + +# SFT Data Generation with OpenAI GPT-4o +python verifiers_server.py process \ + --env.vf_env_name "will/wordle" \ + --env.data_path_to_save_groups gpt4o_sft_data.jsonl \ + --env.total_steps 100 \ + --env.group_size 4 \ + --openai.model_name gpt-4o \ + --openai.base_url https://api.openai.com/v1 + +# SFT Data Generation with local server +python verifiers_server.py process \ + --env.vf_env_name "will/wordle" \ + --env.data_path_to_save_groups local_sft_data.jsonl \ + --openai.base_url http://localhost:9001/v1 + +# Quick Evaluation via verifiers_server.py +python verifiers_server.py evaluate \ + --env.vf_env_name "will/wordle" \ + --openai.base_url http://localhost:9001/v1 + +# Standalone Evaluation with detailed metrics (verifiers_eval.py) +python eval_environments/verifiers_eval.py evaluate \ + --env.vf_env_name "primeintellect/gsm8k" \ + --openai.model_name gpt-4o \ + --openai.api_key $OPENAI_API_KEY + +# Quick test run with limited items +python eval_environments/verifiers_eval.py evaluate \ + --env.vf_env_name "primeintellect/gsm8k" \ + --env.max_eval_items 10 \ + --openai.model_name gpt-4o \ + --openai.api_key $OPENAI_API_KEY + +# Evaluation with debug output +python eval_environments/verifiers_eval.py evaluate \ + --env.vf_env_name "primeintellect/gsm8k" \ + --env.full_debug true \ + --openai.base_url http://localhost:9001/v1 +``` + +**Key Implementation Details:** + +- **RL Training Mode (`serve`)**: Uses `ManagedServer` for proper token/logprob alignment required by policy gradient methods (GRPO, PPO, REINFORCE). Returns `ScoredDataGroup` with `tokens`, `masks`, `scores`, and `inference_logprobs`. +- **SFT Datagen Mode (`process`)**: Uses `tokenize_for_trainer` to tokenize API responses with your target model's tokenizer (e.g., GPT-4o responses tokenized for Qwen/Llama). Does NOT require logprobs. +- **Evaluation (`evaluate`)**: Runs on the environment's eval dataset with greedy decoding (temperature=0). The standalone `verifiers_eval.py` provides more detailed metrics and retry logic for production evaluation. + +**Prime Environment Installation:** +```bash +# Install Prime CLI +uv tool install prime + +# Login to Prime +prime login + +# Install an environment (e.g., Wordle, GSM8K) +prime env install will/wordle +prime env install primeintellect/gsm8k + +# List available environments +prime env list +``` ### Letter Counting Environment (`letter_counting_environment.py`) diff --git a/environments/eval_environments/verifiers_eval.py b/environments/eval_environments/verifiers_eval.py index 4efcd26e..9059657e 100644 --- a/environments/eval_environments/verifiers_eval.py +++ b/environments/eval_environments/verifiers_eval.py @@ -61,6 +61,9 @@ class VerifiersEvaluationConfig(BaseEnvConfig): default=1, description="Minimum response length to consider valid" ) full_debug: bool = Field(default=False, description="Enable full debug output") + max_eval_items: int = Field( + default=-1, description="Maximum number of items to evaluate (-1 for all)" + ) class VerifiersEvaluationEnv(BaseEnv): @@ -110,6 +113,7 @@ class VerifiersEvaluationEnv(BaseEnv): server_configs = [ APIServerConfig( model_name="gpt-4.1-nano", + base_url="https://api.openai.com/v1", api_key=os.getenv("OPENAI_API_KEY"), ), ] @@ -121,6 +125,11 @@ class VerifiersEvaluationEnv(BaseEnv): # Load datasets from verifiers environment test_data = self.vf_env.get_eval_dataset() self.eval_items = test_data.select_columns(["question", "answer"]).to_list() + + # Limit items if max_eval_items is set + if self.config.max_eval_items > 0: + self.eval_items = self.eval_items[: self.config.max_eval_items] + self._dataset_loaded = True print("\nVerifiers Evaluation Setup:") @@ -322,7 +331,7 @@ class VerifiersEvaluationEnv(BaseEnv): # Lazy init if wandb not already initialized if wandb.run is None: wandb.init( - project="verifiers-eval", + project="atropos-environments", name=self.config.wandb_name, config=self.config.model_dump(), ) diff --git a/environments/verifiers_server.py b/environments/verifiers_server.py index e7216919..434e1368 100644 --- a/environments/verifiers_server.py +++ b/environments/verifiers_server.py @@ -1,8 +1,36 @@ """ Verifiers Training Environment for Atropos -NOTE: This environment requires a LOCAL inference server (vLLM, SGLang, TRL) -for ALL modes (serve, process, evaluate) because it uses ManagedServer for -token/logprob tracking. For evaluation with OpenAI API, use: `environments/eval_environments/verifiers_eval.py` + +Supports TWO modes: +- serve: RL training with local inference server (requires ManagedServer for logprobs) +- process: SFT data generation with ANY API (OpenAI, Claude, local, etc.) + +Usage: + # RL Training (requires local vLLM/SGLang server) + python verifiers_server.py serve \ + --env.vf_env_name "will/wordle" \ + --openai.base_url http://localhost:9001/v1 \ + --slurm false + + # SFT Data Generation with OpenAI GPT-4o + python verifiers_server.py process \ + --env.vf_env_name "will/wordle" \ + --env.data_path_to_save_groups gpt4o_sft_data.jsonl \ + --env.total_steps 100 \ + --env.group_size 4 \ + --openai.model_name gpt-4o \ + --openai.base_url https://api.openai.com/v1 + + # SFT Data Generation with local server + python verifiers_server.py process \ + --env.vf_env_name "will/wordle" \ + --env.data_path_to_save_groups local_sft_data.jsonl \ + --openai.base_url http://localhost:9001/v1 + + # Evaluation (uses ManagedServer by default, falls back to direct API in process mode) + python verifiers_server.py evaluate \ + --env.vf_env_name "will/wordle" \ + --openai.base_url http://localhost:9001/v1 To install a Verifiers/Prime environment: 1. uv tool install prime @@ -24,6 +52,7 @@ from atroposlib.envs.base import ( BaseEnvConfig, ScoredDataGroup, ) +from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer class VfEnvConfig(BaseEnvConfig): @@ -73,12 +102,13 @@ class VerifiersEnv(BaseEnv): max_token_length=2048, wandb_name="verifiers", ) - # Requires local inference server (vLLM, SGLang, TRL) - # For evaluation with OpenAI, use eval_environments/verifiers_evaluation.py + # Default config for local inference server (vLLM, SGLang, TRL) + # For SFT data generation with OpenAI, override via CLI: + # --openai.base_url https://api.openai.com/v1 --openai.model_name gpt-4o server_configs = [ APIServerConfig( - model_name="Qwen/Qwen2.5-1.5B-Instruct", - base_url="http://localhost:9001/v1", + model_name="gpt-4.1-nano", + base_url="https://api.openai.com/v1", api_key="x", num_requests_for_eval=4, ), @@ -116,39 +146,58 @@ class VerifiersEnv(BaseEnv): data["iter"] = self.iter super().save_checkpoint(step, data) + def _compute_score(self, completion_messages: List[Dict], answer: str) -> float: + """Compute score using verifiers reward functions.""" + rewards = [] + for func in self.reward_funcs: + reward = func( + parser=self.parser, + completion=completion_messages, + answer=answer, + ) + rewards.append(reward) + weighted_rewards = [r * self.reward_scales[j] for j, r in enumerate(rewards)] + return sum(weighted_rewards) + async def rollout_and_score_eval( self, question: str, answer: str, **kwargs ) -> dict: + """ + Rollout and score for evaluation. + Uses ManagedServer in serve mode, direct API calls in process mode. + """ system_prompt = kwargs.get("system_prompt") messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": question}, ] - async with self.server.managed_server(tokenizer=self.tokenizer) as managed: - completion = await managed.chat_completion( + is_process_mode = getattr(self, "process_mode", False) + + if is_process_mode: + # Process mode: use direct API call (works with any API) + completion = await self.server.chat_completion( messages=messages, n=1, max_tokens=self.config.max_token_length, temperature=0.0, ) + else: + # Serve mode: use ManagedServer for token tracking + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + completion = await managed.chat_completion( + messages=messages, + n=1, + max_tokens=self.config.max_token_length, + temperature=0.0, + ) response_content = completion.choices[0].message.content or "" messages.append({"role": "assistant", "content": response_content}) answer_parsed = self.parser.parse_answer(completion=response_content) - rewards = [] - for func in self.reward_funcs: - reward = func( - parser=self.parser, - completion=messages, - answer=answer, - ) - rewards.append(reward) - weighted_rewards = [r * self.reward_scales[j] for j, r in enumerate(rewards)] - - score = sum(weighted_rewards) + score = self._compute_score(messages, answer) sample = { "messages": messages, @@ -204,6 +253,11 @@ class VerifiersEnv(BaseEnv): return next_item async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, list]: + """ + Collect trajectories - automatically switches between: + - ManagedServer (for RL training with local server - requires logprobs) + - tokenize_for_trainer (for SFT datagen with any API - no logprobs needed) + """ question = item["question"] answer = item["answer"] @@ -212,6 +266,81 @@ class VerifiersEnv(BaseEnv): {"role": "user", "content": question}, ] + # Check if we're in process mode (SFT data generation) + is_process_mode = getattr(self, "process_mode", False) + + if is_process_mode: + return await self._collect_trajectories_for_sft(messages, answer) + else: + return await self._collect_trajectories_for_training(messages, answer) + + async def _collect_trajectories_for_sft( + self, messages: List[Dict], answer: str + ) -> Tuple[ScoredDataGroup, list]: + """ + SFT data generation mode - works with ANY API (OpenAI, Claude, local). + Does NOT require logprobs or local server. + + Uses tokenize_for_trainer to tokenize completions with your training + tokenizer, so the resulting data is ready for fine-tuning your target model. + """ + completions = await self.server.chat_completion( + messages=messages, + n=self.config.group_size, + max_tokens=self.config.max_token_length, + temperature=1.0, + ) + + scored_data = ScoredDataGroup() + scored_data["tokens"] = [] + scored_data["masks"] = [] + scored_data["scores"] = [] + scored_data["messages"] = [] + # Note: No inference_logprobs - not needed/available for SFT + + for choice in completions.choices: + response = choice.message.content or "" + finish_reason = choice.finish_reason or "" + + # Build full conversation for scoring and tokenization + completion_messages = messages + [ + {"role": "assistant", "content": response} + ] + + # Score using verifiers reward funcs + score = self._compute_score(completion_messages, answer) + + # Use tokenize_for_trainer for tokenization + # This uses YOUR training tokenizer (e.g., Qwen, Llama), not the API's tokenizer + # So GPT-4o responses get tokenized for your target model + tokenized = tokenize_for_trainer( + tokenizer=self.tokenizer, + chat=completion_messages, + include_messages=True, + finish_reason=finish_reason, + ) + + scored_data["tokens"].append(tokenized["tokens"]) + scored_data["masks"].append(tokenized["masks"]) + scored_data["messages"].append(completion_messages) + scored_data["scores"].append(score) + + # Track scores for wandb logging + for score in scored_data["scores"]: + self.percent_correct_buffer.append(max(score, 0)) + + return scored_data, [] + + async def _collect_trajectories_for_training( + self, messages: List[Dict], answer: str + ) -> Tuple[ScoredDataGroup, list]: + """ + RL training mode - requires local inference server. + Uses ManagedServer for proper token/logprob alignment. + + The inference_logprobs are required for policy gradient methods like + GRPO, PPO, REINFORCE, etc. + """ async with self.server.managed_server(tokenizer=self.tokenizer) as managed: completions = await managed.chat_completion( messages=messages, @@ -223,30 +352,21 @@ class VerifiersEnv(BaseEnv): nodes = state["nodes"] scored_data = ScoredDataGroup() - scored_data["tokens"] = list() - scored_data["masks"] = list() - scored_data["scores"] = list() - scored_data["inference_logprobs"] = list() + scored_data["tokens"] = [] + scored_data["masks"] = [] + scored_data["scores"] = [] + scored_data["inference_logprobs"] = [] # Required for RL training! for i, choice in enumerate(completions.choices): response = choice.message.content or "" - # Score using reward funcs + # Build full conversation for scoring completion_messages = messages + [ {"role": "assistant", "content": response} ] - rewards = [] - for func in self.reward_funcs: - reward = func( - parser=self.parser, - completion=completion_messages, - answer=answer, - ) - rewards.append(reward) - weighted_rewards = [ - r * self.reward_scales[j] for j, r in enumerate(rewards) - ] - score = sum(weighted_rewards) + + # Score using verifiers reward funcs + score = self._compute_score(completion_messages, answer) # Use ManagedServer's properly aligned tokens/masks/logprobs node = nodes[i] From 3449a4c23d06493d41408685eebaa260946c71f2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 9 Jan 2026 11:22:01 +0000 Subject: [PATCH 07/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- environments/eval_environments/verifiers_eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/environments/eval_environments/verifiers_eval.py b/environments/eval_environments/verifiers_eval.py index 9059657e..98b95ffa 100644 --- a/environments/eval_environments/verifiers_eval.py +++ b/environments/eval_environments/verifiers_eval.py @@ -23,10 +23,10 @@ import time from typing import Dict, List, Optional, Tuple import verifiers as vf +import wandb from pydantic import Field from tqdm.asyncio import tqdm_asyncio -import wandb from atroposlib.envs.base import ( APIServerConfig, BaseEnv, From cf636595d22696cc55537d99a623874924131bd0 Mon Sep 17 00:00:00 2001 From: "balyan.sid@gmail.com" Date: Sat, 10 Jan 2026 14:55:08 +0530 Subject: [PATCH 08/22] rework server and eval for rl rollout. add in asyncmanagedserver for verifiers --- .../server_handling/atropos_managed_client.py | 330 ++++++++++++++++++ environments/configs/verifiers.yaml | 31 ++ .../eval_environments/verifiers_eval.py | 116 ++++-- environments/verifiers_server.py | 266 ++++++++++---- 4 files changed, 652 insertions(+), 91 deletions(-) create mode 100644 atroposlib/envs/server_handling/atropos_managed_client.py create mode 100644 environments/configs/verifiers.yaml diff --git a/atroposlib/envs/server_handling/atropos_managed_client.py b/atroposlib/envs/server_handling/atropos_managed_client.py new file mode 100644 index 00000000..75e78e39 --- /dev/null +++ b/atroposlib/envs/server_handling/atropos_managed_client.py @@ -0,0 +1,330 @@ +""" +AtroposManagedClient: AsyncOpenAI-compatible client backed by ManagedServer. + +This module provides a drop-in replacement for AsyncOpenAI that uses Atropos's +ManagedServer for inference, enabling token tracking for multi-turn RL training +with the Verifiers library. + +Usage: + async with server_manager.managed_server(tokenizer=tokenizer) as managed: + client = AtroposManagedClient(managed_server=managed, model="model-name") + + # Use like AsyncOpenAI - tokens are tracked automatically + response = await client.chat.completions.create( + messages=[{"role": "user", "content": "Hello"}], + max_tokens=100 + ) + + # Token data is available on the response: + # - response.prompt_token_ids + # - response.choices[0].token_ids + # - response.choices[0].logprobs.content[i].logprob +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +from openai.types.chat.chat_completion_message import ChatCompletionMessage + +from atroposlib.envs.server_handling.managed_server import ManagedServer, SequenceNode + +# ============================================================================= +# Enhanced Types for Token Data Injection +# ============================================================================= + + +@dataclass +class LogprobContent: + """ + Single token logprob entry. + + Compatible with verifiers' parse_response_tokens() which accesses: + - response.choices[i].logprobs.content[j].logprob + """ + + logprob: float + token: str = "" + token_id: int = 0 + top_logprobs: Optional[List[Any]] = None + + +@dataclass +class ChoiceLogprobs: + """ + Logprobs structure compatible with verifiers expectations. + + Verifiers checks for either object or dict format: + - Object: response.choices[i].logprobs.content[j].logprob + - Dict: response.choices[i].logprobs["content"][j]["logprob"] + + This dataclass supports the object format. + """ + + content: List[LogprobContent] = field(default_factory=list) + + +@dataclass +class EnhancedChoice: + """ + Choice with token_ids and logprobs for RL training. + + Adds the following attributes that verifiers expects: + - token_ids: List[int] - completion token IDs + - logprobs: ChoiceLogprobs - structured logprobs + """ + + index: int + message: ChatCompletionMessage + finish_reason: str + token_ids: List[int] + logprobs: ChoiceLogprobs + + +@dataclass +class EnhancedChatCompletion: + """ + ChatCompletion with token data for RL training. + + Compatible with verifiers' parse_response_tokens() expectations: + - prompt_token_ids: list[int] + - choices[i].token_ids: list[int] + - choices[i].logprobs.content[j].logprob + """ + + id: str + created: int + model: str + object: str + choices: List[EnhancedChoice] + prompt_token_ids: List[int] + usage: Optional[Dict[str, int]] = None + + +# ============================================================================= +# AsyncOpenAI-Compatible Client Classes +# ============================================================================= + + +class _CompletionsNamespace: + """ + Mimics openai.resources.chat.completions.AsyncCompletions. + + Provides the create() method that verifiers calls. + """ + + def __init__(self, parent: "AtroposManagedClient"): + self.parent = parent + + async def create( + self, + *, + messages: List[Dict[str, Any]], + model: Optional[str] = None, + n: int = 1, + max_tokens: Optional[int] = None, + max_completion_tokens: Optional[int] = None, + temperature: float = 1.0, + top_p: float = 1.0, + tools: Optional[List[Dict]] = None, + stop: Optional[List[str]] = None, + **kwargs, + ) -> EnhancedChatCompletion: + """ + Create chat completion with token tracking. + + Returns ChatCompletion with additional attributes: + - prompt_token_ids: list[int] + - choices[i].token_ids: list[int] + - choices[i].logprobs.content: list with logprob info + + Args: + messages: List of message dicts with 'role' and 'content' + model: Model name (defaults to client's model) + n: Number of completions (should be 1 for multi-turn) + max_tokens: Max tokens in completion (legacy param) + max_completion_tokens: Max tokens in completion (new param) + temperature: Sampling temperature + top_p: Nucleus sampling parameter + tools: Tool definitions for function calling + stop: Stop sequences + **kwargs: Additional parameters passed to ManagedServer + """ + # Use max_completion_tokens if provided, else max_tokens + effective_max_tokens = max_completion_tokens or max_tokens + + # Build kwargs for ManagedServer + completion_kwargs = { + "messages": messages, + "model": model or self.parent.model, + "n": n, + "temperature": temperature, + "top_p": top_p, + } + + if effective_max_tokens is not None: + completion_kwargs["max_tokens"] = effective_max_tokens + + if tools is not None: + completion_kwargs["tools"] = tools + + if stop is not None: + completion_kwargs["stop"] = stop + + # Add any extra kwargs (like logprobs settings) + for key, value in kwargs.items(): + if value is not None: + completion_kwargs[key] = value + + # Call ManagedServer for inference + completion = await self.parent.managed_server.chat_completion( + **completion_kwargs + ) + + # Get token state from managed server + state = self.parent.managed_server.get_state() + nodes: List[SequenceNode] = state["nodes"] + + # Inject token data into response + return self._enhance_completion(completion, nodes) + + def _enhance_completion( + self, completion: Any, nodes: List[SequenceNode] + ) -> EnhancedChatCompletion: + """ + Convert ManagedServer output to verifiers-compatible format. + + Extracts token data from SequenceNodes and injects it into the + ChatCompletion response in the format verifiers expects. + """ + enhanced_choices = [] + prompt_token_ids: List[int] = [] + + for i, (choice, node) in enumerate(zip(completion.choices, nodes)): + # Find prompt/completion boundary from masked_tokens + # -100 indicates prompt tokens, actual token IDs indicate completion + prompt_len = sum(1 for m in node.masked_tokens if m == -100) + + # Extract prompt and completion portions + if i == 0: + prompt_token_ids = node.tokens[:prompt_len] + + completion_ids = node.tokens[prompt_len:] + completion_logprobs = node.logprobs[prompt_len:] + + # Build logprobs structure verifiers expects + logprobs_content = [] + tokenizer = self.parent.managed_server.tokenizer + + for token_id, logprob in zip(completion_ids, completion_logprobs): + # Decode token to string if tokenizer available + token_str = "" + if tokenizer is not None: + try: + token_str = tokenizer.decode([token_id]) + except Exception: + token_str = f"" + + logprobs_content.append( + LogprobContent( + logprob=logprob, + token=token_str, + token_id=token_id, + ) + ) + + # Create enhanced choice with token data + enhanced_choice = EnhancedChoice( + index=choice.index, + message=choice.message, + finish_reason=choice.finish_reason or "stop", + token_ids=completion_ids, + logprobs=ChoiceLogprobs(content=logprobs_content), + ) + enhanced_choices.append(enhanced_choice) + + return EnhancedChatCompletion( + id=completion.id, + created=completion.created, + model=completion.model, + object=completion.object, + choices=enhanced_choices, + prompt_token_ids=prompt_token_ids, + usage=completion.usage.model_dump() if completion.usage else None, + ) + + +class _ChatNamespace: + """Mimics openai.resources.chat.AsyncChat.""" + + def __init__(self, parent: "AtroposManagedClient"): + self.completions = _CompletionsNamespace(parent) + + +class AtroposManagedClient: + """ + AsyncOpenAI-compatible client backed by ManagedServer. + + This client provides the same interface as AsyncOpenAI but uses Atropos's + ManagedServer for inference, enabling automatic token tracking for + multi-turn RL training with the Verifiers library. + + The key feature is that responses include token data attributes that + verifiers' parse_response_tokens() expects: + - response.prompt_token_ids + - response.choices[i].token_ids + - response.choices[i].logprobs.content[j].logprob + + Usage: + async with server_manager.managed_server(tokenizer=tokenizer) as managed: + client = AtroposManagedClient( + managed_server=managed, + model="Qwen/Qwen2.5-1.5B-Instruct" + ) + + # Pass to verifiers env.rollout() + state = await vf_env.rollout( + input=rollout_input, + client=client, + model="Qwen/Qwen2.5-1.5B-Instruct", + ) + + # Token data is now in state["trajectory"][i]["tokens"] + """ + + def __init__( + self, + managed_server: ManagedServer, + model: str, + base_url: Optional[str] = None, + ): + """ + Initialize the managed client. + + Args: + managed_server: ManagedServer instance for inference and token tracking + model: Model name to use for completions + base_url: Optional base URL (for API compatibility, not used) + """ + self.managed_server = managed_server + self.model = model + self.base_url = base_url or "http://managed-server" + + # Mimic AsyncOpenAI namespace structure + self.chat = _ChatNamespace(self) + + def reset(self): + """Reset token tracking state between rollouts.""" + self.managed_server.reset() + + async def close(self): + """Compatibility method - no-op since ManagedServer handles cleanup.""" + pass + + def copy(self, **_kwargs) -> "AtroposManagedClient": + """ + Create a copy of this client (for API compatibility). + + Verifiers may call client.copy() for certain operations. + Returns self since we want to maintain the same ManagedServer state. + """ + return self diff --git a/environments/configs/verifiers.yaml b/environments/configs/verifiers.yaml new file mode 100644 index 00000000..91ef7ec2 --- /dev/null +++ b/environments/configs/verifiers.yaml @@ -0,0 +1,31 @@ +# Verifiers environment configuration +# Usage: python environments/verifiers_server.py serve --config environments/configs/verifiers.yaml +# +# For SFT data generation with external API: +# python environments/verifiers_server.py process \ +# --env.vf_env_name primeintellect/gsm8k \ +# --env.data_path_to_save_groups output.jsonl \ +# --openai.base_url https://api.openai.com/v1 \ +# --openai.api_key $OPENAI_API_KEY \ +# --openai.model_name gpt-4o + +env: + vf_env_name: "primeintellect/gsm8k" # Prime Env Hub environment + env_args: {} + group_size: 8 + max_token_length: 2048 + tokenizer_name: "Qwen/Qwen2.5-1.5B-Instruct" + rollout_server_url: "http://localhost:8000" + use_wandb: true + wandb_name: "verifiers" + total_steps: 1000 + batch_size: 4 + steps_per_eval: 100 + +openai: + - model_name: "Qwen/Qwen2.5-1.5B-Instruct" + base_url: "http://localhost:9001/v1" + api_key: "x" + +slurm: false +testing: false diff --git a/environments/eval_environments/verifiers_eval.py b/environments/eval_environments/verifiers_eval.py index 98b95ffa..c5fb1063 100644 --- a/environments/eval_environments/verifiers_eval.py +++ b/environments/eval_environments/verifiers_eval.py @@ -18,15 +18,15 @@ Usage: """ import asyncio +import inspect import os import time -from typing import Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import verifiers as vf -import wandb from pydantic import Field -from tqdm.asyncio import tqdm_asyncio +import wandb from atroposlib.envs.base import ( APIServerConfig, BaseEnv, @@ -34,6 +34,36 @@ from atroposlib.envs.base import ( ) +# Patch math_verify timeout to work in async context +# The signal-based timeout doesn't work in non-main threads (asyncio event loop) +def _no_signal_timeout(timeout_seconds: int): + """Replacement timeout decorator that doesn't use signals.""" + + def decorator(func): + def wrapper(*args, **kwargs): + # Just call the function without timeout + # This is safe because we're in an async context with our own timeouts + # timeout_seconds is intentionally unused - we're replacing the timeout logic + return func(*args, **kwargs) + + return wrapper + + return decorator + + +try: + import math_verify.grader + import math_verify.parser + import math_verify.utils + + # Patch all modules that use the timeout decorator + math_verify.utils.timeout = _no_signal_timeout + math_verify.parser.timeout = _no_signal_timeout + math_verify.grader.timeout = _no_signal_timeout +except ImportError: + pass # math_verify not installed + + class VerifiersEvaluationConfig(BaseEnvConfig): """Configuration for Verifiers evaluation environment.""" @@ -91,13 +121,30 @@ class VerifiersEvaluationEnv(BaseEnv): self.vf_env = vf.load_environment(config.vf_env_name, **config.env_args) self.rubric = self.vf_env.rubric - # Extract rubric components + # Extract rubric components from RubricGroup + # RubricGroup.funcs is empty - need to collect from individual rubrics self.parser = self.rubric.parser - self.reward_funcs = self.rubric.funcs - self.reward_weights = self.rubric.weights - self.reward_scales = [ - weight / sum(self.reward_weights) for weight in self.reward_weights - ] + self.reward_funcs: List[Callable] = [] + self.reward_weights: List[float] = [] + self.rubric_class_objects: List[Dict[str, Any]] = [] # class_objects per func + + if hasattr(self.rubric, "rubrics"): + # RubricGroup: collect from all individual rubrics + for rubric in self.rubric.rubrics: + class_objects = getattr(rubric, "class_objects", {}) + for func, weight in zip(rubric.funcs, rubric.weights): + self.reward_funcs.append(func) + self.reward_weights.append(weight) + self.rubric_class_objects.append(class_objects) + else: + # Single Rubric + self.reward_funcs = self.rubric.funcs + self.reward_weights = self.rubric.weights + class_objects = getattr(self.rubric, "class_objects", {}) + self.rubric_class_objects = [class_objects] * len(self.rubric.funcs) + + total_weight = sum(self.reward_weights) if self.reward_weights else 1.0 + self.reward_scales = [weight / total_weight for weight in self.reward_weights] self.system_prompt = self.vf_env.system_prompt # Tracking @@ -192,14 +239,40 @@ class VerifiersEvaluationEnv(BaseEnv): # Parse answer answer_parsed = self.parser.parse_answer(completion=response_text) - # Score using reward funcs + # Score using reward funcs (async functions need await) + # Use signature introspection to pass only required params (like verifiers does) rewards = [] - for func in self.reward_funcs: - reward = func( - parser=self.parser, - completion=completion_messages, - answer=answer, - ) + for i, func in enumerate(self.reward_funcs): + try: + # Build merged dict of all possible parameters + class_objects = self.rubric_class_objects[i] + merged = { + "completion": completion_messages, + "answer": answer, + "prompt": question, + } + merged.update(class_objects) # Adds parser, etc. + + # Filter to only params the function accepts + sig = inspect.signature(func) + if any(p.kind == p.VAR_KEYWORD for p in sig.parameters.values()): + # Function accepts **kwargs, pass everything + kwargs = merged + else: + # Only pass params in signature + kwargs = {k: v for k, v in merged.items() if k in sig.parameters} + + result = func(**kwargs) + # Reward functions may be async coroutines + if asyncio.iscoroutine(result): + reward = await result + else: + reward = result + reward = float(reward) + except Exception as e: + if self.config.full_debug: + print(f" Reward func {func.__name__} error: {e}") + reward = 0.0 rewards.append(reward) weighted_rewards = [r * self.reward_scales[j] for j, r in enumerate(rewards)] @@ -235,11 +308,14 @@ class VerifiersEvaluationEnv(BaseEnv): start_time = time.time() - # Create evaluation tasks - tasks = [self.rollout_and_score(item) for item in self.eval_items] + # Run sequentially to avoid signal/threading issues with math_verify parser + # The parser uses signals for timeouts which only work in main thread + from tqdm import tqdm - # Run with progress bar - results = await tqdm_asyncio.gather(*tasks, desc="Evaluating") + results = [] + for item in tqdm(self.eval_items, desc="Evaluating"): + result = await self.rollout_and_score(item) + results.append(result) # Filter out failed results valid_results = [r for r in results if r is not None] diff --git a/environments/verifiers_server.py b/environments/verifiers_server.py index 434e1368..42d77c1c 100644 --- a/environments/verifiers_server.py +++ b/environments/verifiers_server.py @@ -39,10 +39,12 @@ To install a Verifiers/Prime environment: Docs: https://docs.primeintellect.ai/tutorials-environments/install """ +import asyncio import time from typing import Any, Dict, List, Optional, Tuple import verifiers as vf +from openai import AsyncOpenAI from pydantic import Field from tqdm.asyncio import tqdm_asyncio @@ -82,11 +84,21 @@ class VerifiersEnv(BaseEnv): self.rubric = self.vf_env.rubric self.parser = self.rubric.parser - self.reward_funcs = self.rubric.funcs - self.reward_weights = self.rubric.weights - self.reward_scales = [ - weight / sum(self.reward_weights) for weight in self.reward_weights - ] + + # Handle both single Rubric and RubricGroup (composite) + # RubricGroup has empty funcs/weights at top level - must extract from individual rubrics + if hasattr(self.rubric, "rubrics"): + self.reward_funcs = [] + self.reward_weights = [] + for rubric in self.rubric.rubrics: + self.reward_funcs.extend(rubric.funcs) + self.reward_weights.extend(rubric.weights) + else: + self.reward_funcs = self.rubric.funcs + self.reward_weights = self.rubric.weights + + total = sum(self.reward_weights) if self.reward_weights else 1.0 + self.reward_scales = [weight / total for weight in self.reward_weights] self.system_prompt = self.vf_env.system_prompt @classmethod @@ -135,9 +147,15 @@ class VerifiersEnv(BaseEnv): async def setup(self): train_data = self.vf_env.get_dataset() - self.train = train_data.select_columns(["question", "answer"]).to_list() + # Only load columns we need to avoid memory bloat + columns_to_keep = ["question", "answer", "info"] + available_columns = [c for c in columns_to_keep if c in train_data.column_names] + self.train = train_data.select_columns(available_columns).to_list() test_data = self.vf_env.get_eval_dataset() - self.test = test_data.select_columns(["question", "answer"]).to_list() + available_test_columns = [ + c for c in columns_to_keep if c in test_data.column_names + ] + self.test = test_data.select_columns(available_test_columns).to_list() self.iter = 0 def save_checkpoint(self, step, data=None): @@ -254,70 +272,116 @@ class VerifiersEnv(BaseEnv): async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, list]: """ - Collect trajectories - automatically switches between: - - ManagedServer (for RL training with local server - requires logprobs) - - tokenize_for_trainer (for SFT datagen with any API - no logprobs needed) + Collect trajectories - switches between: + - SFT data generation (process mode): Any API, no logprobs needed + - RL training (serve mode): Local server with logprobs """ - question = item["question"] - answer = item["answer"] - - messages = [ - {"role": "system", "content": self.system_prompt}, - {"role": "user", "content": question}, - ] - - # Check if we're in process mode (SFT data generation) is_process_mode = getattr(self, "process_mode", False) if is_process_mode: - return await self._collect_trajectories_for_sft(messages, answer) + return await self._collect_trajectories_for_sft(item) else: - return await self._collect_trajectories_for_training(messages, answer) + return await self._collect_trajectories_for_rl(item) async def _collect_trajectories_for_sft( - self, messages: List[Dict], answer: str + self, item: Dict[str, Any] ) -> Tuple[ScoredDataGroup, list]: """ SFT data generation mode - works with ANY API (OpenAI, Claude, local). Does NOT require logprobs or local server. - Uses tokenize_for_trainer to tokenize completions with your training - tokenizer, so the resulting data is ready for fine-tuning your target model. + Uses verifiers rollout() for multi-turn environments and tokenize_for_trainer + to tokenize completions with your training tokenizer. """ - completions = await self.server.chat_completion( - messages=messages, - n=self.config.group_size, - max_tokens=self.config.max_token_length, - temperature=1.0, + question = item["question"] + answer = item["answer"] + + # Build initial messages + initial_messages: List[Dict[str, str]] = [] + if self.system_prompt: + initial_messages.append({"role": "system", "content": self.system_prompt}) + initial_messages.append({"role": "user", "content": question}) + + # Create AsyncOpenAI client directly from server config (no ManagedServer needed) + server_config = self.server.servers[0].config + client = AsyncOpenAI( + api_key=server_config.api_key, + base_url=server_config.base_url, + timeout=server_config.timeout, ) + # Sampling args - use max_completion_tokens for newer models like gpt-5 + sampling_args = { + "temperature": 1.0, + "max_completion_tokens": self.config.max_token_length, + } + scored_data = ScoredDataGroup() scored_data["tokens"] = [] scored_data["masks"] = [] scored_data["scores"] = [] scored_data["messages"] = [] - # Note: No inference_logprobs - not needed/available for SFT - for choice in completions.choices: - response = choice.message.content or "" - finish_reason = choice.finish_reason or "" + # Semaphore for scoring (required by rubric.score_rollout) + score_sem = asyncio.Semaphore(1) - # Build full conversation for scoring and tokenization - completion_messages = messages + [ - {"role": "assistant", "content": response} + # Run rollouts in parallel for group_size + async def run_single_rollout(example_id: int): + # Pass through any info from the dataset item (e.g., docker_image for SWE envs) + item_info = item.get("info", {}) + rollout_input = { + "prompt": initial_messages, + "answer": answer, + "example_id": example_id, + "task": self.config.vf_env_name, + "info": item_info, + } + state = await self.vf_env.rollout( + input=rollout_input, + client=client, + model=server_config.model_name, + sampling_args=sampling_args, + ) + # Score the rollout using verifiers rubric (computes reward from test output) + # This is needed because vf_env.rollout() doesn't call score_rollout + await self.rubric.score_rollout(state, score_sem=score_sem) + return state + + # Run group_size rollouts in parallel + rollout_tasks = [run_single_rollout(i) for i in range(self.config.group_size)] + states = await asyncio.gather(*rollout_tasks) + + for state in states: + # Extract completion messages from state + completion_messages = list(state.get("prompt", [])) + list( + state.get("completion", []) + ) + # Ensure all message contents are strings (not None) + # This can happen with tool call messages that have content: null + completion_messages = [ + {**msg, "content": msg.get("content") or ""} + for msg in completion_messages ] - # Score using verifiers reward funcs - score = self._compute_score(completion_messages, answer) + # Get reward from verifiers scoring (set by rubric.score_rollout above) + score = state.get("reward", 0.0) + + # Determine finish reason from last trajectory step + trajectory = state.get("trajectory", []) + if trajectory: + finish_reason = trajectory[-1]["response"].choices[0].finish_reason + else: + finish_reason = "stop" # Use tokenize_for_trainer for tokenization - # This uses YOUR training tokenizer (e.g., Qwen, Llama), not the API's tokenizer - # So GPT-4o responses get tokenized for your target model + # train_on_all_assistant_turns=True ensures ALL assistant turns are unmasked + # for multi-turn environments, not just the last message tokenized = tokenize_for_trainer( tokenizer=self.tokenizer, chat=completion_messages, include_messages=True, finish_reason=finish_reason, + train_on_all_assistant_turns=True, ) scored_data["tokens"].append(tokenized["tokens"]) @@ -331,49 +395,73 @@ class VerifiersEnv(BaseEnv): return scored_data, [] - async def _collect_trajectories_for_training( - self, messages: List[Dict], answer: str + async def _collect_trajectories_for_rl( + self, item: Dict[str, Any] ) -> Tuple[ScoredDataGroup, list]: """ - RL training mode - requires local inference server. - Uses ManagedServer for proper token/logprob alignment. - - The inference_logprobs are required for policy gradient methods like - GRPO, PPO, REINFORCE, etc. + RL training mode - requires local inference server for logprobs. + Uses AtroposManagedClient with vf_env.rollout() for both single-turn and multi-turn. """ - async with self.server.managed_server(tokenizer=self.tokenizer) as managed: - completions = await managed.chat_completion( - messages=messages, - n=self.config.group_size, - max_tokens=self.config.max_token_length, - temperature=1.0, - ) - state = managed.get_state() - nodes = state["nodes"] + from atroposlib.envs.server_handling.atropos_managed_client import ( + AtroposManagedClient, + ) + + question = item["question"] + answer = item["answer"] + item_info = item.get("info", {}) + + initial_messages: List[Dict[str, str]] = [] + if self.system_prompt: + initial_messages.append({"role": "system", "content": self.system_prompt}) + initial_messages.append({"role": "user", "content": question}) + + sampling_args = { + "temperature": 1.0, + "max_completion_tokens": self.config.max_token_length, + } scored_data = ScoredDataGroup() scored_data["tokens"] = [] scored_data["masks"] = [] scored_data["scores"] = [] - scored_data["inference_logprobs"] = [] # Required for RL training! + scored_data["inference_logprobs"] = [] - for i, choice in enumerate(completions.choices): - response = choice.message.content or "" + # Semaphore for scoring (required by rubric.score_rollout) + score_sem = asyncio.Semaphore(1) - # Build full conversation for scoring - completion_messages = messages + [ - {"role": "assistant", "content": response} - ] + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + client = AtroposManagedClient( + managed_server=managed, + model=self.server_configs[0].model_name, + ) - # Score using verifiers reward funcs - score = self._compute_score(completion_messages, answer) + # Run group_size rollouts sequentially (ManagedServer state must be reset between) + for i in range(self.config.group_size): + client.reset() - # Use ManagedServer's properly aligned tokens/masks/logprobs - node = nodes[i] - scored_data["tokens"].append(node.tokens) - scored_data["masks"].append(node.masked_tokens) - scored_data["inference_logprobs"].append(node.logprobs) - scored_data["scores"].append(score) + rollout_input = { + "prompt": initial_messages, + "answer": answer, + "example_id": i, + "task": self.config.vf_env_name, + "info": item_info, + } + + state = await self.vf_env.rollout( + input=rollout_input, + client=client, + model=self.server_configs[0].model_name, + sampling_args=sampling_args, + ) + + # Score the rollout (computes reward from test output) + await self.rubric.score_rollout(state, score_sem=score_sem) + + tokens, masks, logprobs, score = self._extract_from_state(state) + scored_data["tokens"].append(tokens) + scored_data["masks"].append(masks) + scored_data["inference_logprobs"].append(logprobs) + scored_data["scores"].append(score) # Track scores for wandb logging for score in scored_data["scores"]: @@ -381,6 +469,42 @@ class VerifiersEnv(BaseEnv): return scored_data, [] + def _extract_from_state( + self, state: Any + ) -> Tuple[List[int], List[int], List[float], float]: + """ + Extract tokens/masks/logprobs/score from a single rollout state. + + Handles the mask convention conversion: + - Verifiers: prompt_mask=0, completion_mask=1 + - Atropos: masked_tokens=-100 (prompt), token_id (completion) + """ + all_tokens: List[int] = [] + all_masks: List[int] = [] + all_logprobs: List[float] = [] + + trajectory = state.get("trajectory", []) + + for step in trajectory: + tokens = step["tokens"] + + prompt_ids = tokens["prompt_ids"] + completion_ids = tokens["completion_ids"] + completion_logprobs = tokens["completion_logprobs"] + + all_tokens.extend(prompt_ids) + all_tokens.extend(completion_ids) + + all_masks.extend([-100] * len(prompt_ids)) + all_masks.extend(completion_ids) + + all_logprobs.extend([1.0] * len(prompt_ids)) + all_logprobs.extend(completion_logprobs) + + reward = state["reward"] + + return all_tokens, all_masks, all_logprobs, reward + if __name__ == "__main__": VerifiersEnv.cli() From 294b9806256c814ca49b2ae25083ced77fb22471 Mon Sep 17 00:00:00 2001 From: "balyan.sid@gmail.com" Date: Sat, 10 Jan 2026 14:57:03 +0530 Subject: [PATCH 09/22] add tests for AtroposManagedClient --- .../tests/test_atropos_managed_client.py | 235 ++++++++++++++++++ 1 file changed, 235 insertions(+) create mode 100644 atroposlib/tests/test_atropos_managed_client.py diff --git a/atroposlib/tests/test_atropos_managed_client.py b/atroposlib/tests/test_atropos_managed_client.py new file mode 100644 index 00000000..cb1cf955 --- /dev/null +++ b/atroposlib/tests/test_atropos_managed_client.py @@ -0,0 +1,235 @@ +"""Tests for AtroposManagedClient - AsyncOpenAI-compatible wrapper for ManagedServer.""" + +import pytest + +from atroposlib.envs.server_handling.atropos_managed_client import ( + AtroposManagedClient, + ChoiceLogprobs, + EnhancedChatCompletion, + LogprobContent, +) +from atroposlib.envs.server_handling.managed_server import ManagedServer +from atroposlib.envs.server_handling.server_harness import ServerHarness + + +class MockTokenizer: + """Mock tokenizer for testing.""" + + def __init__(self): + self.eos_token_id = 2 + self.bos_token_id = 1 + + def encode(self, text, add_special_tokens=True): + """Simple character-based encoding for testing.""" + tokens = [ord(c) for c in text] + if add_special_tokens: + tokens = [self.bos_token_id] + tokens + return tokens + + def decode(self, tokens, skip_special_tokens=False): + """Simple character-based decoding for testing.""" + if skip_special_tokens: + tokens = [ + t for t in tokens if t not in [self.bos_token_id, self.eos_token_id] + ] + return "".join([chr(t) if t > 31 else "" for t in tokens]) + + def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True): + """Simple chat template for testing.""" + result = "" + for msg in messages: + result += f"<{msg['role']}>{msg['content']}" + if add_generation_prompt: + result += "" + if tokenize: + return self.encode(result) + return result + + +@pytest.fixture +def mock_server(): + """Create a mock server with a tokenizer.""" + server = ServerHarness() + server.tokenizer = MockTokenizer() + + class Config: + model_name = "test_model" + + server.config = Config() + return server + + +@pytest.fixture +def managed_client(mock_server): + """Create an AtroposManagedClient with mocked server.""" + managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer) + return AtroposManagedClient(managed_server=managed, model="test_model") + + +class TestDataclasses: + """Test the enhanced dataclasses.""" + + def test_logprob_content(self): + """Test LogprobContent creation.""" + lp = LogprobContent(logprob=-0.5, token="hello", token_id=100) + assert lp.logprob == -0.5 + assert lp.token == "hello" + assert lp.token_id == 100 + + def test_choice_logprobs(self): + """Test ChoiceLogprobs structure.""" + content = [ + LogprobContent(logprob=-0.1), + LogprobContent(logprob=-0.2), + ] + logprobs = ChoiceLogprobs(content=content) + assert len(logprobs.content) == 2 + assert logprobs.content[0].logprob == -0.1 + + +class TestAtroposManagedClient: + """Test AtroposManagedClient behavior.""" + + def test_reset(self, managed_client): + """Test reset clears ManagedServer state.""" + # Add some state to managed server + managed_client.managed_server.current_nodes = ["dummy"] + + # Reset should clear it + managed_client.reset() + assert len(managed_client.managed_server.current_nodes) == 0 + + def test_copy_returns_self(self, managed_client): + """Test copy returns same instance for state sharing.""" + copied = managed_client.copy() + assert copied is managed_client + + def test_namespace_structure(self, managed_client): + """Test client has correct namespace structure like AsyncOpenAI.""" + assert hasattr(managed_client, "chat") + assert hasattr(managed_client.chat, "completions") + assert hasattr(managed_client.chat.completions, "create") + + @pytest.mark.asyncio + async def test_close_is_noop(self, managed_client): + """Test close() doesn't raise.""" + await managed_client.close() # Should not raise + + +class TestChatCompletionCreate: + """Test the chat.completions.create() method.""" + + @pytest.mark.asyncio + async def test_basic_completion(self, mock_server, managed_client): + """Test basic chat completion returns enhanced response.""" + messages = [{"role": "user", "content": "Hello"}] + managed = managed_client.managed_server + prompt = managed._convert_messages_to_prompt(messages) + prompt_tokens = mock_server.tokenizer.encode(prompt) + + output_text = "Hi there!" + output_tokens = [ord(c) for c in output_text] + output_logprobs = [-0.1] * len(output_tokens) + + mock_server.set_tokens_and_logprobs_response( + prompt=prompt, + prompt_tokens=prompt_tokens, + output_tokens_list=[output_tokens], + output_logprobs_list=[output_logprobs], + finish_reasons=["stop"], + ) + + result = await managed_client.chat.completions.create( + messages=messages, + max_tokens=100, + ) + + # Should return EnhancedChatCompletion + assert isinstance(result, EnhancedChatCompletion) + assert len(result.choices) == 1 + assert result.choices[0].message.content == output_text + + # Should have prompt_token_ids + assert len(result.prompt_token_ids) == len(prompt_tokens) + + # Should have token_ids on choice + assert len(result.choices[0].token_ids) == len(output_tokens) + assert result.choices[0].token_ids == output_tokens + + # Should have logprobs + assert len(result.choices[0].logprobs.content) == len(output_tokens) + assert result.choices[0].logprobs.content[0].logprob == -0.1 + + @pytest.mark.asyncio + async def test_max_completion_tokens_param(self, mock_server, managed_client): + """Test max_completion_tokens is preferred over max_tokens.""" + messages = [{"role": "user", "content": "Hi"}] + managed = managed_client.managed_server + prompt = managed._convert_messages_to_prompt(messages) + prompt_tokens = mock_server.tokenizer.encode(prompt) + + output_tokens = [ord("!")] + output_logprobs = [-0.1] + + mock_server.set_tokens_and_logprobs_response( + prompt=prompt, + prompt_tokens=prompt_tokens, + output_tokens_list=[output_tokens], + output_logprobs_list=[output_logprobs], + finish_reasons=["stop"], + ) + + # Should accept max_completion_tokens (new OpenAI param) + result = await managed_client.chat.completions.create( + messages=messages, + max_completion_tokens=50, + ) + + assert isinstance(result, EnhancedChatCompletion) + + @pytest.mark.asyncio + async def test_reset_between_rollouts(self, mock_server, managed_client): + """Test that reset clears state between rollouts.""" + messages = [{"role": "user", "content": "Hello"}] + managed = managed_client.managed_server + prompt = managed._convert_messages_to_prompt(messages) + prompt_tokens = mock_server.tokenizer.encode(prompt) + + output_tokens = [ord("!")] + output_logprobs = [-0.1] + + mock_server.set_tokens_and_logprobs_response( + prompt=prompt, + prompt_tokens=prompt_tokens, + output_tokens_list=[output_tokens], + output_logprobs_list=[output_logprobs], + finish_reasons=["stop"], + ) + + # First rollout + await managed_client.chat.completions.create(messages=messages, max_tokens=10) + state = managed_client.managed_server.get_state() + assert len(state["nodes"]) == 1 + + # Reset + managed_client.reset() + state = managed_client.managed_server.get_state() + assert len(state["nodes"]) == 0 + + # Setup for second rollout + mock_server.set_tokens_and_logprobs_response( + prompt=prompt, + prompt_tokens=prompt_tokens, + output_tokens_list=[output_tokens], + output_logprobs_list=[output_logprobs], + finish_reasons=["stop"], + ) + + # Second rollout + await managed_client.chat.completions.create(messages=messages, max_tokens=10) + state = managed_client.managed_server.get_state() + assert len(state["nodes"]) == 1 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From d98bc6d9fc89d84d4a5ec45af7d336b7b3cee390 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 10 Jan 2026 09:27:53 +0000 Subject: [PATCH 10/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- environments/eval_environments/verifiers_eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/environments/eval_environments/verifiers_eval.py b/environments/eval_environments/verifiers_eval.py index c5fb1063..3635d94f 100644 --- a/environments/eval_environments/verifiers_eval.py +++ b/environments/eval_environments/verifiers_eval.py @@ -24,9 +24,9 @@ import time from typing import Any, Callable, Dict, List, Optional, Tuple import verifiers as vf +import wandb from pydantic import Field -import wandb from atroposlib.envs.base import ( APIServerConfig, BaseEnv, From 24b4488c60f23c5079b88046e777291af6132580 Mon Sep 17 00:00:00 2001 From: "balyan.sid@gmail.com" Date: Mon, 12 Jan 2026 05:38:15 +0530 Subject: [PATCH 11/22] clean up eval, pin verifiers version --- .../eval_environments/verifiers_eval.py | 362 +++++++----------- environments/verifiers_server.py | 82 +++- pyproject.toml | 2 +- 3 files changed, 202 insertions(+), 244 deletions(-) diff --git a/environments/eval_environments/verifiers_eval.py b/environments/eval_environments/verifiers_eval.py index 3635d94f..df28973b 100644 --- a/environments/eval_environments/verifiers_eval.py +++ b/environments/eval_environments/verifiers_eval.py @@ -4,6 +4,9 @@ Verifiers Evaluation Environment for Atropos This environment evaluates models using Prime Intellect's Verifiers library. It supports any environment registered with the Verifiers ecosystem. +Uses verifiers' native rollout and scoring machinery - just pass an OpenAI-compatible +client and verifiers handles generation, parsing, and scoring. + To install a Verifiers/Prime environment: 1. uv tool install prime 2. prime login @@ -15,18 +18,24 @@ Usage: --env.vf_env_name primeintellect/gsm8k \ --openai.model_name gpt-4.1-nano \ --openai.api_key $OPENAI_API_KEY + +Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.): + python verifiers_eval.py evaluate \ + --env.vf_env_name primeintellect/gsm8k \ + --openai.model_name Qwen/Qwen2.5-7B-Instruct \ + --openai.base_url http://localhost:8000/v1 """ -import asyncio -import inspect import os import time -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import verifiers as vf -import wandb +from openai import AsyncOpenAI from pydantic import Field +import wandb + from atroposlib.envs.base import ( APIServerConfig, BaseEnv, @@ -36,14 +45,12 @@ from atroposlib.envs.base import ( # Patch math_verify timeout to work in async context # The signal-based timeout doesn't work in non-main threads (asyncio event loop) -def _no_signal_timeout(timeout_seconds: int): +def _no_signal_timeout(_timeout_seconds: int): """Replacement timeout decorator that doesn't use signals.""" def decorator(func): def wrapper(*args, **kwargs): - # Just call the function without timeout - # This is safe because we're in an async context with our own timeouts - # timeout_seconds is intentionally unused - we're replacing the timeout logic + # Just call the function without timeout - safe in async context return func(*args, **kwargs) return wrapper @@ -67,41 +74,47 @@ except ImportError: class VerifiersEvaluationConfig(BaseEnvConfig): """Configuration for Verifiers evaluation environment.""" - # Verifiers environment vf_env_name: str = Field( default="", description="Verifiers environment name (e.g., primeintellect/gsm8k)", ) - env_args: dict = Field( + env_args: Dict[str, Any] = Field( default_factory=dict, description="Additional arguments for verifiers environment", ) - temperature: float = Field( - default=0.0, description="Temperature for generation (0.0 for deterministic)" + default=0.0, + description="Temperature for generation (0.0 for deterministic)", + ) + max_eval_items: int = Field( + default=-1, + description="Maximum number of items to evaluate (-1 for all)", + ) + max_concurrent: int = Field( + default=64, + description="Maximum concurrent requests to the model", ) - max_retries: int = Field( - default=3, description="Maximum retries for failed API calls" - ) - retry_delay: float = Field( - default=1.0, description="Delay between retries in seconds" - ) - min_response_length: int = Field( - default=1, description="Minimum response length to consider valid" - ) - full_debug: bool = Field(default=False, description="Enable full debug output") - max_eval_items: int = Field( - default=-1, description="Maximum number of items to evaluate (-1 for all)" - ) + # Override BaseEnvConfig defaults for evaluation + group_size: int = 1 + max_num_workers: int = 1024 + max_eval_workers: int = 256 + max_num_workers_per_node: int = 128 + use_wandb: bool = True + rollout_server_url: str = "http://localhost:8000" + total_steps: int = 1 + steps_per_eval: int = 1 + wandb_name: str = "verifiers_eval" class VerifiersEvaluationEnv(BaseEnv): """ Verifiers Evaluation Environment. - Evaluates models using Prime Intellect's Verifiers library rubrics. - Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, etc.) + Evaluates models using Prime Intellect's Verifiers library. + Uses verifiers' native rollout and scoring machinery. + + Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.) """ name = "verifiers_evaluation" @@ -117,39 +130,10 @@ class VerifiersEvaluationEnv(BaseEnv): super().__init__(config, server_configs, slurm, testing) self.config: VerifiersEvaluationConfig = config - # Load verifiers environment self.vf_env = vf.load_environment(config.vf_env_name, **config.env_args) - self.rubric = self.vf_env.rubric - # Extract rubric components from RubricGroup - # RubricGroup.funcs is empty - need to collect from individual rubrics - self.parser = self.rubric.parser - self.reward_funcs: List[Callable] = [] - self.reward_weights: List[float] = [] - self.rubric_class_objects: List[Dict[str, Any]] = [] # class_objects per func - - if hasattr(self.rubric, "rubrics"): - # RubricGroup: collect from all individual rubrics - for rubric in self.rubric.rubrics: - class_objects = getattr(rubric, "class_objects", {}) - for func, weight in zip(rubric.funcs, rubric.weights): - self.reward_funcs.append(func) - self.reward_weights.append(weight) - self.rubric_class_objects.append(class_objects) - else: - # Single Rubric - self.reward_funcs = self.rubric.funcs - self.reward_weights = self.rubric.weights - class_objects = getattr(self.rubric, "class_objects", {}) - self.rubric_class_objects = [class_objects] * len(self.rubric.funcs) - - total_weight = sum(self.reward_weights) if self.reward_weights else 1.0 - self.reward_scales = [weight / total_weight for weight in self.reward_weights] - self.system_prompt = self.vf_env.system_prompt - - # Tracking - self.eval_items: List[Dict] = [] - self._dataset_loaded = False + # Get reward function names for metrics reporting + self.reward_func_names = self.vf_env.rubric._get_reward_func_names() @classmethod def config_init(cls) -> Tuple[VerifiersEvaluationConfig, List[APIServerConfig]]: @@ -166,183 +150,87 @@ class VerifiersEvaluationEnv(BaseEnv): ] return env_config, server_configs + def _get_openai_client(self) -> AsyncOpenAI: + """Create AsyncOpenAI client from first server config.""" + server = self.server.servers[0] + config = server.config + return AsyncOpenAI( + api_key=config.api_key or "x", + base_url=config.base_url, + timeout=config.timeout, + ) + + def _get_model_name(self) -> str: + """Get model name from first server config.""" + return self.server.servers[0].config.model_name + async def setup(self) -> None: - """Initialize the environment and load datasets.""" - if not self._dataset_loaded: - # Load datasets from verifiers environment - test_data = self.vf_env.get_eval_dataset() - self.eval_items = test_data.select_columns(["question", "answer"]).to_list() - - # Limit items if max_eval_items is set - if self.config.max_eval_items > 0: - self.eval_items = self.eval_items[: self.config.max_eval_items] - - self._dataset_loaded = True + """Initialize the environment.""" + num_eval = len(self.vf_env.get_eval_dataset()) + if self.config.max_eval_items > 0: + num_eval = min(num_eval, self.config.max_eval_items) print("\nVerifiers Evaluation Setup:") print(f" Environment: {self.config.vf_env_name}") - print(f" Reward functions: {len(self.reward_funcs)}") - print(f" Reward weights: {self.reward_weights}") - print(f" Loaded {len(self.eval_items)} evaluation items") + print(f" Reward functions: {self.reward_func_names}") + print(f" Evaluation items: {num_eval}") + print(f" Max concurrent: {self.config.max_concurrent}") - async def rollout_and_score(self, item: Dict) -> Optional[Dict]: - """ - Run evaluation on a single item and return the result. + async def evaluate(self) -> Dict: + """Run evaluation using verifiers' native machinery.""" + num_examples = ( + self.config.max_eval_items if self.config.max_eval_items > 0 else -1 + ) - Args: - item: Dict with 'question' and 'answer' keys - - Returns: - Dict with evaluation results or None if failed - """ - question = item["question"] - answer = item["answer"] - - messages = [ - {"role": "system", "content": self.system_prompt}, - {"role": "user", "content": question}, - ] - - # Build API call parameters - kwargs = { - "messages": messages, - "temperature": self.config.temperature, - "max_tokens": self.config.max_token_length, - "n": 1, - } - - response_text = "" - for attempt in range(self.config.max_retries): - try: - # Direct API call (no ManagedServer) - eval doesn't need token tracking - response = await self.server.chat_completion(**kwargs) - response_text = response.choices[0].message.content or "" - - if len(response_text) >= self.config.min_response_length: - break - - except Exception as e: - if self.config.full_debug: - print(f" API error (attempt {attempt + 1}): {e}") - if attempt < self.config.max_retries - 1: - await asyncio.sleep(self.config.retry_delay) - continue - - if not response_text: - return None - - # Build completion messages for scoring - completion_messages = messages + [ - {"role": "assistant", "content": response_text} - ] - - # Parse answer - answer_parsed = self.parser.parse_answer(completion=response_text) - - # Score using reward funcs (async functions need await) - # Use signature introspection to pass only required params (like verifiers does) - rewards = [] - for i, func in enumerate(self.reward_funcs): - try: - # Build merged dict of all possible parameters - class_objects = self.rubric_class_objects[i] - merged = { - "completion": completion_messages, - "answer": answer, - "prompt": question, - } - merged.update(class_objects) # Adds parser, etc. - - # Filter to only params the function accepts - sig = inspect.signature(func) - if any(p.kind == p.VAR_KEYWORD for p in sig.parameters.values()): - # Function accepts **kwargs, pass everything - kwargs = merged - else: - # Only pass params in signature - kwargs = {k: v for k, v in merged.items() if k in sig.parameters} - - result = func(**kwargs) - # Reward functions may be async coroutines - if asyncio.iscoroutine(result): - reward = await result - else: - reward = result - reward = float(reward) - except Exception as e: - if self.config.full_debug: - print(f" Reward func {func.__name__} error: {e}") - reward = 0.0 - rewards.append(reward) - - weighted_rewards = [r * self.reward_scales[j] for j, r in enumerate(rewards)] - score = sum(weighted_rewards) - - if self.config.full_debug: - print("\n--- Item ---") - print(f"Question: {question[:100]}...") - print(f"Gold answer: {answer}") - print(f"Model parsed: {answer_parsed}") - print(f"Rewards: {rewards}") - print(f"Score: {score}") - - return { - "question": question, - "gold_answer": answer, - "response": response_text, - "model_parsed": str(answer_parsed) if answer_parsed else None, - "rewards": rewards, - "weighted_rewards": weighted_rewards, - "score": score, - "correct": bool(score > 0), - } - - async def evaluate(self, *args, **kwargs) -> Dict: - """Run the full evaluation.""" print(f"\n{'=' * 60}") print(f"Starting Verifiers Evaluation: {self.config.vf_env_name}") print(f"{'=' * 60}") - print(f" Total questions: {len(self.eval_items)}") + print(f" Model: {self._get_model_name()}") print(f" Temperature: {self.config.temperature}") + print(f" Max concurrent: {self.config.max_concurrent}") print(f"{'=' * 60}\n") start_time = time.time() - # Run sequentially to avoid signal/threading issues with math_verify parser - # The parser uses signals for timeouts which only work in main thread - from tqdm import tqdm + # Create OpenAI client from atropos server config + client = self._get_openai_client() + model = self._get_model_name() - results = [] - for item in tqdm(self.eval_items, desc="Evaluating"): - result = await self.rollout_and_score(item) - results.append(result) - - # Filter out failed results - valid_results = [r for r in results if r is not None] - - if not valid_results: - print("Warning: No valid evaluation results obtained") - return {"error": "No valid results", "accuracy": 0.0} + # Let verifiers handle everything: rollouts + scoring + results = await self.vf_env.evaluate( + client=client, + model=model, + sampling_args={ + "temperature": self.config.temperature, + "max_tokens": self.config.max_token_length, + }, + num_examples=num_examples, + max_concurrent=self.config.max_concurrent, + save_results=False, + ) end_time = time.time() - # Calculate metrics - total = len(valid_results) - scores = [r["score"] for r in valid_results] - correct = sum(1 for r in valid_results if r["correct"]) + # Extract metrics from verifiers output + rewards = results["reward"] + per_func_metrics = results["metrics"] # dict of func_name -> list[float] + prompts = results["prompt"] + completions = results["completion"] + answers = results["answer"] - avg_score = sum(scores) / total if total > 0 else 0.0 + total = len(rewards) + correct = sum(1 for r in rewards if r > 0) + avg_score = sum(rewards) / total if total > 0 else 0.0 accuracy = correct / total if total > 0 else 0.0 # Per-reward function breakdown reward_breakdown = {} - for i, weight in enumerate(self.reward_weights): - func_rewards = [r["rewards"][i] for r in valid_results] - reward_breakdown[f"reward_func_{i}"] = { - "weight": weight, - "avg": sum(func_rewards) / len(func_rewards), - "correct": sum(1 for r in func_rewards if r > 0), - } + for func_name, values in per_func_metrics.items(): + if values: + reward_breakdown[func_name] = { + "avg": sum(values) / len(values), + "correct": sum(1 for v in values if v > 0), + } metrics = { "avg_score": avg_score, @@ -366,22 +254,32 @@ class VerifiersEvaluationEnv(BaseEnv): ) print(f"{'=' * 60}\n") - # Log to evaluate_log - samples = [ - { - "messages": [ - {"role": "system", "content": self.system_prompt}, - {"role": "user", "content": r["question"]}, - {"role": "assistant", "content": r["response"]}, - ], - "question": r["question"], - "gold_answer": r["gold_answer"], - "model_parsed": r["model_parsed"], - "score": r["score"], - "correct": r["correct"], - } - for r in valid_results - ] + # Log to evaluate_log (atropos's logging system) + system_prompt = self.vf_env.system_prompt or "" + samples = [] + for i in range(min(total, 100)): # Limit samples for logging + prompt_msgs = prompts[i] if isinstance(prompts[i], list) else [] + completion_msgs = completions[i] if completions[i] else [] + + # Build full message list + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.extend(prompt_msgs) + if isinstance(completion_msgs, list): + messages.extend(completion_msgs) + + samples.append( + { + "messages": messages, + "gold_answer": answers[i] if i < len(answers) else "", + "score": rewards[i], + "correct": rewards[i] > 0, + "metrics": { + k: v[i] for k, v in per_func_metrics.items() if i < len(v) + }, + } + ) await self.evaluate_log( metrics={"accuracy": accuracy, "avg_score": avg_score}, @@ -430,13 +328,11 @@ class VerifiersEvaluationEnv(BaseEnv): # Required abstract method implementations (stubs for evaluation-only mode) async def get_next_item(self) -> Optional[Dict]: """Not used in evaluation mode.""" - raise NotImplementedError("get_next_item not supported in evaluation-only mode") + return None - async def collect_trajectories(self, item) -> Tuple[List, List]: + async def collect_trajectories(self, item) -> Tuple[List, List]: # noqa: ARG002 """Not used in evaluation mode.""" - raise NotImplementedError( - "collect_trajectories not supported in evaluation-only mode" - ) + return [], [] if __name__ == "__main__": diff --git a/environments/verifiers_server.py b/environments/verifiers_server.py index 42d77c1c..accbf89c 100644 --- a/environments/verifiers_server.py +++ b/environments/verifiers_server.py @@ -40,8 +40,9 @@ Docs: https://docs.primeintellect.ai/tutorials-environments/install """ import asyncio +import logging import time -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import verifiers as vf from openai import AsyncOpenAI @@ -56,6 +57,56 @@ from atroposlib.envs.base import ( ) from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Verifiers API Compatibility Layer +# ============================================================================= + + +def _get_rubric_reward_funcs(rubric: vf.Rubric) -> List[Callable]: + """ + Get reward functions from a Rubric with API version compatibility. + + Handles different verifiers API versions: + - v0.1.9+: rubric._get_reward_funcs() (private) + - v0.1.5-0.1.8: rubric.get_reward_funcs() (public) + - fallback: rubric.funcs (direct access) + """ + if hasattr(rubric, "_get_reward_funcs"): + return rubric._get_reward_funcs() + elif hasattr(rubric, "get_reward_funcs"): + return rubric.get_reward_funcs() + elif hasattr(rubric, "funcs"): + return rubric.funcs + else: + raise AttributeError( + f"Cannot find reward functions on rubric. " + f"Available attrs: {dir(rubric)}" + ) + + +def _get_rubric_reward_weights(rubric: vf.Rubric) -> List[float]: + """ + Get reward weights from a Rubric with API version compatibility. + + Handles different verifiers API versions: + - v0.1.9+: rubric._get_reward_weights() (private) + - v0.1.5-0.1.8: rubric.get_reward_weights() (public) + - fallback: rubric.weights (direct access) + """ + if hasattr(rubric, "_get_reward_weights"): + return rubric._get_reward_weights() + elif hasattr(rubric, "get_reward_weights"): + return rubric.get_reward_weights() + elif hasattr(rubric, "weights"): + return rubric.weights + else: + raise AttributeError( + f"Cannot find reward weights on rubric. " f"Available attrs: {dir(rubric)}" + ) + class VfEnvConfig(BaseEnvConfig): """ @@ -78,8 +129,11 @@ class VerifiersEnv(BaseEnv): testing=False, ): super().__init__(config, server_configs, slurm, testing) - self.percent_correct_buffer = list() - self.eval_metrics = list() + self.percent_correct_buffer: List[float] = [] + self.eval_metrics: List[Tuple[str, float]] = [] + + # Load verifiers environment + logger.info("Loading verifiers environment: %s", config.vf_env_name) self.vf_env = vf.load_environment(config.vf_env_name, **config.env_args) self.rubric = self.vf_env.rubric @@ -88,19 +142,27 @@ class VerifiersEnv(BaseEnv): # Handle both single Rubric and RubricGroup (composite) # RubricGroup has empty funcs/weights at top level - must extract from individual rubrics if hasattr(self.rubric, "rubrics"): - self.reward_funcs = [] - self.reward_weights = [] - for rubric in self.rubric.rubrics: - self.reward_funcs.extend(rubric.funcs) - self.reward_weights.extend(rubric.weights) + # RubricGroup: collect from all individual rubrics + self.reward_funcs: List[Callable] = [] + self.reward_weights: List[float] = [] + for rubric in self.rubric.rubrics: # type: ignore[attr-defined] + self.reward_funcs.extend(_get_rubric_reward_funcs(rubric)) + self.reward_weights.extend(_get_rubric_reward_weights(rubric)) else: - self.reward_funcs = self.rubric.funcs - self.reward_weights = self.rubric.weights + # Single Rubric: use compatibility layer + self.reward_funcs = _get_rubric_reward_funcs(self.rubric) + self.reward_weights = _get_rubric_reward_weights(self.rubric) total = sum(self.reward_weights) if self.reward_weights else 1.0 self.reward_scales = [weight / total for weight in self.reward_weights] self.system_prompt = self.vf_env.system_prompt + logger.info( + "Loaded environment with %d reward functions, system_prompt=%s", + len(self.reward_funcs), + bool(self.system_prompt), + ) + @classmethod def config_init(cls) -> Tuple[VfEnvConfig, List[APIServerConfig]]: env_config = VfEnvConfig( diff --git a/pyproject.toml b/pyproject.toml index 0b1ab5ac..fff88547 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,7 @@ examples = [ "langdetect" ] verifiers = [ - "verifiers>=0.1.5.post0" + "verifiers==0.1.9.post2" ] [build-system] From dceb1d8fd82de0883f30af4ee484e79d542c0804 Mon Sep 17 00:00:00 2001 From: "balyan.sid@gmail.com" Date: Mon, 12 Jan 2026 07:20:56 +0530 Subject: [PATCH 12/22] parallelize verifiers_server: use generate() for SFT, parallel ManagedServer contexts for RL --- environments/verifiers_server.py | 445 +++++++------------------------ 1 file changed, 93 insertions(+), 352 deletions(-) diff --git a/environments/verifiers_server.py b/environments/verifiers_server.py index accbf89c..7ce38da7 100644 --- a/environments/verifiers_server.py +++ b/environments/verifiers_server.py @@ -41,13 +41,11 @@ Docs: https://docs.primeintellect.ai/tutorials-environments/install import asyncio import logging -import time -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import verifiers as vf from openai import AsyncOpenAI from pydantic import Field -from tqdm.asyncio import tqdm_asyncio from atroposlib.envs.base import ( APIServerConfig, @@ -60,59 +58,7 @@ from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer logger = logging.getLogger(__name__) -# ============================================================================= -# Verifiers API Compatibility Layer -# ============================================================================= - - -def _get_rubric_reward_funcs(rubric: vf.Rubric) -> List[Callable]: - """ - Get reward functions from a Rubric with API version compatibility. - - Handles different verifiers API versions: - - v0.1.9+: rubric._get_reward_funcs() (private) - - v0.1.5-0.1.8: rubric.get_reward_funcs() (public) - - fallback: rubric.funcs (direct access) - """ - if hasattr(rubric, "_get_reward_funcs"): - return rubric._get_reward_funcs() - elif hasattr(rubric, "get_reward_funcs"): - return rubric.get_reward_funcs() - elif hasattr(rubric, "funcs"): - return rubric.funcs - else: - raise AttributeError( - f"Cannot find reward functions on rubric. " - f"Available attrs: {dir(rubric)}" - ) - - -def _get_rubric_reward_weights(rubric: vf.Rubric) -> List[float]: - """ - Get reward weights from a Rubric with API version compatibility. - - Handles different verifiers API versions: - - v0.1.9+: rubric._get_reward_weights() (private) - - v0.1.5-0.1.8: rubric.get_reward_weights() (public) - - fallback: rubric.weights (direct access) - """ - if hasattr(rubric, "_get_reward_weights"): - return rubric._get_reward_weights() - elif hasattr(rubric, "get_reward_weights"): - return rubric.get_reward_weights() - elif hasattr(rubric, "weights"): - return rubric.weights - else: - raise AttributeError( - f"Cannot find reward weights on rubric. " f"Available attrs: {dir(rubric)}" - ) - - class VfEnvConfig(BaseEnvConfig): - """ - Configuration for the Verifiers environments. - """ - vf_env_name: str = "" env_args: Dict[str, Any] = Field(default_factory=dict) @@ -130,38 +76,12 @@ class VerifiersEnv(BaseEnv): ): super().__init__(config, server_configs, slurm, testing) self.percent_correct_buffer: List[float] = [] - self.eval_metrics: List[Tuple[str, float]] = [] - # Load verifiers environment logger.info("Loading verifiers environment: %s", config.vf_env_name) self.vf_env = vf.load_environment(config.vf_env_name, **config.env_args) self.rubric = self.vf_env.rubric - - self.parser = self.rubric.parser - - # Handle both single Rubric and RubricGroup (composite) - # RubricGroup has empty funcs/weights at top level - must extract from individual rubrics - if hasattr(self.rubric, "rubrics"): - # RubricGroup: collect from all individual rubrics - self.reward_funcs: List[Callable] = [] - self.reward_weights: List[float] = [] - for rubric in self.rubric.rubrics: # type: ignore[attr-defined] - self.reward_funcs.extend(_get_rubric_reward_funcs(rubric)) - self.reward_weights.extend(_get_rubric_reward_weights(rubric)) - else: - # Single Rubric: use compatibility layer - self.reward_funcs = _get_rubric_reward_funcs(self.rubric) - self.reward_weights = _get_rubric_reward_weights(self.rubric) - - total = sum(self.reward_weights) if self.reward_weights else 1.0 - self.reward_scales = [weight / total for weight in self.reward_weights] self.system_prompt = self.vf_env.system_prompt - - logger.info( - "Loaded environment with %d reward functions, system_prompt=%s", - len(self.reward_funcs), - bool(self.system_prompt), - ) + logger.info("Reward functions: %s", self.rubric._get_reward_func_names()) @classmethod def config_init(cls) -> Tuple[VfEnvConfig, List[APIServerConfig]]: @@ -176,9 +96,6 @@ class VerifiersEnv(BaseEnv): max_token_length=2048, wandb_name="verifiers", ) - # Default config for local inference server (vLLM, SGLang, TRL) - # For SFT data generation with OpenAI, override via CLI: - # --openai.base_url https://api.openai.com/v1 --openai.model_name gpt-4o server_configs = [ APIServerConfig( model_name="gpt-4.1-nano", @@ -193,31 +110,19 @@ class VerifiersEnv(BaseEnv): if wandb_metrics is None: wandb_metrics = {} - # Calculate percent_correct from buffer if self.percent_correct_buffer: wandb_metrics["train/percent_correct"] = sum( self.percent_correct_buffer ) / len(self.percent_correct_buffer) - - self.percent_correct_buffer = list() - - for item in self.eval_metrics: - wandb_metrics[item[0]] = item[1] - self.eval_metrics = list() + self.percent_correct_buffer = [] await super().wandb_log(wandb_metrics) async def setup(self): train_data = self.vf_env.get_dataset() - # Only load columns we need to avoid memory bloat columns_to_keep = ["question", "answer", "info"] available_columns = [c for c in columns_to_keep if c in train_data.column_names] self.train = train_data.select_columns(available_columns).to_list() - test_data = self.vf_env.get_eval_dataset() - available_test_columns = [ - c for c in columns_to_keep if c in test_data.column_names - ] - self.test = test_data.select_columns(available_test_columns).to_list() self.iter = 0 def save_checkpoint(self, step, data=None): @@ -226,145 +131,32 @@ class VerifiersEnv(BaseEnv): data["iter"] = self.iter super().save_checkpoint(step, data) - def _compute_score(self, completion_messages: List[Dict], answer: str) -> float: - """Compute score using verifiers reward functions.""" - rewards = [] - for func in self.reward_funcs: - reward = func( - parser=self.parser, - completion=completion_messages, - answer=answer, - ) - rewards.append(reward) - weighted_rewards = [r * self.reward_scales[j] for j, r in enumerate(rewards)] - return sum(weighted_rewards) - - async def rollout_and_score_eval( - self, question: str, answer: str, **kwargs - ) -> dict: - """ - Rollout and score for evaluation. - Uses ManagedServer in serve mode, direct API calls in process mode. - """ - system_prompt = kwargs.get("system_prompt") - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": question}, - ] - - is_process_mode = getattr(self, "process_mode", False) - - if is_process_mode: - # Process mode: use direct API call (works with any API) - completion = await self.server.chat_completion( - messages=messages, - n=1, - max_tokens=self.config.max_token_length, - temperature=0.0, - ) - else: - # Serve mode: use ManagedServer for token tracking - async with self.server.managed_server(tokenizer=self.tokenizer) as managed: - completion = await managed.chat_completion( - messages=messages, - n=1, - max_tokens=self.config.max_token_length, - temperature=0.0, - ) - - response_content = completion.choices[0].message.content or "" - messages.append({"role": "assistant", "content": response_content}) - - answer_parsed = self.parser.parse_answer(completion=response_content) - - score = self._compute_score(messages, answer) - - sample = { - "messages": messages, - "question": question, - "gold_answer": answer, - "model_parsed": str(answer_parsed) if answer_parsed else None, - "score": score, - "correct": bool(score), - "finish_reason": completion.choices[0].finish_reason, - } - - return {"score": score, "sample": sample} - - async def evaluate(self, *args, **kwargs): - start_time = time.time() - - eval_tasks = [] - for item in self.test: - eval_tasks.append( - self.rollout_and_score_eval( - item["question"], item["answer"], system_prompt=self.system_prompt - ) - ) - results = await tqdm_asyncio.gather(*eval_tasks) - - scores = [result["score"] for result in results] - samples = [result["sample"] for result in results] - - avg_total_score = sum(scores) / len(scores) - - end_time = time.time() - - self.eval_metrics.append(("eval/avg_total_score", avg_total_score)) - - eval_metrics = {"eval/avg_total_score": avg_total_score} - - await self.evaluate_log( - metrics=eval_metrics, - samples=samples, - start_time=start_time, - end_time=end_time, - generation_parameters={ - "temperature": 0.0, - "max_tokens": self.config.max_token_length, - }, - ) - - return eval_metrics - async def get_next_item(self): next_item = self.train[self.iter % len(self.train)] self.iter += 1 return next_item + async def evaluate(self) -> Dict[str, float]: + """No-op. Use environments/eval_environments/verifiers_eval.py for evaluation.""" + return {} + + def _build_initial_messages(self, question: str) -> List[Dict[str, str]]: + messages: List[Dict[str, str]] = [] + if self.system_prompt: + messages.append({"role": "system", "content": self.system_prompt}) + messages.append({"role": "user", "content": question}) + return messages + async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, list]: - """ - Collect trajectories - switches between: - - SFT data generation (process mode): Any API, no logprobs needed - - RL training (serve mode): Local server with logprobs - """ is_process_mode = getattr(self, "process_mode", False) - if is_process_mode: - return await self._collect_trajectories_for_sft(item) - else: - return await self._collect_trajectories_for_rl(item) + return await self._collect_for_sft(item) + return await self._collect_for_rl(item) - async def _collect_trajectories_for_sft( + async def _collect_for_sft( self, item: Dict[str, Any] ) -> Tuple[ScoredDataGroup, list]: - """ - SFT data generation mode - works with ANY API (OpenAI, Claude, local). - Does NOT require logprobs or local server. - - Uses verifiers rollout() for multi-turn environments and tokenize_for_trainer - to tokenize completions with your training tokenizer. - """ - question = item["question"] - answer = item["answer"] - - # Build initial messages - initial_messages: List[Dict[str, str]] = [] - if self.system_prompt: - initial_messages.append({"role": "system", "content": self.system_prompt}) - initial_messages.append({"role": "user", "content": question}) - - # Create AsyncOpenAI client directly from server config (no ManagedServer needed) + """SFT mode: uses vf_env.generate() for parallel generation + scoring.""" server_config = self.server.servers[0].config client = AsyncOpenAI( api_key=server_config.api_key, @@ -372,11 +164,31 @@ class VerifiersEnv(BaseEnv): timeout=server_config.timeout, ) - # Sampling args - use max_completion_tokens for newer models like gpt-5 - sampling_args = { - "temperature": 1.0, - "max_completion_tokens": self.config.max_token_length, - } + initial_messages = self._build_initial_messages(item["question"]) + inputs = [ + { + "prompt": initial_messages, + "answer": item["answer"], + "example_id": i, + "task": self.config.vf_env_name, + "info": item.get("info", {}), + } + for i in range(self.config.group_size) + ] + + results = await self.vf_env.generate( + inputs=inputs, + client=client, + model=server_config.model_name, + sampling_args={ + "temperature": 1.0, + "max_completion_tokens": self.config.max_token_length, + }, + max_concurrent=self.config.group_size, + max_concurrent_scoring=self.config.group_size, + save_results=False, + independent_scoring=True, + ) scored_data = ScoredDataGroup() scored_data["tokens"] = [] @@ -384,63 +196,22 @@ class VerifiersEnv(BaseEnv): scored_data["scores"] = [] scored_data["messages"] = [] - # Semaphore for scoring (required by rubric.score_rollout) - score_sem = asyncio.Semaphore(1) - - # Run rollouts in parallel for group_size - async def run_single_rollout(example_id: int): - # Pass through any info from the dataset item (e.g., docker_image for SWE envs) - item_info = item.get("info", {}) - rollout_input = { - "prompt": initial_messages, - "answer": answer, - "example_id": example_id, - "task": self.config.vf_env_name, - "info": item_info, - } - state = await self.vf_env.rollout( - input=rollout_input, - client=client, - model=server_config.model_name, - sampling_args=sampling_args, - ) - # Score the rollout using verifiers rubric (computes reward from test output) - # This is needed because vf_env.rollout() doesn't call score_rollout - await self.rubric.score_rollout(state, score_sem=score_sem) - return state - - # Run group_size rollouts in parallel - rollout_tasks = [run_single_rollout(i) for i in range(self.config.group_size)] - states = await asyncio.gather(*rollout_tasks) - - for state in states: - # Extract completion messages from state - completion_messages = list(state.get("prompt", [])) + list( - state.get("completion", []) - ) - # Ensure all message contents are strings (not None) - # This can happen with tool call messages that have content: null - completion_messages = [ - {**msg, "content": msg.get("content") or ""} - for msg in completion_messages + for state in results["state"]: + messages = list(state.get("prompt", [])) + list(state.get("completion", [])) + messages = [ + {**msg, "content": msg.get("content") or ""} for msg in messages ] - # Get reward from verifiers scoring (set by rubric.score_rollout above) - score = state.get("reward", 0.0) - - # Determine finish reason from last trajectory step trajectory = state.get("trajectory", []) - if trajectory: - finish_reason = trajectory[-1]["response"].choices[0].finish_reason - else: - finish_reason = "stop" + finish_reason = ( + trajectory[-1]["response"].choices[0].finish_reason + if trajectory + else "stop" + ) - # Use tokenize_for_trainer for tokenization - # train_on_all_assistant_turns=True ensures ALL assistant turns are unmasked - # for multi-turn environments, not just the last message tokenized = tokenize_for_trainer( tokenizer=self.tokenizer, - chat=completion_messages, + chat=messages, include_messages=True, finish_reason=finish_reason, train_on_all_assistant_turns=True, @@ -448,39 +219,54 @@ class VerifiersEnv(BaseEnv): scored_data["tokens"].append(tokenized["tokens"]) scored_data["masks"].append(tokenized["masks"]) - scored_data["messages"].append(completion_messages) - scored_data["scores"].append(score) + scored_data["messages"].append(messages) + scored_data["scores"].append(state.get("reward", 0.0)) - # Track scores for wandb logging for score in scored_data["scores"]: self.percent_correct_buffer.append(max(score, 0)) return scored_data, [] - async def _collect_trajectories_for_rl( + async def _collect_for_rl( self, item: Dict[str, Any] ) -> Tuple[ScoredDataGroup, list]: - """ - RL training mode - requires local inference server for logprobs. - Uses AtroposManagedClient with vf_env.rollout() for both single-turn and multi-turn. - """ + """RL mode: uses ManagedServer for logprobs tracking.""" from atroposlib.envs.server_handling.atropos_managed_client import ( AtroposManagedClient, ) - question = item["question"] - answer = item["answer"] - item_info = item.get("info", {}) - - initial_messages: List[Dict[str, str]] = [] - if self.system_prompt: - initial_messages.append({"role": "system", "content": self.system_prompt}) - initial_messages.append({"role": "user", "content": question}) - + initial_messages = self._build_initial_messages(item["question"]) sampling_args = { "temperature": 1.0, "max_completion_tokens": self.config.max_token_length, } + score_sem = asyncio.Semaphore(self.config.group_size) + model = self.server.servers[0].config.model_name + + async def run_rollout( + example_id: int, + ) -> Tuple[List[int], List[int], List[float], float]: + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + client = AtroposManagedClient(managed_server=managed, model=model) + + state = await self.vf_env.rollout( + input={ + "prompt": initial_messages, + "answer": item["answer"], + "example_id": example_id, + "task": self.config.vf_env_name, + "info": item.get("info", {}), + }, + client=client, + model=model, + sampling_args=sampling_args, + ) + await self.rubric.score_rollout(state, score_sem=score_sem) + return self._extract_from_state(state) + + results = await asyncio.gather( + *[run_rollout(i) for i in range(self.config.group_size)] + ) scored_data = ScoredDataGroup() scored_data["tokens"] = [] @@ -488,44 +274,12 @@ class VerifiersEnv(BaseEnv): scored_data["scores"] = [] scored_data["inference_logprobs"] = [] - # Semaphore for scoring (required by rubric.score_rollout) - score_sem = asyncio.Semaphore(1) + for tokens, masks, logprobs, score in results: + scored_data["tokens"].append(tokens) + scored_data["masks"].append(masks) + scored_data["inference_logprobs"].append(logprobs) + scored_data["scores"].append(score) - async with self.server.managed_server(tokenizer=self.tokenizer) as managed: - client = AtroposManagedClient( - managed_server=managed, - model=self.server_configs[0].model_name, - ) - - # Run group_size rollouts sequentially (ManagedServer state must be reset between) - for i in range(self.config.group_size): - client.reset() - - rollout_input = { - "prompt": initial_messages, - "answer": answer, - "example_id": i, - "task": self.config.vf_env_name, - "info": item_info, - } - - state = await self.vf_env.rollout( - input=rollout_input, - client=client, - model=self.server_configs[0].model_name, - sampling_args=sampling_args, - ) - - # Score the rollout (computes reward from test output) - await self.rubric.score_rollout(state, score_sem=score_sem) - - tokens, masks, logprobs, score = self._extract_from_state(state) - scored_data["tokens"].append(tokens) - scored_data["masks"].append(masks) - scored_data["inference_logprobs"].append(logprobs) - scored_data["scores"].append(score) - - # Track scores for wandb logging for score in scored_data["scores"]: self.percent_correct_buffer.append(max(score, 0)) @@ -534,38 +288,25 @@ class VerifiersEnv(BaseEnv): def _extract_from_state( self, state: Any ) -> Tuple[List[int], List[int], List[float], float]: - """ - Extract tokens/masks/logprobs/score from a single rollout state. - - Handles the mask convention conversion: - - Verifiers: prompt_mask=0, completion_mask=1 - - Atropos: masked_tokens=-100 (prompt), token_id (completion) - """ + """Extract tokens/masks/logprobs from rollout state (RL mode only).""" all_tokens: List[int] = [] all_masks: List[int] = [] all_logprobs: List[float] = [] - trajectory = state.get("trajectory", []) - - for step in trajectory: + for step in state.get("trajectory", []): tokens = step["tokens"] - prompt_ids = tokens["prompt_ids"] completion_ids = tokens["completion_ids"] completion_logprobs = tokens["completion_logprobs"] all_tokens.extend(prompt_ids) all_tokens.extend(completion_ids) - all_masks.extend([-100] * len(prompt_ids)) all_masks.extend(completion_ids) - all_logprobs.extend([1.0] * len(prompt_ids)) all_logprobs.extend(completion_logprobs) - reward = state["reward"] - - return all_tokens, all_masks, all_logprobs, reward + return all_tokens, all_masks, all_logprobs, state["reward"] if __name__ == "__main__": From 9db6c0d1ed7073989b206bf6d1d3755c263505b6 Mon Sep 17 00:00:00 2001 From: "balyan.sid@gmail.com" Date: Mon, 12 Jan 2026 07:32:52 +0530 Subject: [PATCH 13/22] added better wandb logging --- environments/verifiers_server.py | 129 ++++++++++++++++++++++++++----- 1 file changed, 110 insertions(+), 19 deletions(-) diff --git a/environments/verifiers_server.py b/environments/verifiers_server.py index 7ce38da7..13feb6f8 100644 --- a/environments/verifiers_server.py +++ b/environments/verifiers_server.py @@ -41,6 +41,7 @@ Docs: https://docs.primeintellect.ai/tutorials-environments/install import asyncio import logging +from collections import defaultdict from typing import Any, Dict, List, Optional, Tuple import verifiers as vf @@ -75,13 +76,22 @@ class VerifiersEnv(BaseEnv): testing=False, ): super().__init__(config, server_configs, slurm, testing) - self.percent_correct_buffer: List[float] = [] + + # Metrics buffers for wandb logging + self.reward_buffer: List[float] = [] + self.metrics_buffer: Dict[str, List[float]] = defaultdict(list) + self.num_turns_buffer: List[int] = [] + self.groups_with_identical_scores: int = 0 + self.groups_total: int = 0 logger.info("Loading verifiers environment: %s", config.vf_env_name) self.vf_env = vf.load_environment(config.vf_env_name, **config.env_args) self.rubric = self.vf_env.rubric self.system_prompt = self.vf_env.system_prompt - logger.info("Reward functions: %s", self.rubric._get_reward_func_names()) + + # Get reward function names for metrics reporting + self.reward_func_names = self.rubric._get_reward_func_names() + logger.info("Reward functions: %s", self.reward_func_names) @classmethod def config_init(cls) -> Tuple[VfEnvConfig, List[APIServerConfig]]: @@ -107,14 +117,57 @@ class VerifiersEnv(BaseEnv): return env_config, server_configs async def wandb_log(self, wandb_metrics: Optional[Dict] = None): + """Enhanced wandb logging with verifiers-specific metrics.""" if wandb_metrics is None: wandb_metrics = {} - if self.percent_correct_buffer: - wandb_metrics["train/percent_correct"] = sum( - self.percent_correct_buffer - ) / len(self.percent_correct_buffer) - self.percent_correct_buffer = [] + # Log mean reward across all rollouts + if self.reward_buffer: + wandb_metrics["metrics/mean_reward"] = sum(self.reward_buffer) / len( + self.reward_buffer + ) + wandb_metrics["metrics/reward_std"] = ( + ( + sum( + (r - wandb_metrics["metrics/mean_reward"]) ** 2 + for r in self.reward_buffer + ) + / len(self.reward_buffer) + ) + ** 0.5 + if len(self.reward_buffer) > 1 + else 0.0 + ) + self.reward_buffer = [] + + # Log per-reward-function metrics (e.g., strict_accuracy, format_score) + if self.metrics_buffer: + for metric_name, values in self.metrics_buffer.items(): + if values: + avg_metric = sum(values) / len(values) + wandb_metrics[f"metrics/{metric_name}"] = avg_metric + self.metrics_buffer = defaultdict(list) + + # Log multi-turn statistics + if self.num_turns_buffer: + wandb_metrics["metrics/avg_num_turns"] = sum(self.num_turns_buffer) / len( + self.num_turns_buffer + ) + wandb_metrics["metrics/max_num_turns"] = max(self.num_turns_buffer) + self.num_turns_buffer = [] + + # Log group filtering statistics (helpful for debugging) + if self.groups_total > 0: + wandb_metrics["metrics/groups_with_identical_scores"] = ( + self.groups_with_identical_scores + ) + wandb_metrics["metrics/groups_total"] = self.groups_total + wandb_metrics["metrics/identical_score_rate"] = ( + self.groups_with_identical_scores / self.groups_total + ) + # Reset counters + self.groups_with_identical_scores = 0 + self.groups_total = 0 await super().wandb_log(wandb_metrics) @@ -220,10 +273,29 @@ class VerifiersEnv(BaseEnv): scored_data["tokens"].append(tokenized["tokens"]) scored_data["masks"].append(tokenized["masks"]) scored_data["messages"].append(messages) - scored_data["scores"].append(state.get("reward", 0.0)) - for score in scored_data["scores"]: - self.percent_correct_buffer.append(max(score, 0)) + reward = state.get("reward", 0.0) + scored_data["scores"].append(reward) + + # Capture metrics for wandb logging + self.reward_buffer.append(reward) + self.num_turns_buffer.append(len(trajectory)) + + # Extract per-function metrics from verifiers state + state_metrics = state.get("metrics", {}) + if state_metrics: + for metric_name, metric_value in state_metrics.items(): + if isinstance(metric_value, (int, float)): + self.metrics_buffer[metric_name].append(float(metric_value)) + + # Track group-level identical scores for debugging + self.groups_total += 1 + if len(set(scored_data["scores"])) == 1: + self.groups_with_identical_scores += 1 + logger.debug( + "Group has identical scores (%.3f) - will be filtered by base env", + scored_data["scores"][0], + ) return scored_data, [] @@ -243,9 +315,8 @@ class VerifiersEnv(BaseEnv): score_sem = asyncio.Semaphore(self.config.group_size) model = self.server.servers[0].config.model_name - async def run_rollout( - example_id: int, - ) -> Tuple[List[int], List[int], List[float], float]: + async def run_rollout(example_id: int) -> Dict[str, Any]: + """Run a single rollout and return full state for metrics extraction.""" async with self.server.managed_server(tokenizer=self.tokenizer) as managed: client = AtroposManagedClient(managed_server=managed, model=model) @@ -262,9 +333,9 @@ class VerifiersEnv(BaseEnv): sampling_args=sampling_args, ) await self.rubric.score_rollout(state, score_sem=score_sem) - return self._extract_from_state(state) + return state - results = await asyncio.gather( + states = await asyncio.gather( *[run_rollout(i) for i in range(self.config.group_size)] ) @@ -274,14 +345,34 @@ class VerifiersEnv(BaseEnv): scored_data["scores"] = [] scored_data["inference_logprobs"] = [] - for tokens, masks, logprobs, score in results: + for state in states: + tokens, masks, logprobs, reward = self._extract_from_state(state) scored_data["tokens"].append(tokens) scored_data["masks"].append(masks) scored_data["inference_logprobs"].append(logprobs) - scored_data["scores"].append(score) + scored_data["scores"].append(reward) - for score in scored_data["scores"]: - self.percent_correct_buffer.append(max(score, 0)) + # Capture metrics for wandb logging + self.reward_buffer.append(reward) + + trajectory = state.get("trajectory", []) + self.num_turns_buffer.append(len(trajectory)) + + # Extract per-function metrics from verifiers state + state_metrics = state.get("metrics", {}) + if state_metrics: + for metric_name, metric_value in state_metrics.items(): + if isinstance(metric_value, (int, float)): + self.metrics_buffer[metric_name].append(float(metric_value)) + + # Track group-level identical scores for debugging + self.groups_total += 1 + if len(set(scored_data["scores"])) == 1: + self.groups_with_identical_scores += 1 + logger.debug( + "Group has identical scores (%.3f) - will be filtered by base env", + scored_data["scores"][0], + ) return scored_data, [] From 49687304ef3bbbe408e5e9177aa8524329eed92f Mon Sep 17 00:00:00 2001 From: "balyan.sid@gmail.com" Date: Mon, 12 Jan 2026 10:33:39 +0530 Subject: [PATCH 14/22] fix verifiers conflict --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index fff88547..e61f25a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ dependencies = [ "numpy", "wandb", "gymnasium", - "math-verify==0.7.0", + "math-verify>=0.8.0", "jinja2", "nltk", "rich", From 7907ffd0ad3fe36bb67d4e72b383f0f5a101b159 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 12 Jan 2026 05:05:07 +0000 Subject: [PATCH 15/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- environments/eval_environments/verifiers_eval.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/environments/eval_environments/verifiers_eval.py b/environments/eval_environments/verifiers_eval.py index df28973b..d6d8cf98 100644 --- a/environments/eval_environments/verifiers_eval.py +++ b/environments/eval_environments/verifiers_eval.py @@ -31,11 +31,10 @@ import time from typing import Any, Dict, List, Optional, Tuple import verifiers as vf +import wandb from openai import AsyncOpenAI from pydantic import Field -import wandb - from atroposlib.envs.base import ( APIServerConfig, BaseEnv, From a1d1e7d7feed25f7c4df4894a3807e32d4d69020 Mon Sep 17 00:00:00 2001 From: "balyan.sid@gmail.com" Date: Mon, 12 Jan 2026 10:39:43 +0530 Subject: [PATCH 16/22] fix env_args, dataset/prompt loading --- environments/verifiers_server.py | 92 ++++++++++++++++++++------------ 1 file changed, 59 insertions(+), 33 deletions(-) diff --git a/environments/verifiers_server.py b/environments/verifiers_server.py index 13feb6f8..3e70081e 100644 --- a/environments/verifiers_server.py +++ b/environments/verifiers_server.py @@ -8,13 +8,13 @@ Supports TWO modes: Usage: # RL Training (requires local vLLM/SGLang server) python verifiers_server.py serve \ - --env.vf_env_name "will/wordle" \ + --env.vf_env_name "primeintellect/alphabet-sort" \ --openai.base_url http://localhost:9001/v1 \ --slurm false # SFT Data Generation with OpenAI GPT-4o python verifiers_server.py process \ - --env.vf_env_name "will/wordle" \ + --env.vf_env_name "primeintellect/alphabet-sort" \ --env.data_path_to_save_groups gpt4o_sft_data.jsonl \ --env.total_steps 100 \ --env.group_size 4 \ @@ -23,30 +23,30 @@ Usage: # SFT Data Generation with local server python verifiers_server.py process \ - --env.vf_env_name "will/wordle" \ + --env.vf_env_name "primeintellect/alphabet-sort" \ --env.data_path_to_save_groups local_sft_data.jsonl \ --openai.base_url http://localhost:9001/v1 # Evaluation (uses ManagedServer by default, falls back to direct API in process mode) python verifiers_server.py evaluate \ - --env.vf_env_name "will/wordle" \ + --env.vf_env_name "primeintellect/alphabet-sort" \ --openai.base_url http://localhost:9001/v1 To install a Verifiers/Prime environment: 1. uv tool install prime 2. prime login -3. prime env install will/wordle (or any owner/environment) +3. prime env install primeintellect/alphabet-sort (or any owner/environment) Docs: https://docs.primeintellect.ai/tutorials-environments/install """ import asyncio +import json import logging from collections import defaultdict from typing import Any, Dict, List, Optional, Tuple import verifiers as vf from openai import AsyncOpenAI -from pydantic import Field from atroposlib.envs.base import ( APIServerConfig, @@ -61,7 +61,13 @@ logger = logging.getLogger(__name__) class VfEnvConfig(BaseEnvConfig): vf_env_name: str = "" - env_args: Dict[str, Any] = Field(default_factory=dict) + env_args: str = "{}" + + def get_env_args(self) -> Dict[str, Any]: + """Parse env_args JSON string into dict.""" + if isinstance(self.env_args, dict): + return self.env_args + return json.loads(self.env_args) class VerifiersEnv(BaseEnv): @@ -85,7 +91,10 @@ class VerifiersEnv(BaseEnv): self.groups_total: int = 0 logger.info("Loading verifiers environment: %s", config.vf_env_name) - self.vf_env = vf.load_environment(config.vf_env_name, **config.env_args) + env_args = config.get_env_args() + if env_args: + logger.info("Environment args: %s", env_args) + self.vf_env = vf.load_environment(config.vf_env_name, **env_args) self.rubric = self.vf_env.rubric self.system_prompt = self.vf_env.system_prompt @@ -93,6 +102,10 @@ class VerifiersEnv(BaseEnv): self.reward_func_names = self.rubric._get_reward_func_names() logger.info("Reward functions: %s", self.reward_func_names) + # Log multi-turn config if available + if hasattr(self.vf_env, "max_turns"): + logger.info("Max turns: %d", self.vf_env.max_turns) + @classmethod def config_init(cls) -> Tuple[VfEnvConfig, List[APIServerConfig]]: env_config = VfEnvConfig( @@ -172,10 +185,9 @@ class VerifiersEnv(BaseEnv): await super().wandb_log(wandb_metrics) async def setup(self): + # Dataset already has: prompt, answer, info, example_id, task train_data = self.vf_env.get_dataset() - columns_to_keep = ["question", "answer", "info"] - available_columns = [c for c in columns_to_keep if c in train_data.column_names] - self.train = train_data.select_columns(available_columns).to_list() + self.train = train_data.to_list() self.iter = 0 def save_checkpoint(self, step, data=None): @@ -193,13 +205,6 @@ class VerifiersEnv(BaseEnv): """No-op. Use environments/eval_environments/verifiers_eval.py for evaluation.""" return {} - def _build_initial_messages(self, question: str) -> List[Dict[str, str]]: - messages: List[Dict[str, str]] = [] - if self.system_prompt: - messages.append({"role": "system", "content": self.system_prompt}) - messages.append({"role": "user", "content": question}) - return messages - async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, list]: is_process_mode = getattr(self, "process_mode", False) if is_process_mode: @@ -217,16 +222,16 @@ class VerifiersEnv(BaseEnv): timeout=server_config.timeout, ) - initial_messages = self._build_initial_messages(item["question"]) + # item already has prompt, answer, example_id, task, info from dataset inputs = [ { - "prompt": initial_messages, - "answer": item["answer"], - "example_id": i, - "task": self.config.vf_env_name, + "prompt": item["prompt"], + "answer": item.get("answer", ""), + "example_id": item["example_id"], + "task": item.get("task", self.config.vf_env_name), "info": item.get("info", {}), } - for i in range(self.config.group_size) + for _ in range(self.config.group_size) ] results = await self.vf_env.generate( @@ -279,7 +284,9 @@ class VerifiersEnv(BaseEnv): # Capture metrics for wandb logging self.reward_buffer.append(reward) - self.num_turns_buffer.append(len(trajectory)) + num_turns = len(trajectory) + self.num_turns_buffer.append(num_turns) + logger.debug("Rollout: %d turns, reward=%.3f", num_turns, reward) # Extract per-function metrics from verifiers state state_metrics = state.get("metrics", {}) @@ -288,6 +295,15 @@ class VerifiersEnv(BaseEnv): if isinstance(metric_value, (int, float)): self.metrics_buffer[metric_name].append(float(metric_value)) + # Log group summary + turns = [len(s.get("trajectory", [])) for s in results["state"]] + logger.info( + "Group: %d rollouts, turns=%s, rewards=%s", + len(results["state"]), + turns, + [f"{s:.3f}" for s in scored_data["scores"]], + ) + # Track group-level identical scores for debugging self.groups_total += 1 if len(set(scored_data["scores"])) == 1: @@ -307,7 +323,6 @@ class VerifiersEnv(BaseEnv): AtroposManagedClient, ) - initial_messages = self._build_initial_messages(item["question"]) sampling_args = { "temperature": 1.0, "max_completion_tokens": self.config.max_token_length, @@ -315,17 +330,17 @@ class VerifiersEnv(BaseEnv): score_sem = asyncio.Semaphore(self.config.group_size) model = self.server.servers[0].config.model_name - async def run_rollout(example_id: int) -> Dict[str, Any]: + async def run_rollout() -> Dict[str, Any]: """Run a single rollout and return full state for metrics extraction.""" async with self.server.managed_server(tokenizer=self.tokenizer) as managed: client = AtroposManagedClient(managed_server=managed, model=model) state = await self.vf_env.rollout( input={ - "prompt": initial_messages, - "answer": item["answer"], - "example_id": example_id, - "task": self.config.vf_env_name, + "prompt": item["prompt"], + "answer": item.get("answer", ""), + "example_id": item["example_id"], + "task": item.get("task", self.config.vf_env_name), "info": item.get("info", {}), }, client=client, @@ -336,7 +351,7 @@ class VerifiersEnv(BaseEnv): return state states = await asyncio.gather( - *[run_rollout(i) for i in range(self.config.group_size)] + *[run_rollout() for _ in range(self.config.group_size)] ) scored_data = ScoredDataGroup() @@ -356,7 +371,9 @@ class VerifiersEnv(BaseEnv): self.reward_buffer.append(reward) trajectory = state.get("trajectory", []) - self.num_turns_buffer.append(len(trajectory)) + num_turns = len(trajectory) + self.num_turns_buffer.append(num_turns) + logger.debug("Rollout: %d turns, reward=%.3f", num_turns, reward) # Extract per-function metrics from verifiers state state_metrics = state.get("metrics", {}) @@ -365,6 +382,15 @@ class VerifiersEnv(BaseEnv): if isinstance(metric_value, (int, float)): self.metrics_buffer[metric_name].append(float(metric_value)) + # Log group summary + turns = [len(s.get("trajectory", [])) for s in states] + logger.info( + "Group: %d rollouts, turns=%s, rewards=%s", + len(states), + turns, + [f"{s:.3f}" for s in scored_data["scores"]], + ) + # Track group-level identical scores for debugging self.groups_total += 1 if len(set(scored_data["scores"])) == 1: From 32320512e831253e5272f10a26f349ca6e329c53 Mon Sep 17 00:00:00 2001 From: "balyan.sid@gmail.com" Date: Tue, 13 Jan 2026 15:00:54 +0530 Subject: [PATCH 17/22] update verifiers_server to use tokenizer_for_trainer --- environments/verifiers_server.py | 156 ++++--------------------------- 1 file changed, 20 insertions(+), 136 deletions(-) diff --git a/environments/verifiers_server.py b/environments/verifiers_server.py index 3e70081e..44c033b9 100644 --- a/environments/verifiers_server.py +++ b/environments/verifiers_server.py @@ -1,12 +1,12 @@ """ Verifiers Training Environment for Atropos -Supports TWO modes: -- serve: RL training with local inference server (requires ManagedServer for logprobs) -- process: SFT data generation with ANY API (OpenAI, Claude, local, etc.) +Unified environment that works for both RL training (serve) and SFT data generation (process). +Uses vf_env.generate() with standard AsyncOpenAI client and tokenize_for_trainer() for +token/mask generation. No inference logprobs needed - GRPO computes fresh logprobs during training. Usage: - # RL Training (requires local vLLM/SGLang server) + # RL Training (GRPO - no inference logprobs needed) python verifiers_server.py serve \ --env.vf_env_name "primeintellect/alphabet-sort" \ --openai.base_url http://localhost:9001/v1 \ @@ -27,11 +27,6 @@ Usage: --env.data_path_to_save_groups local_sft_data.jsonl \ --openai.base_url http://localhost:9001/v1 - # Evaluation (uses ManagedServer by default, falls back to direct API in process mode) - python verifiers_server.py evaluate \ - --env.vf_env_name "primeintellect/alphabet-sort" \ - --openai.base_url http://localhost:9001/v1 - To install a Verifiers/Prime environment: 1. uv tool install prime 2. prime login @@ -39,7 +34,6 @@ To install a Verifiers/Prime environment: Docs: https://docs.primeintellect.ai/tutorials-environments/install """ -import asyncio import json import logging from collections import defaultdict @@ -206,15 +200,12 @@ class VerifiersEnv(BaseEnv): return {} async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, list]: - is_process_mode = getattr(self, "process_mode", False) - if is_process_mode: - return await self._collect_for_sft(item) - return await self._collect_for_rl(item) + """Unified trajectory collection using vf_env.generate(). - async def _collect_for_sft( - self, item: Dict[str, Any] - ) -> Tuple[ScoredDataGroup, list]: - """SFT mode: uses vf_env.generate() for parallel generation + scoring.""" + Works for both RL training (serve) and SFT data generation (process). + Uses tokenize_for_trainer() for token/mask generation - no inference + logprobs needed since GRPO computes fresh logprobs during training. + """ server_config = self.server.servers[0].config client = AsyncOpenAI( api_key=server_config.api_key, @@ -222,7 +213,7 @@ class VerifiersEnv(BaseEnv): timeout=server_config.timeout, ) - # item already has prompt, answer, example_id, task, info from dataset + # Build inputs for group_size rollouts inputs = [ { "prompt": item["prompt"], @@ -234,6 +225,7 @@ class VerifiersEnv(BaseEnv): for _ in range(self.config.group_size) ] + # Use vf_env.generate() - handles batching and scoring internally results = await self.vf_env.generate( inputs=inputs, client=client, @@ -255,11 +247,13 @@ class VerifiersEnv(BaseEnv): scored_data["messages"] = [] for state in results["state"]: + # Extract messages from state messages = list(state.get("prompt", [])) + list(state.get("completion", [])) messages = [ {**msg, "content": msg.get("content") or ""} for msg in messages ] + # Get finish_reason for proper tokenization trajectory = state.get("trajectory", []) finish_reason = ( trajectory[-1]["response"].choices[0].finish_reason @@ -267,6 +261,7 @@ class VerifiersEnv(BaseEnv): else "stop" ) + # Tokenize with multi-turn support tokenized = tokenize_for_trainer( tokenizer=self.tokenizer, chat=messages, @@ -282,18 +277,17 @@ class VerifiersEnv(BaseEnv): reward = state.get("reward", 0.0) scored_data["scores"].append(reward) - # Capture metrics for wandb logging + # Metrics logging self.reward_buffer.append(reward) num_turns = len(trajectory) self.num_turns_buffer.append(num_turns) logger.debug("Rollout: %d turns, reward=%.3f", num_turns, reward) - # Extract per-function metrics from verifiers state + # Per-function metrics from verifiers state state_metrics = state.get("metrics", {}) - if state_metrics: - for metric_name, metric_value in state_metrics.items(): - if isinstance(metric_value, (int, float)): - self.metrics_buffer[metric_name].append(float(metric_value)) + for metric_name, metric_value in state_metrics.items(): + if isinstance(metric_value, (int, float)): + self.metrics_buffer[metric_name].append(float(metric_value)) # Log group summary turns = [len(s.get("trajectory", [])) for s in results["state"]] @@ -304,7 +298,7 @@ class VerifiersEnv(BaseEnv): [f"{s:.3f}" for s in scored_data["scores"]], ) - # Track group-level identical scores for debugging + # Track identical scores for debugging self.groups_total += 1 if len(set(scored_data["scores"])) == 1: self.groups_with_identical_scores += 1 @@ -315,116 +309,6 @@ class VerifiersEnv(BaseEnv): return scored_data, [] - async def _collect_for_rl( - self, item: Dict[str, Any] - ) -> Tuple[ScoredDataGroup, list]: - """RL mode: uses ManagedServer for logprobs tracking.""" - from atroposlib.envs.server_handling.atropos_managed_client import ( - AtroposManagedClient, - ) - - sampling_args = { - "temperature": 1.0, - "max_completion_tokens": self.config.max_token_length, - } - score_sem = asyncio.Semaphore(self.config.group_size) - model = self.server.servers[0].config.model_name - - async def run_rollout() -> Dict[str, Any]: - """Run a single rollout and return full state for metrics extraction.""" - async with self.server.managed_server(tokenizer=self.tokenizer) as managed: - client = AtroposManagedClient(managed_server=managed, model=model) - - state = await self.vf_env.rollout( - input={ - "prompt": item["prompt"], - "answer": item.get("answer", ""), - "example_id": item["example_id"], - "task": item.get("task", self.config.vf_env_name), - "info": item.get("info", {}), - }, - client=client, - model=model, - sampling_args=sampling_args, - ) - await self.rubric.score_rollout(state, score_sem=score_sem) - return state - - states = await asyncio.gather( - *[run_rollout() for _ in range(self.config.group_size)] - ) - - scored_data = ScoredDataGroup() - scored_data["tokens"] = [] - scored_data["masks"] = [] - scored_data["scores"] = [] - scored_data["inference_logprobs"] = [] - - for state in states: - tokens, masks, logprobs, reward = self._extract_from_state(state) - scored_data["tokens"].append(tokens) - scored_data["masks"].append(masks) - scored_data["inference_logprobs"].append(logprobs) - scored_data["scores"].append(reward) - - # Capture metrics for wandb logging - self.reward_buffer.append(reward) - - trajectory = state.get("trajectory", []) - num_turns = len(trajectory) - self.num_turns_buffer.append(num_turns) - logger.debug("Rollout: %d turns, reward=%.3f", num_turns, reward) - - # Extract per-function metrics from verifiers state - state_metrics = state.get("metrics", {}) - if state_metrics: - for metric_name, metric_value in state_metrics.items(): - if isinstance(metric_value, (int, float)): - self.metrics_buffer[metric_name].append(float(metric_value)) - - # Log group summary - turns = [len(s.get("trajectory", [])) for s in states] - logger.info( - "Group: %d rollouts, turns=%s, rewards=%s", - len(states), - turns, - [f"{s:.3f}" for s in scored_data["scores"]], - ) - - # Track group-level identical scores for debugging - self.groups_total += 1 - if len(set(scored_data["scores"])) == 1: - self.groups_with_identical_scores += 1 - logger.debug( - "Group has identical scores (%.3f) - will be filtered by base env", - scored_data["scores"][0], - ) - - return scored_data, [] - - def _extract_from_state( - self, state: Any - ) -> Tuple[List[int], List[int], List[float], float]: - """Extract tokens/masks/logprobs from rollout state (RL mode only).""" - all_tokens: List[int] = [] - all_masks: List[int] = [] - all_logprobs: List[float] = [] - - for step in state.get("trajectory", []): - tokens = step["tokens"] - prompt_ids = tokens["prompt_ids"] - completion_ids = tokens["completion_ids"] - completion_logprobs = tokens["completion_logprobs"] - - all_tokens.extend(prompt_ids) - all_tokens.extend(completion_ids) - all_masks.extend([-100] * len(prompt_ids)) - all_masks.extend(completion_ids) - all_logprobs.extend([1.0] * len(prompt_ids)) - all_logprobs.extend(completion_logprobs) - - return all_tokens, all_masks, all_logprobs, state["reward"] - if __name__ == "__main__": VerifiersEnv.cli() From 6a27e88023d2990f1c4156e9a977833dbe294796 Mon Sep 17 00:00:00 2001 From: "balyan.sid@gmail.com" Date: Wed, 14 Jan 2026 17:09:01 +0530 Subject: [PATCH 18/22] use managed server --- .../envs/server_handling/vllm_server.py | 2 + environments/verifiers_server.py | 182 ++++++++++++------ 2 files changed, 128 insertions(+), 56 deletions(-) diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py index 8e043cf9..48c8cb8d 100644 --- a/atroposlib/envs/server_handling/vllm_server.py +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -169,6 +169,8 @@ class VLLMServer(APIServer): prompt_tokens = prompt_tokens[1:] if "max_new_tokens" in kwargs: kwargs["max_tokens"] = kwargs.pop("max_new_tokens") + if "max_completion_tokens" in kwargs: + kwargs["max_tokens"] = kwargs.pop("max_completion_tokens") if "model" in kwargs: kwargs.pop("model") # Prepare request for VLLM native API diff --git a/environments/verifiers_server.py b/environments/verifiers_server.py index 44c033b9..0cd8cf45 100644 --- a/environments/verifiers_server.py +++ b/environments/verifiers_server.py @@ -2,8 +2,7 @@ Verifiers Training Environment for Atropos Unified environment that works for both RL training (serve) and SFT data generation (process). -Uses vf_env.generate() with standard AsyncOpenAI client and tokenize_for_trainer() for -token/mask generation. No inference logprobs needed - GRPO computes fresh logprobs during training. +Uses vf_env.generate() with ManagedServer (via adapter) for automatic token and logprob tracking. Usage: # RL Training (GRPO - no inference logprobs needed) @@ -40,7 +39,6 @@ from collections import defaultdict from typing import Any, Dict, List, Optional, Tuple import verifiers as vf -from openai import AsyncOpenAI from atroposlib.envs.base import ( APIServerConfig, @@ -48,11 +46,65 @@ from atroposlib.envs.base import ( BaseEnvConfig, ScoredDataGroup, ) -from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer +from atroposlib.envs.server_handling.managed_server import ManagedServer logger = logging.getLogger(__name__) +class ManagedServerAdapter: + """ + Adapter that makes ManagedServer look like AsyncOpenAI for verifiers. + + Implements the subset of AsyncOpenAI interface that verifiers uses: + - client.chat.completions.create() + - client.completions.create() + - client.base_url + """ + + def __init__(self, managed_server: ManagedServer, base_url: str): + self._managed = managed_server + self.base_url = base_url + self.chat = self._ChatNamespace(self._managed) + self.completions = self._CompletionsNamespace(self._managed) + + class _ChatNamespace: + def __init__(self, managed: ManagedServer): + self._managed = managed + self.completions = ManagedServerAdapter._ChatCompletionsNamespace(managed) + + class _ChatCompletionsNamespace: + def __init__(self, managed: ManagedServer): + self._managed = managed + + async def create(self, **kwargs): + logger.info( + "ManagedServerAdapter.chat.completions.create called with model=%s", + kwargs.get("model"), + ) + result = await self._managed.chat_completion(**kwargs) + logger.info("ManagedServerAdapter.chat.completions.create completed") + return result + + class _CompletionsNamespace: + def __init__(self, managed: ManagedServer): + self._managed = managed + + async def create(self, **kwargs): + return await self._managed.completion(**kwargs) + + async def post(self, path: str, body: dict, cast_to: type): + raise NotImplementedError( + f"ManagedServerAdapter does not support post() for path '{path}'. " + "This is used for vLLM interleaved rollouts. Use standard chat completions." + ) + + def copy(self, **kwargs): + raise NotImplementedError( + "ManagedServerAdapter does not support copy(). " + "This is used for vLLM tokenization endpoints." + ) + + class VfEnvConfig(BaseEnvConfig): vf_env_name: str = "" env_args: str = "{}" @@ -115,10 +167,11 @@ class VerifiersEnv(BaseEnv): ) server_configs = [ APIServerConfig( - model_name="gpt-4.1-nano", - base_url="https://api.openai.com/v1", + model_name="Qwen/Qwen2.5-1.5B-Instruct", + base_url="http://localhost:9001/v1", api_key="x", num_requests_for_eval=4, + server_type="sglang", ), ] return env_config, server_configs @@ -200,18 +253,20 @@ class VerifiersEnv(BaseEnv): return {} async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, list]: - """Unified trajectory collection using vf_env.generate(). + """Unified trajectory collection using vf_env.generate() with ManagedServer. Works for both RL training (serve) and SFT data generation (process). - Uses tokenize_for_trainer() for token/mask generation - no inference - logprobs needed since GRPO computes fresh logprobs during training. + Uses ManagedServer adapter for automatic token and logprob tracking. """ - server_config = self.server.servers[0].config - client = AsyncOpenAI( - api_key=server_config.api_key, - base_url=server_config.base_url, - timeout=server_config.timeout, - ) + # Get server config (handle both real servers and test harness) + if hasattr(self.server, "servers") and self.server.servers: + server_config = self.server.servers[0].config + else: + # Fallback for testing + server_config = APIServerConfig( + model_name=self.config.tokenizer_name, + base_url="http://localhost:8000/v1", + ) # Build inputs for group_size rollouts inputs = [ @@ -225,56 +280,70 @@ class VerifiersEnv(BaseEnv): for _ in range(self.config.group_size) ] - # Use vf_env.generate() - handles batching and scoring internally - results = await self.vf_env.generate( - inputs=inputs, - client=client, - model=server_config.model_name, - sampling_args={ - "temperature": 1.0, - "max_completion_tokens": self.config.max_token_length, - }, - max_concurrent=self.config.group_size, - max_concurrent_scoring=self.config.group_size, - save_results=False, - independent_scoring=True, - ) + # Use ManagedServer for automatic token/logprob tracking + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + # Create adapter that looks like AsyncOpenAI for verifiers + adapter = ManagedServerAdapter( + managed_server=managed, + base_url=server_config.base_url, + ) - scored_data = ScoredDataGroup() - scored_data["tokens"] = [] - scored_data["masks"] = [] - scored_data["scores"] = [] - scored_data["messages"] = [] + # Use vf_env.generate() - handles batching and scoring internally + results = await self.vf_env.generate( + inputs=inputs, + client=adapter, + model=server_config.model_name, + sampling_args={ + "temperature": 1.0, + "max_completion_tokens": self.config.max_token_length, + }, + max_concurrent=self.config.group_size, + max_concurrent_scoring=self.config.group_size, + save_results=False, + independent_scoring=True, + ) - for state in results["state"]: + # Get tracked state from ManagedServer + managed_state = managed.get_state() + nodes = managed_state["nodes"] + + scored_data: ScoredDataGroup = { + "tokens": [], + "masks": [], + "scores": [], + "messages": [], + "inference_logprobs": [], + } + + # Zip verifiers states with ManagedServer nodes for logprob tracking + for i, vf_state in enumerate(results["state"]): # Extract messages from state - messages = list(state.get("prompt", [])) + list(state.get("completion", [])) + messages = list(vf_state.get("prompt", [])) + list( + vf_state.get("completion", []) + ) messages = [ {**msg, "content": msg.get("content") or ""} for msg in messages ] - # Get finish_reason for proper tokenization - trajectory = state.get("trajectory", []) - finish_reason = ( - trajectory[-1]["response"].choices[0].finish_reason - if trajectory - else "stop" - ) + # Get trajectory for metrics + trajectory = vf_state.get("trajectory", []) - # Tokenize with multi-turn support - tokenized = tokenize_for_trainer( - tokenizer=self.tokenizer, - chat=messages, - include_messages=True, - finish_reason=finish_reason, - train_on_all_assistant_turns=True, - ) + # Get tokens, masks, and logprobs from ManagedServer + # IMPORTANT: We use ManagedServer's tokens (not re-tokenize) to ensure + # alignment with logprobs. ManagedServer tracks tokens and logprobs together. + if i >= len(nodes): + raise RuntimeError( + f"Node count mismatch: expected at least {i + 1} nodes, got {len(nodes)}. " + "ManagedServer should track all rollouts." + ) - scored_data["tokens"].append(tokenized["tokens"]) - scored_data["masks"].append(tokenized["masks"]) + node = nodes[i] + scored_data["tokens"].append(node.tokens) + scored_data["masks"].append(node.masked_tokens) + scored_data["inference_logprobs"].append(node.logprobs) scored_data["messages"].append(messages) - reward = state.get("reward", 0.0) + reward = vf_state.get("reward", 0.0) scored_data["scores"].append(reward) # Metrics logging @@ -284,7 +353,7 @@ class VerifiersEnv(BaseEnv): logger.debug("Rollout: %d turns, reward=%.3f", num_turns, reward) # Per-function metrics from verifiers state - state_metrics = state.get("metrics", {}) + state_metrics = vf_state.get("metrics", {}) for metric_name, metric_value in state_metrics.items(): if isinstance(metric_value, (int, float)): self.metrics_buffer[metric_name].append(float(metric_value)) @@ -292,10 +361,11 @@ class VerifiersEnv(BaseEnv): # Log group summary turns = [len(s.get("trajectory", [])) for s in results["state"]] logger.info( - "Group: %d rollouts, turns=%s, rewards=%s", + "Group: %d rollouts, turns=%s, rewards=%s, nodes=%d", len(results["state"]), turns, [f"{s:.3f}" for s in scored_data["scores"]], + len(nodes), ) # Track identical scores for debugging From 57fa229846e2390ec5b2486b47b46db53b4e7f36 Mon Sep 17 00:00:00 2001 From: "balyan.sid@gmail.com" Date: Wed, 14 Jan 2026 17:09:57 +0530 Subject: [PATCH 19/22] remove unused managed_server wrapper + tese --- .../server_handling/atropos_managed_client.py | 330 ------------------ .../tests/test_atropos_managed_client.py | 235 ------------- 2 files changed, 565 deletions(-) delete mode 100644 atroposlib/envs/server_handling/atropos_managed_client.py delete mode 100644 atroposlib/tests/test_atropos_managed_client.py diff --git a/atroposlib/envs/server_handling/atropos_managed_client.py b/atroposlib/envs/server_handling/atropos_managed_client.py deleted file mode 100644 index 75e78e39..00000000 --- a/atroposlib/envs/server_handling/atropos_managed_client.py +++ /dev/null @@ -1,330 +0,0 @@ -""" -AtroposManagedClient: AsyncOpenAI-compatible client backed by ManagedServer. - -This module provides a drop-in replacement for AsyncOpenAI that uses Atropos's -ManagedServer for inference, enabling token tracking for multi-turn RL training -with the Verifiers library. - -Usage: - async with server_manager.managed_server(tokenizer=tokenizer) as managed: - client = AtroposManagedClient(managed_server=managed, model="model-name") - - # Use like AsyncOpenAI - tokens are tracked automatically - response = await client.chat.completions.create( - messages=[{"role": "user", "content": "Hello"}], - max_tokens=100 - ) - - # Token data is available on the response: - # - response.prompt_token_ids - # - response.choices[0].token_ids - # - response.choices[0].logprobs.content[i].logprob -""" - -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional - -from openai.types.chat.chat_completion_message import ChatCompletionMessage - -from atroposlib.envs.server_handling.managed_server import ManagedServer, SequenceNode - -# ============================================================================= -# Enhanced Types for Token Data Injection -# ============================================================================= - - -@dataclass -class LogprobContent: - """ - Single token logprob entry. - - Compatible with verifiers' parse_response_tokens() which accesses: - - response.choices[i].logprobs.content[j].logprob - """ - - logprob: float - token: str = "" - token_id: int = 0 - top_logprobs: Optional[List[Any]] = None - - -@dataclass -class ChoiceLogprobs: - """ - Logprobs structure compatible with verifiers expectations. - - Verifiers checks for either object or dict format: - - Object: response.choices[i].logprobs.content[j].logprob - - Dict: response.choices[i].logprobs["content"][j]["logprob"] - - This dataclass supports the object format. - """ - - content: List[LogprobContent] = field(default_factory=list) - - -@dataclass -class EnhancedChoice: - """ - Choice with token_ids and logprobs for RL training. - - Adds the following attributes that verifiers expects: - - token_ids: List[int] - completion token IDs - - logprobs: ChoiceLogprobs - structured logprobs - """ - - index: int - message: ChatCompletionMessage - finish_reason: str - token_ids: List[int] - logprobs: ChoiceLogprobs - - -@dataclass -class EnhancedChatCompletion: - """ - ChatCompletion with token data for RL training. - - Compatible with verifiers' parse_response_tokens() expectations: - - prompt_token_ids: list[int] - - choices[i].token_ids: list[int] - - choices[i].logprobs.content[j].logprob - """ - - id: str - created: int - model: str - object: str - choices: List[EnhancedChoice] - prompt_token_ids: List[int] - usage: Optional[Dict[str, int]] = None - - -# ============================================================================= -# AsyncOpenAI-Compatible Client Classes -# ============================================================================= - - -class _CompletionsNamespace: - """ - Mimics openai.resources.chat.completions.AsyncCompletions. - - Provides the create() method that verifiers calls. - """ - - def __init__(self, parent: "AtroposManagedClient"): - self.parent = parent - - async def create( - self, - *, - messages: List[Dict[str, Any]], - model: Optional[str] = None, - n: int = 1, - max_tokens: Optional[int] = None, - max_completion_tokens: Optional[int] = None, - temperature: float = 1.0, - top_p: float = 1.0, - tools: Optional[List[Dict]] = None, - stop: Optional[List[str]] = None, - **kwargs, - ) -> EnhancedChatCompletion: - """ - Create chat completion with token tracking. - - Returns ChatCompletion with additional attributes: - - prompt_token_ids: list[int] - - choices[i].token_ids: list[int] - - choices[i].logprobs.content: list with logprob info - - Args: - messages: List of message dicts with 'role' and 'content' - model: Model name (defaults to client's model) - n: Number of completions (should be 1 for multi-turn) - max_tokens: Max tokens in completion (legacy param) - max_completion_tokens: Max tokens in completion (new param) - temperature: Sampling temperature - top_p: Nucleus sampling parameter - tools: Tool definitions for function calling - stop: Stop sequences - **kwargs: Additional parameters passed to ManagedServer - """ - # Use max_completion_tokens if provided, else max_tokens - effective_max_tokens = max_completion_tokens or max_tokens - - # Build kwargs for ManagedServer - completion_kwargs = { - "messages": messages, - "model": model or self.parent.model, - "n": n, - "temperature": temperature, - "top_p": top_p, - } - - if effective_max_tokens is not None: - completion_kwargs["max_tokens"] = effective_max_tokens - - if tools is not None: - completion_kwargs["tools"] = tools - - if stop is not None: - completion_kwargs["stop"] = stop - - # Add any extra kwargs (like logprobs settings) - for key, value in kwargs.items(): - if value is not None: - completion_kwargs[key] = value - - # Call ManagedServer for inference - completion = await self.parent.managed_server.chat_completion( - **completion_kwargs - ) - - # Get token state from managed server - state = self.parent.managed_server.get_state() - nodes: List[SequenceNode] = state["nodes"] - - # Inject token data into response - return self._enhance_completion(completion, nodes) - - def _enhance_completion( - self, completion: Any, nodes: List[SequenceNode] - ) -> EnhancedChatCompletion: - """ - Convert ManagedServer output to verifiers-compatible format. - - Extracts token data from SequenceNodes and injects it into the - ChatCompletion response in the format verifiers expects. - """ - enhanced_choices = [] - prompt_token_ids: List[int] = [] - - for i, (choice, node) in enumerate(zip(completion.choices, nodes)): - # Find prompt/completion boundary from masked_tokens - # -100 indicates prompt tokens, actual token IDs indicate completion - prompt_len = sum(1 for m in node.masked_tokens if m == -100) - - # Extract prompt and completion portions - if i == 0: - prompt_token_ids = node.tokens[:prompt_len] - - completion_ids = node.tokens[prompt_len:] - completion_logprobs = node.logprobs[prompt_len:] - - # Build logprobs structure verifiers expects - logprobs_content = [] - tokenizer = self.parent.managed_server.tokenizer - - for token_id, logprob in zip(completion_ids, completion_logprobs): - # Decode token to string if tokenizer available - token_str = "" - if tokenizer is not None: - try: - token_str = tokenizer.decode([token_id]) - except Exception: - token_str = f"" - - logprobs_content.append( - LogprobContent( - logprob=logprob, - token=token_str, - token_id=token_id, - ) - ) - - # Create enhanced choice with token data - enhanced_choice = EnhancedChoice( - index=choice.index, - message=choice.message, - finish_reason=choice.finish_reason or "stop", - token_ids=completion_ids, - logprobs=ChoiceLogprobs(content=logprobs_content), - ) - enhanced_choices.append(enhanced_choice) - - return EnhancedChatCompletion( - id=completion.id, - created=completion.created, - model=completion.model, - object=completion.object, - choices=enhanced_choices, - prompt_token_ids=prompt_token_ids, - usage=completion.usage.model_dump() if completion.usage else None, - ) - - -class _ChatNamespace: - """Mimics openai.resources.chat.AsyncChat.""" - - def __init__(self, parent: "AtroposManagedClient"): - self.completions = _CompletionsNamespace(parent) - - -class AtroposManagedClient: - """ - AsyncOpenAI-compatible client backed by ManagedServer. - - This client provides the same interface as AsyncOpenAI but uses Atropos's - ManagedServer for inference, enabling automatic token tracking for - multi-turn RL training with the Verifiers library. - - The key feature is that responses include token data attributes that - verifiers' parse_response_tokens() expects: - - response.prompt_token_ids - - response.choices[i].token_ids - - response.choices[i].logprobs.content[j].logprob - - Usage: - async with server_manager.managed_server(tokenizer=tokenizer) as managed: - client = AtroposManagedClient( - managed_server=managed, - model="Qwen/Qwen2.5-1.5B-Instruct" - ) - - # Pass to verifiers env.rollout() - state = await vf_env.rollout( - input=rollout_input, - client=client, - model="Qwen/Qwen2.5-1.5B-Instruct", - ) - - # Token data is now in state["trajectory"][i]["tokens"] - """ - - def __init__( - self, - managed_server: ManagedServer, - model: str, - base_url: Optional[str] = None, - ): - """ - Initialize the managed client. - - Args: - managed_server: ManagedServer instance for inference and token tracking - model: Model name to use for completions - base_url: Optional base URL (for API compatibility, not used) - """ - self.managed_server = managed_server - self.model = model - self.base_url = base_url or "http://managed-server" - - # Mimic AsyncOpenAI namespace structure - self.chat = _ChatNamespace(self) - - def reset(self): - """Reset token tracking state between rollouts.""" - self.managed_server.reset() - - async def close(self): - """Compatibility method - no-op since ManagedServer handles cleanup.""" - pass - - def copy(self, **_kwargs) -> "AtroposManagedClient": - """ - Create a copy of this client (for API compatibility). - - Verifiers may call client.copy() for certain operations. - Returns self since we want to maintain the same ManagedServer state. - """ - return self diff --git a/atroposlib/tests/test_atropos_managed_client.py b/atroposlib/tests/test_atropos_managed_client.py deleted file mode 100644 index cb1cf955..00000000 --- a/atroposlib/tests/test_atropos_managed_client.py +++ /dev/null @@ -1,235 +0,0 @@ -"""Tests for AtroposManagedClient - AsyncOpenAI-compatible wrapper for ManagedServer.""" - -import pytest - -from atroposlib.envs.server_handling.atropos_managed_client import ( - AtroposManagedClient, - ChoiceLogprobs, - EnhancedChatCompletion, - LogprobContent, -) -from atroposlib.envs.server_handling.managed_server import ManagedServer -from atroposlib.envs.server_handling.server_harness import ServerHarness - - -class MockTokenizer: - """Mock tokenizer for testing.""" - - def __init__(self): - self.eos_token_id = 2 - self.bos_token_id = 1 - - def encode(self, text, add_special_tokens=True): - """Simple character-based encoding for testing.""" - tokens = [ord(c) for c in text] - if add_special_tokens: - tokens = [self.bos_token_id] + tokens - return tokens - - def decode(self, tokens, skip_special_tokens=False): - """Simple character-based decoding for testing.""" - if skip_special_tokens: - tokens = [ - t for t in tokens if t not in [self.bos_token_id, self.eos_token_id] - ] - return "".join([chr(t) if t > 31 else "" for t in tokens]) - - def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True): - """Simple chat template for testing.""" - result = "" - for msg in messages: - result += f"<{msg['role']}>{msg['content']}" - if add_generation_prompt: - result += "" - if tokenize: - return self.encode(result) - return result - - -@pytest.fixture -def mock_server(): - """Create a mock server with a tokenizer.""" - server = ServerHarness() - server.tokenizer = MockTokenizer() - - class Config: - model_name = "test_model" - - server.config = Config() - return server - - -@pytest.fixture -def managed_client(mock_server): - """Create an AtroposManagedClient with mocked server.""" - managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer) - return AtroposManagedClient(managed_server=managed, model="test_model") - - -class TestDataclasses: - """Test the enhanced dataclasses.""" - - def test_logprob_content(self): - """Test LogprobContent creation.""" - lp = LogprobContent(logprob=-0.5, token="hello", token_id=100) - assert lp.logprob == -0.5 - assert lp.token == "hello" - assert lp.token_id == 100 - - def test_choice_logprobs(self): - """Test ChoiceLogprobs structure.""" - content = [ - LogprobContent(logprob=-0.1), - LogprobContent(logprob=-0.2), - ] - logprobs = ChoiceLogprobs(content=content) - assert len(logprobs.content) == 2 - assert logprobs.content[0].logprob == -0.1 - - -class TestAtroposManagedClient: - """Test AtroposManagedClient behavior.""" - - def test_reset(self, managed_client): - """Test reset clears ManagedServer state.""" - # Add some state to managed server - managed_client.managed_server.current_nodes = ["dummy"] - - # Reset should clear it - managed_client.reset() - assert len(managed_client.managed_server.current_nodes) == 0 - - def test_copy_returns_self(self, managed_client): - """Test copy returns same instance for state sharing.""" - copied = managed_client.copy() - assert copied is managed_client - - def test_namespace_structure(self, managed_client): - """Test client has correct namespace structure like AsyncOpenAI.""" - assert hasattr(managed_client, "chat") - assert hasattr(managed_client.chat, "completions") - assert hasattr(managed_client.chat.completions, "create") - - @pytest.mark.asyncio - async def test_close_is_noop(self, managed_client): - """Test close() doesn't raise.""" - await managed_client.close() # Should not raise - - -class TestChatCompletionCreate: - """Test the chat.completions.create() method.""" - - @pytest.mark.asyncio - async def test_basic_completion(self, mock_server, managed_client): - """Test basic chat completion returns enhanced response.""" - messages = [{"role": "user", "content": "Hello"}] - managed = managed_client.managed_server - prompt = managed._convert_messages_to_prompt(messages) - prompt_tokens = mock_server.tokenizer.encode(prompt) - - output_text = "Hi there!" - output_tokens = [ord(c) for c in output_text] - output_logprobs = [-0.1] * len(output_tokens) - - mock_server.set_tokens_and_logprobs_response( - prompt=prompt, - prompt_tokens=prompt_tokens, - output_tokens_list=[output_tokens], - output_logprobs_list=[output_logprobs], - finish_reasons=["stop"], - ) - - result = await managed_client.chat.completions.create( - messages=messages, - max_tokens=100, - ) - - # Should return EnhancedChatCompletion - assert isinstance(result, EnhancedChatCompletion) - assert len(result.choices) == 1 - assert result.choices[0].message.content == output_text - - # Should have prompt_token_ids - assert len(result.prompt_token_ids) == len(prompt_tokens) - - # Should have token_ids on choice - assert len(result.choices[0].token_ids) == len(output_tokens) - assert result.choices[0].token_ids == output_tokens - - # Should have logprobs - assert len(result.choices[0].logprobs.content) == len(output_tokens) - assert result.choices[0].logprobs.content[0].logprob == -0.1 - - @pytest.mark.asyncio - async def test_max_completion_tokens_param(self, mock_server, managed_client): - """Test max_completion_tokens is preferred over max_tokens.""" - messages = [{"role": "user", "content": "Hi"}] - managed = managed_client.managed_server - prompt = managed._convert_messages_to_prompt(messages) - prompt_tokens = mock_server.tokenizer.encode(prompt) - - output_tokens = [ord("!")] - output_logprobs = [-0.1] - - mock_server.set_tokens_and_logprobs_response( - prompt=prompt, - prompt_tokens=prompt_tokens, - output_tokens_list=[output_tokens], - output_logprobs_list=[output_logprobs], - finish_reasons=["stop"], - ) - - # Should accept max_completion_tokens (new OpenAI param) - result = await managed_client.chat.completions.create( - messages=messages, - max_completion_tokens=50, - ) - - assert isinstance(result, EnhancedChatCompletion) - - @pytest.mark.asyncio - async def test_reset_between_rollouts(self, mock_server, managed_client): - """Test that reset clears state between rollouts.""" - messages = [{"role": "user", "content": "Hello"}] - managed = managed_client.managed_server - prompt = managed._convert_messages_to_prompt(messages) - prompt_tokens = mock_server.tokenizer.encode(prompt) - - output_tokens = [ord("!")] - output_logprobs = [-0.1] - - mock_server.set_tokens_and_logprobs_response( - prompt=prompt, - prompt_tokens=prompt_tokens, - output_tokens_list=[output_tokens], - output_logprobs_list=[output_logprobs], - finish_reasons=["stop"], - ) - - # First rollout - await managed_client.chat.completions.create(messages=messages, max_tokens=10) - state = managed_client.managed_server.get_state() - assert len(state["nodes"]) == 1 - - # Reset - managed_client.reset() - state = managed_client.managed_server.get_state() - assert len(state["nodes"]) == 0 - - # Setup for second rollout - mock_server.set_tokens_and_logprobs_response( - prompt=prompt, - prompt_tokens=prompt_tokens, - output_tokens_list=[output_tokens], - output_logprobs_list=[output_logprobs], - finish_reasons=["stop"], - ) - - # Second rollout - await managed_client.chat.completions.create(messages=messages, max_tokens=10) - state = managed_client.managed_server.get_state() - assert len(state["nodes"]) == 1 - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) From c56af35eaafa5376c5bb308b549bef1675d8b7bf Mon Sep 17 00:00:00 2001 From: "balyan.sid@gmail.com" Date: Thu, 15 Jan 2026 11:34:40 +0530 Subject: [PATCH 20/22] switch to evalbase for verifiers_eval.py --- environments/README.md | 65 ++-- .../eval_environments/verifiers_eval.py | 362 +++++++++--------- 2 files changed, 216 insertions(+), 211 deletions(-) diff --git a/environments/README.md b/environments/README.md index 6000bf58..5473460b 100644 --- a/environments/README.md +++ b/environments/README.md @@ -64,16 +64,11 @@ A flexible environment that integrates with the [Verifiers](https://docs.primein | `server/server_0_request_time_*` | API latency metrics (avg, std, 99p) | | `eval/avg_total_score` | Average score on evaluation dataset | -**W&B Metrics Logged (Evaluation - `verifiers_eval.py`):** +**Output (Evaluation - `verifiers_eval.py`):** -| Metric | Description | -|--------|-------------| -| `verifiers/accuracy` | Proportion of items with score > 0 | -| `verifiers/avg_score` | Average weighted score across all items | -| `verifiers/total_evaluated` | Number of successfully evaluated items | -| `verifiers/total_correct` | Number of items with score > 0 | -| `verifiers/reward_func_N_avg` | Per-reward function average score | -| `verifiers/reward_func_N_correct` | Per-reward function correct count | +Uses `evaluate_log()` from `atroposlib.envs.eval` to output: +- Console: Metrics table with accuracy, avg_score, time, and per-reward function breakdown +- File: `metrics.json` and `samples.jsonl` (when `--eval-dir` is specified) **Configuration Options (`VfEnvConfig` for `verifiers_server.py`):** @@ -82,18 +77,19 @@ A flexible environment that integrates with the [Verifiers](https://docs.primein | `vf_env_name` | str | `""` | Prime environment identifier (e.g., `"will/wordle"`, `"primeintellect/gsm8k"`) | | `env_args` | Dict | `{}` | Additional arguments passed to `vf.load_environment()`. Read environment specific documentation to get these args. | -**Configuration Options (`VerifiersEvaluationConfig` for `verifiers_eval.py`):** +**CLI Options (`verifiers_eval.py`):** | Option | Type | Default | Description | |--------|------|---------|-------------| -| `vf_env_name` | str | `""` | Prime environment identifier | -| `env_args` | dict | `{}` | Additional arguments for verifiers environment | -| `temperature` | float | `0.0` | Temperature for generation (0.0 for deterministic) | -| `max_retries` | int | `3` | Maximum retries for failed API calls | -| `retry_delay` | float | `1.0` | Delay between retries in seconds | -| `min_response_length` | int | `1` | Minimum response length to consider valid | -| `full_debug` | bool | `False` | Enable verbose per-item debug output | -| `max_eval_items` | int | `-1` | Maximum number of items to evaluate (-1 for all) | +| `--server-url` | str | `http://localhost:8000/v1` | URL of the inference server | +| `--model-name` | str | (required) | Model name to evaluate | +| `--api-key` | str | `$OPENAI_API_KEY` | API key (defaults to env var) | +| `--vf-env-name` | str | `primeintellect/gsm8k` | Prime environment identifier | +| `--temperature` | float | `0.0` | Temperature for generation | +| `--max-tokens` | int | `2048` | Maximum tokens per completion | +| `--max-eval-items` | int | `-1` | Maximum items to evaluate (-1 for all) | +| `--max-concurrent` | int | `64` | Maximum concurrent requests | +| `--eval-dir` | str | `None` | Directory to save evaluation results | **Usage Examples:** @@ -124,31 +120,32 @@ python verifiers_server.py evaluate \ --env.vf_env_name "will/wordle" \ --openai.base_url http://localhost:9001/v1 -# Standalone Evaluation with detailed metrics (verifiers_eval.py) -python eval_environments/verifiers_eval.py evaluate \ - --env.vf_env_name "primeintellect/gsm8k" \ - --openai.model_name gpt-4o \ - --openai.api_key $OPENAI_API_KEY +# Standalone Evaluation with OpenAI (verifiers_eval.py) +python eval_environments/verifiers_eval.py \ + --server-url https://api.openai.com/v1 \ + --model-name gpt-4o \ + --vf-env-name primeintellect/gsm8k # Quick test run with limited items -python eval_environments/verifiers_eval.py evaluate \ - --env.vf_env_name "primeintellect/gsm8k" \ - --env.max_eval_items 10 \ - --openai.model_name gpt-4o \ - --openai.api_key $OPENAI_API_KEY +python eval_environments/verifiers_eval.py \ + --server-url https://api.openai.com/v1 \ + --model-name gpt-4o-mini \ + --vf-env-name primeintellect/alphabet-sort \ + --max-eval-items 10 -# Evaluation with debug output -python eval_environments/verifiers_eval.py evaluate \ - --env.vf_env_name "primeintellect/gsm8k" \ - --env.full_debug true \ - --openai.base_url http://localhost:9001/v1 +# Evaluation with local server and results saved +python eval_environments/verifiers_eval.py \ + --server-url http://localhost:9001/v1 \ + --model-name Qwen/Qwen2.5-7B-Instruct \ + --vf-env-name primeintellect/gsm8k \ + --eval-dir ./eval_results ``` **Key Implementation Details:** - **RL Training Mode (`serve`)**: Uses `ManagedServer` for proper token/logprob alignment required by policy gradient methods (GRPO, PPO, REINFORCE). Returns `ScoredDataGroup` with `tokens`, `masks`, `scores`, and `inference_logprobs`. - **SFT Datagen Mode (`process`)**: Uses `tokenize_for_trainer` to tokenize API responses with your target model's tokenizer (e.g., GPT-4o responses tokenized for Qwen/Llama). Does NOT require logprobs. -- **Evaluation (`evaluate`)**: Runs on the environment's eval dataset with greedy decoding (temperature=0). The standalone `verifiers_eval.py` provides more detailed metrics and retry logic for production evaluation. +- **Evaluation (`verifiers_eval.py`)**: Standalone evaluation script using `EvalBase` pattern. Uses verifiers' native batch evaluation for efficiency and outputs results via `evaluate_log()`. Works with any OpenAI-compatible API. **Prime Environment Installation:** ```bash diff --git a/environments/eval_environments/verifiers_eval.py b/environments/eval_environments/verifiers_eval.py index d6d8cf98..dd21e492 100644 --- a/environments/eval_environments/verifiers_eval.py +++ b/environments/eval_environments/verifiers_eval.py @@ -14,32 +14,25 @@ To install a Verifiers/Prime environment: Docs: https://docs.primeintellect.ai/tutorials-environments/install Usage: - python verifiers_eval.py evaluate \ - --env.vf_env_name primeintellect/gsm8k \ - --openai.model_name gpt-4.1-nano \ - --openai.api_key $OPENAI_API_KEY + python verifiers_eval.py \ + --server-url http://localhost:8000/v1 \ + --model-name Qwen/Qwen2.5-7B-Instruct \ + --vf-env-name primeintellect/gsm8k \ + --max-eval-items 100 -Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.): - python verifiers_eval.py evaluate \ - --env.vf_env_name primeintellect/gsm8k \ - --openai.model_name Qwen/Qwen2.5-7B-Instruct \ - --openai.base_url http://localhost:8000/v1 +Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.) """ -import os +import argparse +import asyncio import time -from typing import Any, Dict, List, Optional, Tuple +from typing import Tuple import verifiers as vf -import wandb from openai import AsyncOpenAI -from pydantic import Field -from atroposlib.envs.base import ( - APIServerConfig, - BaseEnv, - BaseEnvConfig, -) +from atroposlib.envs.eval import EvalBase, evaluate_log +from atroposlib.envs.server_handling.server_manager import ServerManager # Patch math_verify timeout to work in async context @@ -70,149 +63,113 @@ except ImportError: pass # math_verify not installed -class VerifiersEvaluationConfig(BaseEnvConfig): - """Configuration for Verifiers evaluation environment.""" - - vf_env_name: str = Field( - default="", - description="Verifiers environment name (e.g., primeintellect/gsm8k)", - ) - env_args: Dict[str, Any] = Field( - default_factory=dict, - description="Additional arguments for verifiers environment", - ) - temperature: float = Field( - default=0.0, - description="Temperature for generation (0.0 for deterministic)", - ) - max_eval_items: int = Field( - default=-1, - description="Maximum number of items to evaluate (-1 for all)", - ) - max_concurrent: int = Field( - default=64, - description="Maximum concurrent requests to the model", - ) - - # Override BaseEnvConfig defaults for evaluation - group_size: int = 1 - max_num_workers: int = 1024 - max_eval_workers: int = 256 - max_num_workers_per_node: int = 128 - use_wandb: bool = True - rollout_server_url: str = "http://localhost:8000" - total_steps: int = 1 - steps_per_eval: int = 1 - wandb_name: str = "verifiers_eval" - - -class VerifiersEvaluationEnv(BaseEnv): +class VerifiersEval(EvalBase): """ - Verifiers Evaluation Environment. + Verifiers Evaluation using EvalBase pattern. - Evaluates models using Prime Intellect's Verifiers library. - Uses verifiers' native rollout and scoring machinery. + Uses verifiers' native batch evaluation for efficiency, + with EvalBase's standardized logging via evaluate_log(). Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.) """ - name = "verifiers_evaluation" - env_config_cls = VerifiersEvaluationConfig # type: ignore[assignment] - def __init__( self, - config: VerifiersEvaluationConfig, - server_configs: List[APIServerConfig], - slurm: bool = False, - testing: bool = False, + vf_env_name: str = "primeintellect/gsm8k", + env_args: dict = None, + temperature: float = 0.0, + max_tokens: int = 2048, + max_eval_items: int = -1, + max_concurrent: int = 64, + eval_dir: str = None, + verbose: bool = True, + **kwargs, ): - super().__init__(config, server_configs, slurm, testing) - self.config: VerifiersEvaluationConfig = config + self.vf_env_name = vf_env_name + self.env_args = env_args or {} + self.temperature = temperature + self.max_tokens = max_tokens + self.max_eval_items = max_eval_items + self.max_concurrent = max_concurrent - self.vf_env = vf.load_environment(config.vf_env_name, **config.env_args) - - # Get reward function names for metrics reporting + # Load verifiers environment + self.vf_env = vf.load_environment(vf_env_name, **self.env_args) self.reward_func_names = self.vf_env.rubric._get_reward_func_names() - @classmethod - def config_init(cls) -> Tuple[VerifiersEvaluationConfig, List[APIServerConfig]]: - """Default configuration for evaluation.""" - env_config = VerifiersEvaluationConfig( - vf_env_name="primeintellect/gsm8k", - ) - server_configs = [ - APIServerConfig( - model_name="gpt-4.1-nano", - base_url="https://api.openai.com/v1", - api_key=os.getenv("OPENAI_API_KEY"), - ), - ] - return env_config, server_configs - - def _get_openai_client(self) -> AsyncOpenAI: - """Create AsyncOpenAI client from first server config.""" - server = self.server.servers[0] - config = server.config - return AsyncOpenAI( - api_key=config.api_key or "x", - base_url=config.base_url, - timeout=config.timeout, + # Initialize EvalBase (calls setup_data) + super().__init__( + eval_dir=eval_dir, + verbose=verbose, + **kwargs, ) - def _get_model_name(self) -> str: - """Get model name from first server config.""" - return self.server.servers[0].config.model_name + def get_generation_params(self): + """Generation params for logging.""" + return { + "temperature": self.temperature, + "max_tokens": self.max_tokens, + "n": 1, + } - async def setup(self) -> None: - """Initialize the environment.""" - num_eval = len(self.vf_env.get_eval_dataset()) - if self.config.max_eval_items > 0: - num_eval = min(num_eval, self.config.max_eval_items) + def setup_data(self) -> list: + """Return evaluation dataset from verifiers environment.""" + dataset = self.vf_env.get_eval_dataset() + if self.max_eval_items > 0: + n = min(len(dataset), self.max_eval_items) + dataset = dataset.select(range(n)) + return dataset.to_list() - print("\nVerifiers Evaluation Setup:") - print(f" Environment: {self.config.vf_env_name}") - print(f" Reward functions: {self.reward_func_names}") - print(f" Evaluation items: {num_eval}") - print(f" Max concurrent: {self.config.max_concurrent}") - - async def evaluate(self) -> Dict: - """Run evaluation using verifiers' native machinery.""" - num_examples = ( - self.config.max_eval_items if self.config.max_eval_items > 0 else -1 + async def run_item( + self, server: ServerManager, data_item: dict # noqa: ARG002 + ) -> Tuple[dict, list]: + """Not used - we override __call__ for batch evaluation.""" + raise NotImplementedError( + "VerifiersEval uses batch evaluation via __call__, not per-item run_item" ) - print(f"\n{'=' * 60}") - print(f"Starting Verifiers Evaluation: {self.config.vf_env_name}") - print(f"{'=' * 60}") - print(f" Model: {self._get_model_name()}") - print(f" Temperature: {self.config.temperature}") - print(f" Max concurrent: {self.config.max_concurrent}") - print(f"{'=' * 60}\n") - + async def __call__(self, server_manager: ServerManager): + """Run evaluation using verifiers' native batch machinery.""" start_time = time.time() - # Create OpenAI client from atropos server config - client = self._get_openai_client() - model = self._get_model_name() + # Create OpenAI client from server config + server = server_manager.servers[0] + client = AsyncOpenAI( + api_key=server.config.api_key or "x", + base_url=server.config.base_url, + timeout=getattr(server.config, "timeout", 600), + ) + model = server.config.model_name - # Let verifiers handle everything: rollouts + scoring + print(f"\n{'=' * 60}") + print(f"Verifiers Evaluation: {self.vf_env_name}") + print(f"{'=' * 60}") + print(f" Model: {model}") + print(f" Items: {len(self.data)}") + print(f" Reward functions: {self.reward_func_names}") + print(f" Temperature: {self.temperature}") + print(f" Max concurrent: {self.max_concurrent}") + print(f"{'=' * 60}\n") + + num_examples = self.max_eval_items if self.max_eval_items > 0 else -1 + + # Use verifiers' batch evaluation results = await self.vf_env.evaluate( client=client, model=model, sampling_args={ - "temperature": self.config.temperature, - "max_tokens": self.config.max_token_length, + "temperature": self.temperature, + "max_tokens": self.max_tokens, }, num_examples=num_examples, - max_concurrent=self.config.max_concurrent, + max_concurrent=self.max_concurrent, save_results=False, ) end_time = time.time() - # Extract metrics from verifiers output + # Extract from verifiers output rewards = results["reward"] - per_func_metrics = results["metrics"] # dict of func_name -> list[float] + per_func_metrics = results["metrics"] prompts = results["prompt"] completions = results["completion"] answers = results["answer"] @@ -232,14 +189,11 @@ class VerifiersEvaluationEnv(BaseEnv): } metrics = { - "avg_score": avg_score, "accuracy": accuracy, - "total_evaluated": total, - "total_correct": correct, - "reward_breakdown": reward_breakdown, + "avg_score": avg_score, } - # Print results + # Print results summary print(f"\n{'=' * 60}") print("Verifiers Evaluation Results") print(f"{'=' * 60}") @@ -253,14 +207,13 @@ class VerifiersEvaluationEnv(BaseEnv): ) print(f"{'=' * 60}\n") - # Log to evaluate_log (atropos's logging system) + # Build samples for logging system_prompt = self.vf_env.system_prompt or "" samples = [] for i in range(min(total, 100)): # Limit samples for logging prompt_msgs = prompts[i] if isinstance(prompts[i], list) else [] completion_msgs = completions[i] if completions[i] else [] - # Build full message list messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) @@ -280,59 +233,114 @@ class VerifiersEvaluationEnv(BaseEnv): } ) - await self.evaluate_log( - metrics={"accuracy": accuracy, "avg_score": avg_score}, - samples=samples, + # Use EvalBase's evaluate_log + task_name = f"VerifiersEval@{self.vf_env_name.replace('/', '_')}" + evaluate_log( + metrics=metrics, + eval_dir=getattr(self, "eval_dir", None), + task_name=task_name, + model_name=model, start_time=start_time, end_time=end_time, - generation_parameters={ - "temperature": self.config.temperature, - "max_tokens": self.config.max_token_length, - }, + generation_parameters=self.get_generation_params(), + samples=samples, + verbose=getattr(self, "verbose", False), ) - # Log to wandb - await self.wandb_log(metrics) - return metrics - async def wandb_log(self, wandb_metrics: Optional[Dict] = None) -> None: - """Log metrics to Weights & Biases.""" - if not self.config.use_wandb or wandb_metrics is None: - return - # Lazy init if wandb not already initialized - if wandb.run is None: - wandb.init( - project="atropos-environments", - name=self.config.wandb_name, - config=self.config.model_dump(), - ) +async def main(): + """CLI entry point for verifiers evaluation.""" + import os - log_dict = { - "verifiers/accuracy": wandb_metrics.get("accuracy", 0), - "verifiers/avg_score": wandb_metrics.get("avg_score", 0), - "verifiers/total_evaluated": wandb_metrics.get("total_evaluated", 0), - "verifiers/total_correct": wandb_metrics.get("total_correct", 0), - } + from atroposlib.envs.server_handling.server_baseline import APIServerConfig - # Add per-reward function metrics - reward_breakdown = wandb_metrics.get("reward_breakdown", {}) - for func_name, data in reward_breakdown.items(): - log_dict[f"verifiers/{func_name}_avg"] = data.get("avg", 0) - log_dict[f"verifiers/{func_name}_correct"] = data.get("correct", 0) + parser = argparse.ArgumentParser( + description="Evaluate models using Verifiers environments" + ) + # Server args (same as eval_runner) + parser.add_argument( + "--server-url", + type=str, + default="http://localhost:8000/v1", + help="URL of the inference server", + ) + parser.add_argument( + "--model-name", + type=str, + required=True, + help="Model name to evaluate", + ) + parser.add_argument( + "--api-key", + type=str, + default=os.getenv("OPENAI_API_KEY", "x"), + help="API key (defaults to OPENAI_API_KEY env var)", + ) + # Verifiers-specific args + parser.add_argument( + "--vf-env-name", + type=str, + default="primeintellect/gsm8k", + help="Verifiers environment name (e.g., primeintellect/gsm8k)", + ) + parser.add_argument( + "--max-eval-items", + type=int, + default=-1, + help="Maximum items to evaluate (-1 for all)", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.0, + help="Generation temperature", + ) + parser.add_argument( + "--max-tokens", + type=int, + default=2048, + help="Maximum tokens per completion", + ) + parser.add_argument( + "--max-concurrent", + type=int, + default=64, + help="Maximum concurrent requests", + ) + parser.add_argument( + "--eval-dir", + type=str, + default=None, + help="Directory to save evaluation results", + ) + args = parser.parse_args() - wandb.log(log_dict) + # Create server manager + server_manager = ServerManager( + configs=[ + APIServerConfig( + api_key=args.api_key, + base_url=args.server_url, + model_name=args.model_name, + health_check=False, + ), + ] + ) - # Required abstract method implementations (stubs for evaluation-only mode) - async def get_next_item(self) -> Optional[Dict]: - """Not used in evaluation mode.""" - return None - - async def collect_trajectories(self, item) -> Tuple[List, List]: # noqa: ARG002 - """Not used in evaluation mode.""" - return [], [] + # Create and run evaluation + eval_instance = VerifiersEval( + vf_env_name=args.vf_env_name, + max_eval_items=args.max_eval_items, + temperature=args.temperature, + max_tokens=args.max_tokens, + max_concurrent=args.max_concurrent, + eval_dir=args.eval_dir, + verbose=True, + ) + return await eval_instance(server_manager) if __name__ == "__main__": - VerifiersEvaluationEnv.cli() + asyncio.run(main()) From 5a20abdce7089caa19e4d3aae633ccc2fe6f1eea Mon Sep 17 00:00:00 2001 From: "balyan.sid@gmail.com" Date: Fri, 23 Jan 2026 23:25:19 +0530 Subject: [PATCH 21/22] switch eval to use managed server adapter impl. moved managed server adapter --- .../envs/server_handling/managed_server.py | 60 +++ environments/README.md | 56 +-- .../eval_environments/verifiers_eval.py | 351 ++++++++---------- environments/verifiers_server.py | 56 +-- 4 files changed, 253 insertions(+), 270 deletions(-) diff --git a/atroposlib/envs/server_handling/managed_server.py b/atroposlib/envs/server_handling/managed_server.py index cbc97268..7b26b5c1 100644 --- a/atroposlib/envs/server_handling/managed_server.py +++ b/atroposlib/envs/server_handling/managed_server.py @@ -509,3 +509,63 @@ class ManagedServer: self.sequences.clear() else: self.current_nodes.clear() + + +class ManagedServerAdapter: + """ + Adapter that makes ManagedServer look like AsyncOpenAI for external libraries. + + Implements the subset of AsyncOpenAI interface commonly used: + - client.chat.completions.create() + - client.completions.create() + - client.base_url + + This allows libraries like verifiers to use ManagedServer transparently + while still getting automatic token and logprob tracking. + """ + + def __init__(self, managed_server: ManagedServer, base_url: str): + """ + Initialize the adapter. + + Args: + managed_server: The ManagedServer instance to wrap + base_url: The base URL to expose (for compatibility checks) + """ + self._managed = managed_server + self.base_url = base_url + self.chat = self._ChatNamespace(self._managed) + self.completions = self._CompletionsNamespace(self._managed) + + class _ChatNamespace: + def __init__(self, managed: ManagedServer): + self._managed = managed + self.completions = ManagedServerAdapter._ChatCompletionsNamespace(managed) + + class _ChatCompletionsNamespace: + def __init__(self, managed: ManagedServer): + self._managed = managed + + async def create(self, **kwargs): + return await self._managed.chat_completion(**kwargs) + + class _CompletionsNamespace: + def __init__(self, managed: ManagedServer): + self._managed = managed + + async def create(self, **kwargs): + return await self._managed.completion(**kwargs) + + async def post(self, path: str, body: dict, cast_to: type): + """Not supported - raises NotImplementedError.""" + raise NotImplementedError( + f"ManagedServerAdapter does not support post() for path '{path}'. " + "This is used for vLLM interleaved rollouts. Use standard chat completions." + ) + + def copy(self, **kwargs): + """Not supported - raises NotImplementedError.""" + raise NotImplementedError( + "ManagedServerAdapter does not support copy(). " + "This is used for vLLM tokenization endpoints." + ) diff --git a/environments/README.md b/environments/README.md index 5473460b..67bcb6f6 100644 --- a/environments/README.md +++ b/environments/README.md @@ -66,9 +66,9 @@ A flexible environment that integrates with the [Verifiers](https://docs.primein **Output (Evaluation - `verifiers_eval.py`):** -Uses `evaluate_log()` from `atroposlib.envs.eval` to output: +Uses `evaluate_log()` from `BaseEnv` to output: - Console: Metrics table with accuracy, avg_score, time, and per-reward function breakdown -- File: `metrics.json` and `samples.jsonl` (when `--eval-dir` is specified) +- File: `metrics.json` and `samples.jsonl` (when `--env.data_dir_to_save_evals` is specified) **Configuration Options (`VfEnvConfig` for `verifiers_server.py`):** @@ -79,17 +79,19 @@ Uses `evaluate_log()` from `atroposlib.envs.eval` to output: **CLI Options (`verifiers_eval.py`):** +Uses the standard BaseEnv CLI pattern with `evaluate` subcommand. Key options: + | Option | Type | Default | Description | |--------|------|---------|-------------| -| `--server-url` | str | `http://localhost:8000/v1` | URL of the inference server | -| `--model-name` | str | (required) | Model name to evaluate | -| `--api-key` | str | `$OPENAI_API_KEY` | API key (defaults to env var) | -| `--vf-env-name` | str | `primeintellect/gsm8k` | Prime environment identifier | -| `--temperature` | float | `0.0` | Temperature for generation | -| `--max-tokens` | int | `2048` | Maximum tokens per completion | -| `--max-eval-items` | int | `-1` | Maximum items to evaluate (-1 for all) | -| `--max-concurrent` | int | `64` | Maximum concurrent requests | -| `--eval-dir` | str | `None` | Directory to save evaluation results | +| `--openai.base_url` | str | `http://localhost:9001/v1` | URL of the inference server | +| `--openai.model_name` | str | `Qwen/Qwen2.5-1.5B-Instruct` | Model name to evaluate | +| `--openai.api_key` | str | `x` | API key | +| `--env.vf_env_name` | str | `primeintellect/gsm8k` | Prime environment identifier | +| `--env.eval_temperature` | float | `0.0` | Temperature for generation | +| `--env.eval_max_tokens` | int | `2048` | Maximum tokens per completion | +| `--env.max_eval_items` | int | `-1` | Maximum items to evaluate (-1 for all) | +| `--env.max_concurrent` | int | `64` | Maximum concurrent requests | +| `--env.data_dir_to_save_evals` | str | `None` | Directory to save evaluation results | **Usage Examples:** @@ -121,31 +123,33 @@ python verifiers_server.py evaluate \ --openai.base_url http://localhost:9001/v1 # Standalone Evaluation with OpenAI (verifiers_eval.py) -python eval_environments/verifiers_eval.py \ - --server-url https://api.openai.com/v1 \ - --model-name gpt-4o \ - --vf-env-name primeintellect/gsm8k +python eval_environments/verifiers_eval.py evaluate \ + --openai.base_url https://api.openai.com/v1 \ + --openai.api_key $OPENAI_API_KEY \ + --openai.model_name gpt-4o \ + --env.vf_env_name primeintellect/gsm8k # Quick test run with limited items -python eval_environments/verifiers_eval.py \ - --server-url https://api.openai.com/v1 \ - --model-name gpt-4o-mini \ - --vf-env-name primeintellect/alphabet-sort \ - --max-eval-items 10 +python eval_environments/verifiers_eval.py evaluate \ + --openai.base_url https://api.openai.com/v1 \ + --openai.api_key $OPENAI_API_KEY \ + --openai.model_name gpt-4o-mini \ + --env.vf_env_name primeintellect/alphabet-sort \ + --env.max_eval_items 10 # Evaluation with local server and results saved -python eval_environments/verifiers_eval.py \ - --server-url http://localhost:9001/v1 \ - --model-name Qwen/Qwen2.5-7B-Instruct \ - --vf-env-name primeintellect/gsm8k \ - --eval-dir ./eval_results +python eval_environments/verifiers_eval.py evaluate \ + --openai.base_url http://localhost:9001/v1 \ + --openai.model_name Qwen/Qwen2.5-7B-Instruct \ + --env.vf_env_name primeintellect/gsm8k \ + --env.data_dir_to_save_evals ./eval_results ``` **Key Implementation Details:** - **RL Training Mode (`serve`)**: Uses `ManagedServer` for proper token/logprob alignment required by policy gradient methods (GRPO, PPO, REINFORCE). Returns `ScoredDataGroup` with `tokens`, `masks`, `scores`, and `inference_logprobs`. - **SFT Datagen Mode (`process`)**: Uses `tokenize_for_trainer` to tokenize API responses with your target model's tokenizer (e.g., GPT-4o responses tokenized for Qwen/Llama). Does NOT require logprobs. -- **Evaluation (`verifiers_eval.py`)**: Standalone evaluation script using `EvalBase` pattern. Uses verifiers' native batch evaluation for efficiency and outputs results via `evaluate_log()`. Works with any OpenAI-compatible API. +- **Evaluation (`verifiers_eval.py`)**: Standalone evaluation script using `BaseEnv` pattern with `evaluate` subcommand. Uses verifiers' native batch evaluation for efficiency and outputs results via `evaluate_log()`. Works with any OpenAI-compatible API. **Prime Environment Installation:** ```bash diff --git a/environments/eval_environments/verifiers_eval.py b/environments/eval_environments/verifiers_eval.py index dd21e492..3100cf04 100644 --- a/environments/eval_environments/verifiers_eval.py +++ b/environments/eval_environments/verifiers_eval.py @@ -14,31 +14,53 @@ To install a Verifiers/Prime environment: Docs: https://docs.primeintellect.ai/tutorials-environments/install Usage: - python verifiers_eval.py \ - --server-url http://localhost:8000/v1 \ - --model-name Qwen/Qwen2.5-7B-Instruct \ - --vf-env-name primeintellect/gsm8k \ - --max-eval-items 100 + # Evaluate with local server + python verifiers_eval.py evaluate \ + --env.vf_env_name "primeintellect/gsm8k" \ + --env.max_eval_items 100 \ + --openai.model_name "Qwen/Qwen2.5-7B-Instruct" \ + --openai.base_url "http://localhost:8000/v1" + + # Evaluate with OpenAI + python verifiers_eval.py evaluate \ + --env.vf_env_name "primeintellect/gsm8k" \ + --env.max_eval_items 50 \ + --openai.model_name "gpt-4o" \ + --openai.api_key "$OPENAI_API_KEY" \ + --openai.base_url "https://api.openai.com/v1" Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.) """ -import argparse -import asyncio +import json import time -from typing import Tuple +from typing import Any, Dict, List, Tuple import verifiers as vf -from openai import AsyncOpenAI -from atroposlib.envs.eval import EvalBase, evaluate_log -from atroposlib.envs.server_handling.server_manager import ServerManager +from atroposlib.envs.base import ( + APIServerConfig, + BaseEnv, + BaseEnvConfig, + ScoredDataGroup, +) +# Import ManagedServerAdapter from shared location +from atroposlib.envs.server_handling.managed_server import ManagedServerAdapter # Patch math_verify timeout to work in async context # The signal-based timeout doesn't work in non-main threads (asyncio event loop) -def _no_signal_timeout(_timeout_seconds: int): - """Replacement timeout decorator that doesn't use signals.""" + + +def _no_signal_timeout( + _timeout_seconds: int | None = None, *, timeout_seconds: int | None = None +): + """Replacement timeout decorator that doesn't use signals. + + Accepts both positional arg (timeout(5)) and keyword arg (timeout(timeout_seconds=5)). + """ + # Silence unused parameter warnings - these match the original API signature + del _timeout_seconds, timeout_seconds def decorator(func): def wrapper(*args, **kwargs): @@ -63,108 +85,143 @@ except ImportError: pass # math_verify not installed -class VerifiersEval(EvalBase): +class VfEvalConfig(BaseEnvConfig): + """Configuration for Verifiers evaluation environment.""" + + vf_env_name: str = "primeintellect/gsm8k" + env_args: str = "{}" # JSON string for environment-specific args + eval_temperature: float = 0.0 + eval_max_tokens: int = 2048 + max_eval_items: int = -1 # -1 means evaluate all items + max_concurrent: int = 64 + + # Override BaseEnvConfig defaults for eval mode + group_size: int = 1 + total_steps: int = 1 + steps_per_eval: int = 1 + use_wandb: bool = True + + def get_env_args(self) -> Dict[str, Any]: + """Parse env_args JSON string into dict.""" + if isinstance(self.env_args, dict): + return self.env_args + return json.loads(self.env_args) + + +class VerifiersEvalEnv(BaseEnv): """ - Verifiers Evaluation using EvalBase pattern. + Verifiers Evaluation Environment using BaseEnv pattern. Uses verifiers' native batch evaluation for efficiency, - with EvalBase's standardized logging via evaluate_log(). + with BaseEnv's standardized logging via evaluate_log(). Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.) """ - def __init__( - self, - vf_env_name: str = "primeintellect/gsm8k", - env_args: dict = None, - temperature: float = 0.0, - max_tokens: int = 2048, - max_eval_items: int = -1, - max_concurrent: int = 64, - eval_dir: str = None, - verbose: bool = True, - **kwargs, - ): - self.vf_env_name = vf_env_name - self.env_args = env_args or {} - self.temperature = temperature - self.max_tokens = max_tokens - self.max_eval_items = max_eval_items - self.max_concurrent = max_concurrent + name = "verifiers_eval" + env_config_cls = VfEvalConfig # type: ignore[assignment] - # Load verifiers environment - self.vf_env = vf.load_environment(vf_env_name, **self.env_args) + @classmethod + def config_init(cls) -> Tuple[VfEvalConfig, List[APIServerConfig]]: + """Return default configurations.""" + env_config = VfEvalConfig( + tokenizer_name="Qwen/Qwen2.5-1.5B-Instruct", + vf_env_name="primeintellect/gsm8k", + wandb_name="verifiers_eval", + ) + server_configs = [ + APIServerConfig( + model_name="Qwen/Qwen2.5-1.5B-Instruct", + base_url="http://localhost:9001/v1", + api_key="x", + ), + ] + return env_config, server_configs + + async def setup(self): + """Load verifiers environment and dataset.""" + env_args = self.config.get_env_args() + self.vf_env = vf.load_environment(self.config.vf_env_name, **env_args) self.reward_func_names = self.vf_env.rubric._get_reward_func_names() - # Initialize EvalBase (calls setup_data) - super().__init__( - eval_dir=eval_dir, - verbose=verbose, - **kwargs, - ) - - def get_generation_params(self): - """Generation params for logging.""" - return { - "temperature": self.temperature, - "max_tokens": self.max_tokens, - "n": 1, - } - - def setup_data(self) -> list: - """Return evaluation dataset from verifiers environment.""" + # Load evaluation dataset dataset = self.vf_env.get_eval_dataset() - if self.max_eval_items > 0: - n = min(len(dataset), self.max_eval_items) + if self.config.max_eval_items > 0: + n = min(len(dataset), self.config.max_eval_items) dataset = dataset.select(range(n)) - return dataset.to_list() + self.data = dataset.to_list() - async def run_item( - self, server: ServerManager, data_item: dict # noqa: ARG002 - ) -> Tuple[dict, list]: - """Not used - we override __call__ for batch evaluation.""" - raise NotImplementedError( - "VerifiersEval uses batch evaluation via __call__, not per-item run_item" + async def get_next_item(self): + """Not used in eval mode - stub implementation.""" + return None + + async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, list]: + """Not used in eval mode - stub implementation.""" + _ = item # unused in eval mode + return ( + ScoredDataGroup( + tokens=[], + masks=[], + scores=[], + messages=[], + inference_logprobs=[], + advantages=[], + ref_logprobs=[], + generation_params=None, + group_overrides=None, + overrides=[], + images=[], + ), + [], ) - async def __call__(self, server_manager: ServerManager): - """Run evaluation using verifiers' native batch machinery.""" + async def evaluate(self, *args, **kwargs) -> Dict[str, float]: + """Run evaluation using verifiers with ManagedServer.""" start_time = time.time() - # Create OpenAI client from server config - server = server_manager.servers[0] - client = AsyncOpenAI( - api_key=server.config.api_key or "x", - base_url=server.config.base_url, - timeout=getattr(server.config, "timeout", 600), - ) - model = server.config.model_name + # Get server config + if hasattr(self.server, "servers") and self.server.servers: + server_config = self.server.servers[0].config + else: + server_config = self.server_configs[0] + + model_name = server_config.model_name print(f"\n{'=' * 60}") - print(f"Verifiers Evaluation: {self.vf_env_name}") + print(f"Verifiers Evaluation: {self.config.vf_env_name}") print(f"{'=' * 60}") - print(f" Model: {model}") + print(f" Model: {model_name}") print(f" Items: {len(self.data)}") print(f" Reward functions: {self.reward_func_names}") - print(f" Temperature: {self.temperature}") - print(f" Max concurrent: {self.max_concurrent}") + print(f" Temperature: {self.config.eval_temperature}") + print(f" Max concurrent: {self.config.max_concurrent}") print(f"{'=' * 60}\n") - num_examples = self.max_eval_items if self.max_eval_items > 0 else -1 - - # Use verifiers' batch evaluation - results = await self.vf_env.evaluate( - client=client, - model=model, - sampling_args={ - "temperature": self.temperature, - "max_tokens": self.max_tokens, - }, - num_examples=num_examples, - max_concurrent=self.max_concurrent, - save_results=False, + num_examples = ( + self.config.max_eval_items if self.config.max_eval_items > 0 else -1 ) + # Use ManagedServer for automatic token/logprob tracking + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + # Create adapter that looks like AsyncOpenAI for verifiers + adapter = ManagedServerAdapter( + managed_server=managed, + base_url=server_config.base_url, + ) + + # Use verifiers' batch evaluation + results = await self.vf_env.evaluate( + client=adapter, + model=model_name, + sampling_args={ + "temperature": self.config.eval_temperature, + "max_tokens": self.config.eval_max_tokens, + }, + num_examples=num_examples, + max_concurrent=self.config.max_concurrent, + save_results=False, + ) + end_time = time.time() # Extract from verifiers output @@ -193,6 +250,11 @@ class VerifiersEval(EvalBase): "avg_score": avg_score, } + # Add per-function metrics + for func_name, data in reward_breakdown.items(): + metrics[f"{func_name}_avg"] = data["avg"] + metrics[f"{func_name}_correct_rate"] = data["correct"] / total + # Print results summary print(f"\n{'=' * 60}") print("Verifiers Evaluation Results") @@ -233,114 +295,25 @@ class VerifiersEval(EvalBase): } ) - # Use EvalBase's evaluate_log - task_name = f"VerifiersEval@{self.vf_env_name.replace('/', '_')}" - evaluate_log( + # Use BaseEnv's evaluate_log + task_name = f"VerifiersEval@{self.config.vf_env_name.replace('/', '_')}" + await self.evaluate_log( metrics=metrics, - eval_dir=getattr(self, "eval_dir", None), task_name=task_name, - model_name=model, + model_name=model_name, start_time=start_time, end_time=end_time, - generation_parameters=self.get_generation_params(), + generation_parameters={ + "temperature": self.config.eval_temperature, + "max_tokens": self.config.eval_max_tokens, + "n": 1, + }, samples=samples, - verbose=getattr(self, "verbose", False), + verbose=True, ) return metrics -async def main(): - """CLI entry point for verifiers evaluation.""" - import os - - from atroposlib.envs.server_handling.server_baseline import APIServerConfig - - parser = argparse.ArgumentParser( - description="Evaluate models using Verifiers environments" - ) - # Server args (same as eval_runner) - parser.add_argument( - "--server-url", - type=str, - default="http://localhost:8000/v1", - help="URL of the inference server", - ) - parser.add_argument( - "--model-name", - type=str, - required=True, - help="Model name to evaluate", - ) - parser.add_argument( - "--api-key", - type=str, - default=os.getenv("OPENAI_API_KEY", "x"), - help="API key (defaults to OPENAI_API_KEY env var)", - ) - # Verifiers-specific args - parser.add_argument( - "--vf-env-name", - type=str, - default="primeintellect/gsm8k", - help="Verifiers environment name (e.g., primeintellect/gsm8k)", - ) - parser.add_argument( - "--max-eval-items", - type=int, - default=-1, - help="Maximum items to evaluate (-1 for all)", - ) - parser.add_argument( - "--temperature", - type=float, - default=0.0, - help="Generation temperature", - ) - parser.add_argument( - "--max-tokens", - type=int, - default=2048, - help="Maximum tokens per completion", - ) - parser.add_argument( - "--max-concurrent", - type=int, - default=64, - help="Maximum concurrent requests", - ) - parser.add_argument( - "--eval-dir", - type=str, - default=None, - help="Directory to save evaluation results", - ) - args = parser.parse_args() - - # Create server manager - server_manager = ServerManager( - configs=[ - APIServerConfig( - api_key=args.api_key, - base_url=args.server_url, - model_name=args.model_name, - health_check=False, - ), - ] - ) - - # Create and run evaluation - eval_instance = VerifiersEval( - vf_env_name=args.vf_env_name, - max_eval_items=args.max_eval_items, - temperature=args.temperature, - max_tokens=args.max_tokens, - max_concurrent=args.max_concurrent, - eval_dir=args.eval_dir, - verbose=True, - ) - return await eval_instance(server_manager) - - if __name__ == "__main__": - asyncio.run(main()) + VerifiersEvalEnv.cli() diff --git a/environments/verifiers_server.py b/environments/verifiers_server.py index 0cd8cf45..98413d8a 100644 --- a/environments/verifiers_server.py +++ b/environments/verifiers_server.py @@ -46,65 +46,11 @@ from atroposlib.envs.base import ( BaseEnvConfig, ScoredDataGroup, ) -from atroposlib.envs.server_handling.managed_server import ManagedServer +from atroposlib.envs.server_handling.managed_server import ManagedServerAdapter logger = logging.getLogger(__name__) -class ManagedServerAdapter: - """ - Adapter that makes ManagedServer look like AsyncOpenAI for verifiers. - - Implements the subset of AsyncOpenAI interface that verifiers uses: - - client.chat.completions.create() - - client.completions.create() - - client.base_url - """ - - def __init__(self, managed_server: ManagedServer, base_url: str): - self._managed = managed_server - self.base_url = base_url - self.chat = self._ChatNamespace(self._managed) - self.completions = self._CompletionsNamespace(self._managed) - - class _ChatNamespace: - def __init__(self, managed: ManagedServer): - self._managed = managed - self.completions = ManagedServerAdapter._ChatCompletionsNamespace(managed) - - class _ChatCompletionsNamespace: - def __init__(self, managed: ManagedServer): - self._managed = managed - - async def create(self, **kwargs): - logger.info( - "ManagedServerAdapter.chat.completions.create called with model=%s", - kwargs.get("model"), - ) - result = await self._managed.chat_completion(**kwargs) - logger.info("ManagedServerAdapter.chat.completions.create completed") - return result - - class _CompletionsNamespace: - def __init__(self, managed: ManagedServer): - self._managed = managed - - async def create(self, **kwargs): - return await self._managed.completion(**kwargs) - - async def post(self, path: str, body: dict, cast_to: type): - raise NotImplementedError( - f"ManagedServerAdapter does not support post() for path '{path}'. " - "This is used for vLLM interleaved rollouts. Use standard chat completions." - ) - - def copy(self, **kwargs): - raise NotImplementedError( - "ManagedServerAdapter does not support copy(). " - "This is used for vLLM tokenization endpoints." - ) - - class VfEnvConfig(BaseEnvConfig): vf_env_name: str = "" env_args: str = "{}" From 4ba69d3a808d461388ef163662914d6123ac5229 Mon Sep 17 00:00:00 2001 From: "balyan.sid@gmail.com" Date: Fri, 23 Jan 2026 23:41:32 +0530 Subject: [PATCH 22/22] revert to using evalbase --- environments/README.md | 54 ++-- .../eval_environments/verifiers_eval.py | 306 ++++++++++-------- 2 files changed, 204 insertions(+), 156 deletions(-) diff --git a/environments/README.md b/environments/README.md index 67bcb6f6..ee700dde 100644 --- a/environments/README.md +++ b/environments/README.md @@ -66,7 +66,7 @@ A flexible environment that integrates with the [Verifiers](https://docs.primein **Output (Evaluation - `verifiers_eval.py`):** -Uses `evaluate_log()` from `BaseEnv` to output: +Uses `evaluate_log()` from `EvalBase` to output: - Console: Metrics table with accuracy, avg_score, time, and per-reward function breakdown - File: `metrics.json` and `samples.jsonl` (when `--env.data_dir_to_save_evals` is specified) @@ -79,19 +79,19 @@ Uses `evaluate_log()` from `BaseEnv` to output: **CLI Options (`verifiers_eval.py`):** -Uses the standard BaseEnv CLI pattern with `evaluate` subcommand. Key options: +Uses a simple argparse CLI with direct arguments: | Option | Type | Default | Description | |--------|------|---------|-------------| -| `--openai.base_url` | str | `http://localhost:9001/v1` | URL of the inference server | -| `--openai.model_name` | str | `Qwen/Qwen2.5-1.5B-Instruct` | Model name to evaluate | -| `--openai.api_key` | str | `x` | API key | -| `--env.vf_env_name` | str | `primeintellect/gsm8k` | Prime environment identifier | -| `--env.eval_temperature` | float | `0.0` | Temperature for generation | -| `--env.eval_max_tokens` | int | `2048` | Maximum tokens per completion | -| `--env.max_eval_items` | int | `-1` | Maximum items to evaluate (-1 for all) | -| `--env.max_concurrent` | int | `64` | Maximum concurrent requests | -| `--env.data_dir_to_save_evals` | str | `None` | Directory to save evaluation results | +| `--server-url` | str | `http://localhost:8000/v1` | URL of the inference server | +| `--model-name` | str | (required) | Model name to evaluate | +| `--api-key` | str | `$OPENAI_API_KEY` | API key (uses env var if not specified) | +| `--vf-env-name` | str | `primeintellect/gsm8k` | Prime environment identifier | +| `--temperature` | float | `0.0` | Temperature for generation | +| `--max-tokens` | int | `2048` | Maximum tokens per completion | +| `--max-eval-items` | int | `-1` | Maximum items to evaluate (-1 for all) | +| `--max-concurrent` | int | `64` | Maximum concurrent requests | +| `--eval-dir` | str | `None` | Directory to save evaluation results | **Usage Examples:** @@ -123,33 +123,31 @@ python verifiers_server.py evaluate \ --openai.base_url http://localhost:9001/v1 # Standalone Evaluation with OpenAI (verifiers_eval.py) -python eval_environments/verifiers_eval.py evaluate \ - --openai.base_url https://api.openai.com/v1 \ - --openai.api_key $OPENAI_API_KEY \ - --openai.model_name gpt-4o \ - --env.vf_env_name primeintellect/gsm8k +python eval_environments/verifiers_eval.py \ + --server-url https://api.openai.com/v1 \ + --model-name gpt-4o \ + --vf-env-name primeintellect/gsm8k # Quick test run with limited items -python eval_environments/verifiers_eval.py evaluate \ - --openai.base_url https://api.openai.com/v1 \ - --openai.api_key $OPENAI_API_KEY \ - --openai.model_name gpt-4o-mini \ - --env.vf_env_name primeintellect/alphabet-sort \ - --env.max_eval_items 10 +python eval_environments/verifiers_eval.py \ + --server-url https://api.openai.com/v1 \ + --model-name gpt-4o-mini \ + --vf-env-name primeintellect/alphabet-sort \ + --max-eval-items 10 # Evaluation with local server and results saved -python eval_environments/verifiers_eval.py evaluate \ - --openai.base_url http://localhost:9001/v1 \ - --openai.model_name Qwen/Qwen2.5-7B-Instruct \ - --env.vf_env_name primeintellect/gsm8k \ - --env.data_dir_to_save_evals ./eval_results +python eval_environments/verifiers_eval.py \ + --server-url http://localhost:9001/v1 \ + --model-name Qwen/Qwen2.5-7B-Instruct \ + --vf-env-name primeintellect/gsm8k \ + --eval-dir ./eval_results ``` **Key Implementation Details:** - **RL Training Mode (`serve`)**: Uses `ManagedServer` for proper token/logprob alignment required by policy gradient methods (GRPO, PPO, REINFORCE). Returns `ScoredDataGroup` with `tokens`, `masks`, `scores`, and `inference_logprobs`. - **SFT Datagen Mode (`process`)**: Uses `tokenize_for_trainer` to tokenize API responses with your target model's tokenizer (e.g., GPT-4o responses tokenized for Qwen/Llama). Does NOT require logprobs. -- **Evaluation (`verifiers_eval.py`)**: Standalone evaluation script using `BaseEnv` pattern with `evaluate` subcommand. Uses verifiers' native batch evaluation for efficiency and outputs results via `evaluate_log()`. Works with any OpenAI-compatible API. +- **Evaluation (`verifiers_eval.py`)**: Standalone evaluation script using `EvalBase` with simple argparse CLI. Uses verifiers' native batch evaluation with `ManagedServerAdapter` for token/logprob tracking and outputs results via `evaluate_log()`. Works with any OpenAI-compatible API. **Prime Environment Installation:** ```bash diff --git a/environments/eval_environments/verifiers_eval.py b/environments/eval_environments/verifiers_eval.py index 3100cf04..559b0fd0 100644 --- a/environments/eval_environments/verifiers_eval.py +++ b/environments/eval_environments/verifiers_eval.py @@ -14,39 +14,33 @@ To install a Verifiers/Prime environment: Docs: https://docs.primeintellect.ai/tutorials-environments/install Usage: - # Evaluate with local server - python verifiers_eval.py evaluate \ - --env.vf_env_name "primeintellect/gsm8k" \ - --env.max_eval_items 100 \ - --openai.model_name "Qwen/Qwen2.5-7B-Instruct" \ - --openai.base_url "http://localhost:8000/v1" - # Evaluate with OpenAI - python verifiers_eval.py evaluate \ - --env.vf_env_name "primeintellect/gsm8k" \ - --env.max_eval_items 50 \ - --openai.model_name "gpt-4o" \ - --openai.api_key "$OPENAI_API_KEY" \ - --openai.base_url "https://api.openai.com/v1" + python verifiers_eval.py \ + --server-url https://api.openai.com/v1 \ + --model-name gpt-4o \ + --vf-env-name primeintellect/gsm8k \ + --max-eval-items 50 + + # Evaluate with local server + python verifiers_eval.py \ + --server-url http://localhost:8000/v1 \ + --model-name Qwen/Qwen2.5-7B-Instruct \ + --vf-env-name primeintellect/gsm8k Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.) """ +import argparse import json import time -from typing import Any, Dict, List, Tuple +from typing import Any, Dict import verifiers as vf -from atroposlib.envs.base import ( - APIServerConfig, - BaseEnv, - BaseEnvConfig, - ScoredDataGroup, -) - -# Import ManagedServerAdapter from shared location +from atroposlib.envs.eval import EvalBase, evaluate_log from atroposlib.envs.server_handling.managed_server import ManagedServerAdapter +from atroposlib.envs.server_handling.server_baseline import APIServerConfig +from atroposlib.envs.server_handling.server_manager import ServerManager # Patch math_verify timeout to work in async context # The signal-based timeout doesn't work in non-main threads (asyncio event loop) @@ -85,128 +79,74 @@ except ImportError: pass # math_verify not installed -class VfEvalConfig(BaseEnvConfig): - """Configuration for Verifiers evaluation environment.""" - - vf_env_name: str = "primeintellect/gsm8k" - env_args: str = "{}" # JSON string for environment-specific args - eval_temperature: float = 0.0 - eval_max_tokens: int = 2048 - max_eval_items: int = -1 # -1 means evaluate all items - max_concurrent: int = 64 - - # Override BaseEnvConfig defaults for eval mode - group_size: int = 1 - total_steps: int = 1 - steps_per_eval: int = 1 - use_wandb: bool = True - - def get_env_args(self) -> Dict[str, Any]: - """Parse env_args JSON string into dict.""" - if isinstance(self.env_args, dict): - return self.env_args - return json.loads(self.env_args) - - -class VerifiersEvalEnv(BaseEnv): +class VerifiersEval(EvalBase): """ - Verifiers Evaluation Environment using BaseEnv pattern. + Verifiers Evaluation using EvalBase pattern. Uses verifiers' native batch evaluation for efficiency, - with BaseEnv's standardized logging via evaluate_log(). + with ManagedServerAdapter for token/logprob tracking. Works with any OpenAI-compatible API (OpenAI, vLLM, SGLang, Ollama, etc.) """ - name = "verifiers_eval" - env_config_cls = VfEvalConfig # type: ignore[assignment] + def __init__( + self, + vf_env_name: str = "primeintellect/gsm8k", + env_args: str = "{}", + temperature: float = 0.0, + max_tokens: int = 2048, + max_eval_items: int = -1, + max_concurrent: int = 64, + **kwargs, + ): + self.vf_env_name = vf_env_name + self.env_args_str = env_args + self.temperature = temperature + self.max_tokens = max_tokens + self.max_eval_items = max_eval_items + self.max_concurrent = max_concurrent + super().__init__(**kwargs) - @classmethod - def config_init(cls) -> Tuple[VfEvalConfig, List[APIServerConfig]]: - """Return default configurations.""" - env_config = VfEvalConfig( - tokenizer_name="Qwen/Qwen2.5-1.5B-Instruct", - vf_env_name="primeintellect/gsm8k", - wandb_name="verifiers_eval", - ) - server_configs = [ - APIServerConfig( - model_name="Qwen/Qwen2.5-1.5B-Instruct", - base_url="http://localhost:9001/v1", - api_key="x", - ), - ] - return env_config, server_configs + def get_env_args(self) -> Dict[str, Any]: + """Parse env_args JSON string into dict.""" + if isinstance(self.env_args_str, dict): + return self.env_args_str + return json.loads(self.env_args_str) - async def setup(self): + def setup_data(self) -> list: """Load verifiers environment and dataset.""" - env_args = self.config.get_env_args() - self.vf_env = vf.load_environment(self.config.vf_env_name, **env_args) + env_args = self.get_env_args() + self.vf_env = vf.load_environment(self.vf_env_name, **env_args) self.reward_func_names = self.vf_env.rubric._get_reward_func_names() # Load evaluation dataset dataset = self.vf_env.get_eval_dataset() - if self.config.max_eval_items > 0: - n = min(len(dataset), self.config.max_eval_items) + if self.max_eval_items > 0: + n = min(len(dataset), self.max_eval_items) dataset = dataset.select(range(n)) - self.data = dataset.to_list() + return dataset.to_list() - async def get_next_item(self): - """Not used in eval mode - stub implementation.""" - return None + async def run_item(self, server: ServerManager, data_item: dict): + """Not used - verifiers uses batch evaluation in __call__.""" + # This won't be called since we override __call__ + raise NotImplementedError("VerifiersEval uses batch evaluation in __call__") - async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, list]: - """Not used in eval mode - stub implementation.""" - _ = item # unused in eval mode - return ( - ScoredDataGroup( - tokens=[], - masks=[], - scores=[], - messages=[], - inference_logprobs=[], - advantages=[], - ref_logprobs=[], - generation_params=None, - group_overrides=None, - overrides=[], - images=[], - ), - [], - ) - - async def evaluate(self, *args, **kwargs) -> Dict[str, float]: - """Run evaluation using verifiers with ManagedServer.""" + async def __call__(self, server_manager: ServerManager): + """Run evaluation using verifiers with ManagedServerAdapter.""" start_time = time.time() # Get server config - if hasattr(self.server, "servers") and self.server.servers: - server_config = self.server.servers[0].config - else: - server_config = self.server_configs[0] + server = server_manager.servers[0] + model_name = server.config.model_name - model_name = server_config.model_name - - print(f"\n{'=' * 60}") - print(f"Verifiers Evaluation: {self.config.vf_env_name}") - print(f"{'=' * 60}") - print(f" Model: {model_name}") - print(f" Items: {len(self.data)}") - print(f" Reward functions: {self.reward_func_names}") - print(f" Temperature: {self.config.eval_temperature}") - print(f" Max concurrent: {self.config.max_concurrent}") - print(f"{'=' * 60}\n") - - num_examples = ( - self.config.max_eval_items if self.config.max_eval_items > 0 else -1 - ) + num_examples = self.max_eval_items if self.max_eval_items > 0 else -1 # Use ManagedServer for automatic token/logprob tracking - async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + async with server_manager.managed_server(tokenizer=None) as managed: # Create adapter that looks like AsyncOpenAI for verifiers adapter = ManagedServerAdapter( managed_server=managed, - base_url=server_config.base_url, + base_url=server.config.base_url, ) # Use verifiers' batch evaluation @@ -214,11 +154,11 @@ class VerifiersEvalEnv(BaseEnv): client=adapter, model=model_name, sampling_args={ - "temperature": self.config.eval_temperature, - "max_tokens": self.config.eval_max_tokens, + "temperature": self.temperature, + "max_tokens": self.max_tokens, }, num_examples=num_examples, - max_concurrent=self.config.max_concurrent, + max_concurrent=self.max_concurrent, save_results=False, ) @@ -295,25 +235,135 @@ class VerifiersEvalEnv(BaseEnv): } ) - # Use BaseEnv's evaluate_log - task_name = f"VerifiersEval@{self.config.vf_env_name.replace('/', '_')}" - await self.evaluate_log( + # Log results + task_name = f"VerifiersEval@{self.vf_env_name.replace('/', '_')}" + evaluate_log( metrics=metrics, + eval_dir=getattr(self, "eval_dir", None), task_name=task_name, model_name=model_name, start_time=start_time, end_time=end_time, generation_parameters={ - "temperature": self.config.eval_temperature, - "max_tokens": self.config.eval_max_tokens, + "temperature": self.temperature, + "max_tokens": self.max_tokens, "n": 1, }, samples=samples, - verbose=True, + verbose=getattr(self, "verbose", True), ) return metrics +async def main(): + """Run verifiers evaluation with argparse CLI.""" + import os + + parser = argparse.ArgumentParser( + description="Evaluate models using Prime Intellect's Verifiers library" + ) + parser.add_argument( + "--server-url", + type=str, + default="http://localhost:8000/v1", + help="URL of the inference server (default: http://localhost:8000/v1)", + ) + parser.add_argument( + "--model-name", + type=str, + required=True, + help="Model name to evaluate", + ) + parser.add_argument( + "--api-key", + type=str, + default=None, + help="API key (default: uses OPENAI_API_KEY env var)", + ) + parser.add_argument( + "--vf-env-name", + type=str, + default="primeintellect/gsm8k", + help="Verifiers environment name (default: primeintellect/gsm8k)", + ) + parser.add_argument( + "--env-args", + type=str, + default="{}", + help="JSON string of environment-specific args (default: {})", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.0, + help="Generation temperature (default: 0.0)", + ) + parser.add_argument( + "--max-tokens", + type=int, + default=2048, + help="Maximum tokens per completion (default: 2048)", + ) + parser.add_argument( + "--max-eval-items", + type=int, + default=-1, + help="Maximum items to evaluate, -1 for all (default: -1)", + ) + parser.add_argument( + "--max-concurrent", + type=int, + default=64, + help="Maximum concurrent requests (default: 64)", + ) + parser.add_argument( + "--eval-dir", + type=str, + default=None, + help="Directory to save evaluation results (default: None)", + ) + parser.add_argument( + "--verbose", + action="store_true", + default=True, + help="Print verbose output (default: True)", + ) + + args = parser.parse_args() + + # Get API key from args or environment + api_key = args.api_key or os.environ.get("OPENAI_API_KEY", "dummy") + + # Create evaluation instance + eval_env = VerifiersEval( + vf_env_name=args.vf_env_name, + env_args=args.env_args, + temperature=args.temperature, + max_tokens=args.max_tokens, + max_eval_items=args.max_eval_items, + max_concurrent=args.max_concurrent, + eval_dir=args.eval_dir, + verbose=args.verbose, + ) + + # Create server manager + server_manager = ServerManager( + configs=[ + APIServerConfig( + api_key=api_key, + base_url=args.server_url, + model_name=args.model_name, + health_check=False, + ), + ] + ) + + # Run evaluation + return await eval_env(server_manager) + + if __name__ == "__main__": - VerifiersEvalEnv.cli() + import asyncio + + asyncio.run(main())