diff --git a/environments/verifiers_server.py b/environments/verifiers_server.py index 13feb6f8..3e70081e 100644 --- a/environments/verifiers_server.py +++ b/environments/verifiers_server.py @@ -8,13 +8,13 @@ Supports TWO modes: Usage: # RL Training (requires local vLLM/SGLang server) python verifiers_server.py serve \ - --env.vf_env_name "will/wordle" \ + --env.vf_env_name "primeintellect/alphabet-sort" \ --openai.base_url http://localhost:9001/v1 \ --slurm false # SFT Data Generation with OpenAI GPT-4o python verifiers_server.py process \ - --env.vf_env_name "will/wordle" \ + --env.vf_env_name "primeintellect/alphabet-sort" \ --env.data_path_to_save_groups gpt4o_sft_data.jsonl \ --env.total_steps 100 \ --env.group_size 4 \ @@ -23,30 +23,30 @@ Usage: # SFT Data Generation with local server python verifiers_server.py process \ - --env.vf_env_name "will/wordle" \ + --env.vf_env_name "primeintellect/alphabet-sort" \ --env.data_path_to_save_groups local_sft_data.jsonl \ --openai.base_url http://localhost:9001/v1 # Evaluation (uses ManagedServer by default, falls back to direct API in process mode) python verifiers_server.py evaluate \ - --env.vf_env_name "will/wordle" \ + --env.vf_env_name "primeintellect/alphabet-sort" \ --openai.base_url http://localhost:9001/v1 To install a Verifiers/Prime environment: 1. uv tool install prime 2. prime login -3. prime env install will/wordle (or any owner/environment) +3. prime env install primeintellect/alphabet-sort (or any owner/environment) Docs: https://docs.primeintellect.ai/tutorials-environments/install """ import asyncio +import json import logging from collections import defaultdict from typing import Any, Dict, List, Optional, Tuple import verifiers as vf from openai import AsyncOpenAI -from pydantic import Field from atroposlib.envs.base import ( APIServerConfig, @@ -61,7 +61,13 @@ logger = logging.getLogger(__name__) class VfEnvConfig(BaseEnvConfig): vf_env_name: str = "" - env_args: Dict[str, Any] = Field(default_factory=dict) + env_args: str = "{}" + + def get_env_args(self) -> Dict[str, Any]: + """Parse env_args JSON string into dict.""" + if isinstance(self.env_args, dict): + return self.env_args + return json.loads(self.env_args) class VerifiersEnv(BaseEnv): @@ -85,7 +91,10 @@ class VerifiersEnv(BaseEnv): self.groups_total: int = 0 logger.info("Loading verifiers environment: %s", config.vf_env_name) - self.vf_env = vf.load_environment(config.vf_env_name, **config.env_args) + env_args = config.get_env_args() + if env_args: + logger.info("Environment args: %s", env_args) + self.vf_env = vf.load_environment(config.vf_env_name, **env_args) self.rubric = self.vf_env.rubric self.system_prompt = self.vf_env.system_prompt @@ -93,6 +102,10 @@ class VerifiersEnv(BaseEnv): self.reward_func_names = self.rubric._get_reward_func_names() logger.info("Reward functions: %s", self.reward_func_names) + # Log multi-turn config if available + if hasattr(self.vf_env, "max_turns"): + logger.info("Max turns: %d", self.vf_env.max_turns) + @classmethod def config_init(cls) -> Tuple[VfEnvConfig, List[APIServerConfig]]: env_config = VfEnvConfig( @@ -172,10 +185,9 @@ class VerifiersEnv(BaseEnv): await super().wandb_log(wandb_metrics) async def setup(self): + # Dataset already has: prompt, answer, info, example_id, task train_data = self.vf_env.get_dataset() - columns_to_keep = ["question", "answer", "info"] - available_columns = [c for c in columns_to_keep if c in train_data.column_names] - self.train = train_data.select_columns(available_columns).to_list() + self.train = train_data.to_list() self.iter = 0 def save_checkpoint(self, step, data=None): @@ -193,13 +205,6 @@ class VerifiersEnv(BaseEnv): """No-op. Use environments/eval_environments/verifiers_eval.py for evaluation.""" return {} - def _build_initial_messages(self, question: str) -> List[Dict[str, str]]: - messages: List[Dict[str, str]] = [] - if self.system_prompt: - messages.append({"role": "system", "content": self.system_prompt}) - messages.append({"role": "user", "content": question}) - return messages - async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, list]: is_process_mode = getattr(self, "process_mode", False) if is_process_mode: @@ -217,16 +222,16 @@ class VerifiersEnv(BaseEnv): timeout=server_config.timeout, ) - initial_messages = self._build_initial_messages(item["question"]) + # item already has prompt, answer, example_id, task, info from dataset inputs = [ { - "prompt": initial_messages, - "answer": item["answer"], - "example_id": i, - "task": self.config.vf_env_name, + "prompt": item["prompt"], + "answer": item.get("answer", ""), + "example_id": item["example_id"], + "task": item.get("task", self.config.vf_env_name), "info": item.get("info", {}), } - for i in range(self.config.group_size) + for _ in range(self.config.group_size) ] results = await self.vf_env.generate( @@ -279,7 +284,9 @@ class VerifiersEnv(BaseEnv): # Capture metrics for wandb logging self.reward_buffer.append(reward) - self.num_turns_buffer.append(len(trajectory)) + num_turns = len(trajectory) + self.num_turns_buffer.append(num_turns) + logger.debug("Rollout: %d turns, reward=%.3f", num_turns, reward) # Extract per-function metrics from verifiers state state_metrics = state.get("metrics", {}) @@ -288,6 +295,15 @@ class VerifiersEnv(BaseEnv): if isinstance(metric_value, (int, float)): self.metrics_buffer[metric_name].append(float(metric_value)) + # Log group summary + turns = [len(s.get("trajectory", [])) for s in results["state"]] + logger.info( + "Group: %d rollouts, turns=%s, rewards=%s", + len(results["state"]), + turns, + [f"{s:.3f}" for s in scored_data["scores"]], + ) + # Track group-level identical scores for debugging self.groups_total += 1 if len(set(scored_data["scores"])) == 1: @@ -307,7 +323,6 @@ class VerifiersEnv(BaseEnv): AtroposManagedClient, ) - initial_messages = self._build_initial_messages(item["question"]) sampling_args = { "temperature": 1.0, "max_completion_tokens": self.config.max_token_length, @@ -315,17 +330,17 @@ class VerifiersEnv(BaseEnv): score_sem = asyncio.Semaphore(self.config.group_size) model = self.server.servers[0].config.model_name - async def run_rollout(example_id: int) -> Dict[str, Any]: + async def run_rollout() -> Dict[str, Any]: """Run a single rollout and return full state for metrics extraction.""" async with self.server.managed_server(tokenizer=self.tokenizer) as managed: client = AtroposManagedClient(managed_server=managed, model=model) state = await self.vf_env.rollout( input={ - "prompt": initial_messages, - "answer": item["answer"], - "example_id": example_id, - "task": self.config.vf_env_name, + "prompt": item["prompt"], + "answer": item.get("answer", ""), + "example_id": item["example_id"], + "task": item.get("task", self.config.vf_env_name), "info": item.get("info", {}), }, client=client, @@ -336,7 +351,7 @@ class VerifiersEnv(BaseEnv): return state states = await asyncio.gather( - *[run_rollout(i) for i in range(self.config.group_size)] + *[run_rollout() for _ in range(self.config.group_size)] ) scored_data = ScoredDataGroup() @@ -356,7 +371,9 @@ class VerifiersEnv(BaseEnv): self.reward_buffer.append(reward) trajectory = state.get("trajectory", []) - self.num_turns_buffer.append(len(trajectory)) + num_turns = len(trajectory) + self.num_turns_buffer.append(num_turns) + logger.debug("Rollout: %d turns, reward=%.3f", num_turns, reward) # Extract per-function metrics from verifiers state state_metrics = state.get("metrics", {}) @@ -365,6 +382,15 @@ class VerifiersEnv(BaseEnv): if isinstance(metric_value, (int, float)): self.metrics_buffer[metric_name].append(float(metric_value)) + # Log group summary + turns = [len(s.get("trajectory", [])) for s in states] + logger.info( + "Group: %d rollouts, turns=%s, rewards=%s", + len(states), + turns, + [f"{s:.3f}" for s in scored_data["scores"]], + ) + # Track group-level identical scores for debugging self.groups_total += 1 if len(set(scored_data["scores"])) == 1: