fix env_args, dataset/prompt loading

This commit is contained in:
balyan.sid@gmail.com 2026-01-12 10:39:43 +05:30
parent 7907ffd0ad
commit a1d1e7d7fe

View file

@ -8,13 +8,13 @@ Supports TWO modes:
Usage:
# RL Training (requires local vLLM/SGLang server)
python verifiers_server.py serve \
--env.vf_env_name "will/wordle" \
--env.vf_env_name "primeintellect/alphabet-sort" \
--openai.base_url http://localhost:9001/v1 \
--slurm false
# SFT Data Generation with OpenAI GPT-4o
python verifiers_server.py process \
--env.vf_env_name "will/wordle" \
--env.vf_env_name "primeintellect/alphabet-sort" \
--env.data_path_to_save_groups gpt4o_sft_data.jsonl \
--env.total_steps 100 \
--env.group_size 4 \
@ -23,30 +23,30 @@ Usage:
# SFT Data Generation with local server
python verifiers_server.py process \
--env.vf_env_name "will/wordle" \
--env.vf_env_name "primeintellect/alphabet-sort" \
--env.data_path_to_save_groups local_sft_data.jsonl \
--openai.base_url http://localhost:9001/v1
# Evaluation (uses ManagedServer by default, falls back to direct API in process mode)
python verifiers_server.py evaluate \
--env.vf_env_name "will/wordle" \
--env.vf_env_name "primeintellect/alphabet-sort" \
--openai.base_url http://localhost:9001/v1
To install a Verifiers/Prime environment:
1. uv tool install prime
2. prime login
3. prime env install will/wordle (or any owner/environment)
3. prime env install primeintellect/alphabet-sort (or any owner/environment)
Docs: https://docs.primeintellect.ai/tutorials-environments/install
"""
import asyncio
import json
import logging
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple
import verifiers as vf
from openai import AsyncOpenAI
from pydantic import Field
from atroposlib.envs.base import (
APIServerConfig,
@ -61,7 +61,13 @@ logger = logging.getLogger(__name__)
class VfEnvConfig(BaseEnvConfig):
vf_env_name: str = ""
env_args: Dict[str, Any] = Field(default_factory=dict)
env_args: str = "{}"
def get_env_args(self) -> Dict[str, Any]:
"""Parse env_args JSON string into dict."""
if isinstance(self.env_args, dict):
return self.env_args
return json.loads(self.env_args)
class VerifiersEnv(BaseEnv):
@ -85,7 +91,10 @@ class VerifiersEnv(BaseEnv):
self.groups_total: int = 0
logger.info("Loading verifiers environment: %s", config.vf_env_name)
self.vf_env = vf.load_environment(config.vf_env_name, **config.env_args)
env_args = config.get_env_args()
if env_args:
logger.info("Environment args: %s", env_args)
self.vf_env = vf.load_environment(config.vf_env_name, **env_args)
self.rubric = self.vf_env.rubric
self.system_prompt = self.vf_env.system_prompt
@ -93,6 +102,10 @@ class VerifiersEnv(BaseEnv):
self.reward_func_names = self.rubric._get_reward_func_names()
logger.info("Reward functions: %s", self.reward_func_names)
# Log multi-turn config if available
if hasattr(self.vf_env, "max_turns"):
logger.info("Max turns: %d", self.vf_env.max_turns)
@classmethod
def config_init(cls) -> Tuple[VfEnvConfig, List[APIServerConfig]]:
env_config = VfEnvConfig(
@ -172,10 +185,9 @@ class VerifiersEnv(BaseEnv):
await super().wandb_log(wandb_metrics)
async def setup(self):
# Dataset already has: prompt, answer, info, example_id, task
train_data = self.vf_env.get_dataset()
columns_to_keep = ["question", "answer", "info"]
available_columns = [c for c in columns_to_keep if c in train_data.column_names]
self.train = train_data.select_columns(available_columns).to_list()
self.train = train_data.to_list()
self.iter = 0
def save_checkpoint(self, step, data=None):
@ -193,13 +205,6 @@ class VerifiersEnv(BaseEnv):
"""No-op. Use environments/eval_environments/verifiers_eval.py for evaluation."""
return {}
def _build_initial_messages(self, question: str) -> List[Dict[str, str]]:
messages: List[Dict[str, str]] = []
if self.system_prompt:
messages.append({"role": "system", "content": self.system_prompt})
messages.append({"role": "user", "content": question})
return messages
async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, list]:
is_process_mode = getattr(self, "process_mode", False)
if is_process_mode:
@ -217,16 +222,16 @@ class VerifiersEnv(BaseEnv):
timeout=server_config.timeout,
)
initial_messages = self._build_initial_messages(item["question"])
# item already has prompt, answer, example_id, task, info from dataset
inputs = [
{
"prompt": initial_messages,
"answer": item["answer"],
"example_id": i,
"task": self.config.vf_env_name,
"prompt": item["prompt"],
"answer": item.get("answer", ""),
"example_id": item["example_id"],
"task": item.get("task", self.config.vf_env_name),
"info": item.get("info", {}),
}
for i in range(self.config.group_size)
for _ in range(self.config.group_size)
]
results = await self.vf_env.generate(
@ -279,7 +284,9 @@ class VerifiersEnv(BaseEnv):
# Capture metrics for wandb logging
self.reward_buffer.append(reward)
self.num_turns_buffer.append(len(trajectory))
num_turns = len(trajectory)
self.num_turns_buffer.append(num_turns)
logger.debug("Rollout: %d turns, reward=%.3f", num_turns, reward)
# Extract per-function metrics from verifiers state
state_metrics = state.get("metrics", {})
@ -288,6 +295,15 @@ class VerifiersEnv(BaseEnv):
if isinstance(metric_value, (int, float)):
self.metrics_buffer[metric_name].append(float(metric_value))
# Log group summary
turns = [len(s.get("trajectory", [])) for s in results["state"]]
logger.info(
"Group: %d rollouts, turns=%s, rewards=%s",
len(results["state"]),
turns,
[f"{s:.3f}" for s in scored_data["scores"]],
)
# Track group-level identical scores for debugging
self.groups_total += 1
if len(set(scored_data["scores"])) == 1:
@ -307,7 +323,6 @@ class VerifiersEnv(BaseEnv):
AtroposManagedClient,
)
initial_messages = self._build_initial_messages(item["question"])
sampling_args = {
"temperature": 1.0,
"max_completion_tokens": self.config.max_token_length,
@ -315,17 +330,17 @@ class VerifiersEnv(BaseEnv):
score_sem = asyncio.Semaphore(self.config.group_size)
model = self.server.servers[0].config.model_name
async def run_rollout(example_id: int) -> Dict[str, Any]:
async def run_rollout() -> Dict[str, Any]:
"""Run a single rollout and return full state for metrics extraction."""
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
client = AtroposManagedClient(managed_server=managed, model=model)
state = await self.vf_env.rollout(
input={
"prompt": initial_messages,
"answer": item["answer"],
"example_id": example_id,
"task": self.config.vf_env_name,
"prompt": item["prompt"],
"answer": item.get("answer", ""),
"example_id": item["example_id"],
"task": item.get("task", self.config.vf_env_name),
"info": item.get("info", {}),
},
client=client,
@ -336,7 +351,7 @@ class VerifiersEnv(BaseEnv):
return state
states = await asyncio.gather(
*[run_rollout(i) for i in range(self.config.group_size)]
*[run_rollout() for _ in range(self.config.group_size)]
)
scored_data = ScoredDataGroup()
@ -356,7 +371,9 @@ class VerifiersEnv(BaseEnv):
self.reward_buffer.append(reward)
trajectory = state.get("trajectory", [])
self.num_turns_buffer.append(len(trajectory))
num_turns = len(trajectory)
self.num_turns_buffer.append(num_turns)
logger.debug("Rollout: %d turns, reward=%.3f", num_turns, reward)
# Extract per-function metrics from verifiers state
state_metrics = state.get("metrics", {})
@ -365,6 +382,15 @@ class VerifiersEnv(BaseEnv):
if isinstance(metric_value, (int, float)):
self.metrics_buffer[metric_name].append(float(metric_value))
# Log group summary
turns = [len(s.get("trajectory", [])) for s in states]
logger.info(
"Group: %d rollouts, turns=%s, rewards=%s",
len(states),
turns,
[f"{s:.3f}" for s in scored_data["scores"]],
)
# Track group-level identical scores for debugging
self.groups_total += 1
if len(set(scored_data["scores"])) == 1: