mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
430 lines
15 KiB
Python
430 lines
15 KiB
Python
"""
|
|
Verifiers Training Environment for Atropos
|
|
|
|
Supports TWO modes:
|
|
- serve: RL training with local inference server (requires ManagedServer for logprobs)
|
|
- process: SFT data generation with ANY API (OpenAI, Claude, local, etc.)
|
|
|
|
Usage:
|
|
# RL Training (requires local vLLM/SGLang server)
|
|
python verifiers_server.py serve \
|
|
--env.vf_env_name "primeintellect/alphabet-sort" \
|
|
--openai.base_url http://localhost:9001/v1 \
|
|
--slurm false
|
|
|
|
# SFT Data Generation with OpenAI GPT-4o
|
|
python verifiers_server.py process \
|
|
--env.vf_env_name "primeintellect/alphabet-sort" \
|
|
--env.data_path_to_save_groups gpt4o_sft_data.jsonl \
|
|
--env.total_steps 100 \
|
|
--env.group_size 4 \
|
|
--openai.model_name gpt-4o \
|
|
--openai.base_url https://api.openai.com/v1
|
|
|
|
# SFT Data Generation with local server
|
|
python verifiers_server.py process \
|
|
--env.vf_env_name "primeintellect/alphabet-sort" \
|
|
--env.data_path_to_save_groups local_sft_data.jsonl \
|
|
--openai.base_url http://localhost:9001/v1
|
|
|
|
# Evaluation (uses ManagedServer by default, falls back to direct API in process mode)
|
|
python verifiers_server.py evaluate \
|
|
--env.vf_env_name "primeintellect/alphabet-sort" \
|
|
--openai.base_url http://localhost:9001/v1
|
|
|
|
To install a Verifiers/Prime environment:
|
|
1. uv tool install prime
|
|
2. prime login
|
|
3. prime env install primeintellect/alphabet-sort (or any owner/environment)
|
|
Docs: https://docs.primeintellect.ai/tutorials-environments/install
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from collections import defaultdict
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
import verifiers as vf
|
|
from openai import AsyncOpenAI
|
|
|
|
from atroposlib.envs.base import (
|
|
APIServerConfig,
|
|
BaseEnv,
|
|
BaseEnvConfig,
|
|
ScoredDataGroup,
|
|
)
|
|
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class VfEnvConfig(BaseEnvConfig):
|
|
vf_env_name: str = ""
|
|
env_args: str = "{}"
|
|
|
|
def get_env_args(self) -> Dict[str, Any]:
|
|
"""Parse env_args JSON string into dict."""
|
|
if isinstance(self.env_args, dict):
|
|
return self.env_args
|
|
return json.loads(self.env_args)
|
|
|
|
|
|
class VerifiersEnv(BaseEnv):
|
|
name = "verifiers"
|
|
env_config_cls = VfEnvConfig # type: ignore[assignment]
|
|
|
|
def __init__(
|
|
self,
|
|
config: VfEnvConfig,
|
|
server_configs: List[APIServerConfig],
|
|
slurm=False,
|
|
testing=False,
|
|
):
|
|
super().__init__(config, server_configs, slurm, testing)
|
|
|
|
# Metrics buffers for wandb logging
|
|
self.reward_buffer: List[float] = []
|
|
self.metrics_buffer: Dict[str, List[float]] = defaultdict(list)
|
|
self.num_turns_buffer: List[int] = []
|
|
self.groups_with_identical_scores: int = 0
|
|
self.groups_total: int = 0
|
|
|
|
logger.info("Loading verifiers environment: %s", config.vf_env_name)
|
|
env_args = config.get_env_args()
|
|
if env_args:
|
|
logger.info("Environment args: %s", env_args)
|
|
self.vf_env = vf.load_environment(config.vf_env_name, **env_args)
|
|
self.rubric = self.vf_env.rubric
|
|
self.system_prompt = self.vf_env.system_prompt
|
|
|
|
# Get reward function names for metrics reporting
|
|
self.reward_func_names = self.rubric._get_reward_func_names()
|
|
logger.info("Reward functions: %s", self.reward_func_names)
|
|
|
|
# Log multi-turn config if available
|
|
if hasattr(self.vf_env, "max_turns"):
|
|
logger.info("Max turns: %d", self.vf_env.max_turns)
|
|
|
|
@classmethod
|
|
def config_init(cls) -> Tuple[VfEnvConfig, List[APIServerConfig]]:
|
|
env_config = VfEnvConfig(
|
|
tokenizer_name="Qwen/Qwen2.5-1.5B-Instruct",
|
|
group_size=8,
|
|
use_wandb=True,
|
|
rollout_server_url="http://localhost:8000",
|
|
total_steps=1000,
|
|
batch_size=4,
|
|
steps_per_eval=100,
|
|
max_token_length=2048,
|
|
wandb_name="verifiers",
|
|
)
|
|
server_configs = [
|
|
APIServerConfig(
|
|
model_name="gpt-4.1-nano",
|
|
base_url="https://api.openai.com/v1",
|
|
api_key="x",
|
|
num_requests_for_eval=4,
|
|
),
|
|
]
|
|
return env_config, server_configs
|
|
|
|
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
|
"""Enhanced wandb logging with verifiers-specific metrics."""
|
|
if wandb_metrics is None:
|
|
wandb_metrics = {}
|
|
|
|
# Log mean reward across all rollouts
|
|
if self.reward_buffer:
|
|
wandb_metrics["metrics/mean_reward"] = sum(self.reward_buffer) / len(
|
|
self.reward_buffer
|
|
)
|
|
wandb_metrics["metrics/reward_std"] = (
|
|
(
|
|
sum(
|
|
(r - wandb_metrics["metrics/mean_reward"]) ** 2
|
|
for r in self.reward_buffer
|
|
)
|
|
/ len(self.reward_buffer)
|
|
)
|
|
** 0.5
|
|
if len(self.reward_buffer) > 1
|
|
else 0.0
|
|
)
|
|
self.reward_buffer = []
|
|
|
|
# Log per-reward-function metrics (e.g., strict_accuracy, format_score)
|
|
if self.metrics_buffer:
|
|
for metric_name, values in self.metrics_buffer.items():
|
|
if values:
|
|
avg_metric = sum(values) / len(values)
|
|
wandb_metrics[f"metrics/{metric_name}"] = avg_metric
|
|
self.metrics_buffer = defaultdict(list)
|
|
|
|
# Log multi-turn statistics
|
|
if self.num_turns_buffer:
|
|
wandb_metrics["metrics/avg_num_turns"] = sum(self.num_turns_buffer) / len(
|
|
self.num_turns_buffer
|
|
)
|
|
wandb_metrics["metrics/max_num_turns"] = max(self.num_turns_buffer)
|
|
self.num_turns_buffer = []
|
|
|
|
# Log group filtering statistics (helpful for debugging)
|
|
if self.groups_total > 0:
|
|
wandb_metrics["metrics/groups_with_identical_scores"] = (
|
|
self.groups_with_identical_scores
|
|
)
|
|
wandb_metrics["metrics/groups_total"] = self.groups_total
|
|
wandb_metrics["metrics/identical_score_rate"] = (
|
|
self.groups_with_identical_scores / self.groups_total
|
|
)
|
|
# Reset counters
|
|
self.groups_with_identical_scores = 0
|
|
self.groups_total = 0
|
|
|
|
await super().wandb_log(wandb_metrics)
|
|
|
|
async def setup(self):
|
|
# Dataset already has: prompt, answer, info, example_id, task
|
|
train_data = self.vf_env.get_dataset()
|
|
self.train = train_data.to_list()
|
|
self.iter = 0
|
|
|
|
def save_checkpoint(self, step, data=None):
|
|
if data is None:
|
|
data = {}
|
|
data["iter"] = self.iter
|
|
super().save_checkpoint(step, data)
|
|
|
|
async def get_next_item(self):
|
|
next_item = self.train[self.iter % len(self.train)]
|
|
self.iter += 1
|
|
return next_item
|
|
|
|
async def evaluate(self) -> Dict[str, float]:
|
|
"""No-op. Use environments/eval_environments/verifiers_eval.py for evaluation."""
|
|
return {}
|
|
|
|
async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, list]:
|
|
is_process_mode = getattr(self, "process_mode", False)
|
|
if is_process_mode:
|
|
return await self._collect_for_sft(item)
|
|
return await self._collect_for_rl(item)
|
|
|
|
async def _collect_for_sft(
|
|
self, item: Dict[str, Any]
|
|
) -> Tuple[ScoredDataGroup, list]:
|
|
"""SFT mode: uses vf_env.generate() for parallel generation + scoring."""
|
|
server_config = self.server.servers[0].config
|
|
client = AsyncOpenAI(
|
|
api_key=server_config.api_key,
|
|
base_url=server_config.base_url,
|
|
timeout=server_config.timeout,
|
|
)
|
|
|
|
# item already has prompt, answer, example_id, task, info from dataset
|
|
inputs = [
|
|
{
|
|
"prompt": item["prompt"],
|
|
"answer": item.get("answer", ""),
|
|
"example_id": item["example_id"],
|
|
"task": item.get("task", self.config.vf_env_name),
|
|
"info": item.get("info", {}),
|
|
}
|
|
for _ in range(self.config.group_size)
|
|
]
|
|
|
|
results = await self.vf_env.generate(
|
|
inputs=inputs,
|
|
client=client,
|
|
model=server_config.model_name,
|
|
sampling_args={
|
|
"temperature": 1.0,
|
|
"max_completion_tokens": self.config.max_token_length,
|
|
},
|
|
max_concurrent=self.config.group_size,
|
|
max_concurrent_scoring=self.config.group_size,
|
|
save_results=False,
|
|
independent_scoring=True,
|
|
)
|
|
|
|
scored_data = ScoredDataGroup()
|
|
scored_data["tokens"] = []
|
|
scored_data["masks"] = []
|
|
scored_data["scores"] = []
|
|
scored_data["messages"] = []
|
|
|
|
for state in results["state"]:
|
|
messages = list(state.get("prompt", [])) + list(state.get("completion", []))
|
|
messages = [
|
|
{**msg, "content": msg.get("content") or ""} for msg in messages
|
|
]
|
|
|
|
trajectory = state.get("trajectory", [])
|
|
finish_reason = (
|
|
trajectory[-1]["response"].choices[0].finish_reason
|
|
if trajectory
|
|
else "stop"
|
|
)
|
|
|
|
tokenized = tokenize_for_trainer(
|
|
tokenizer=self.tokenizer,
|
|
chat=messages,
|
|
include_messages=True,
|
|
finish_reason=finish_reason,
|
|
train_on_all_assistant_turns=True,
|
|
)
|
|
|
|
scored_data["tokens"].append(tokenized["tokens"])
|
|
scored_data["masks"].append(tokenized["masks"])
|
|
scored_data["messages"].append(messages)
|
|
|
|
reward = state.get("reward", 0.0)
|
|
scored_data["scores"].append(reward)
|
|
|
|
# Capture metrics for wandb logging
|
|
self.reward_buffer.append(reward)
|
|
num_turns = len(trajectory)
|
|
self.num_turns_buffer.append(num_turns)
|
|
logger.debug("Rollout: %d turns, reward=%.3f", num_turns, reward)
|
|
|
|
# Extract per-function metrics from verifiers state
|
|
state_metrics = state.get("metrics", {})
|
|
if state_metrics:
|
|
for metric_name, metric_value in state_metrics.items():
|
|
if isinstance(metric_value, (int, float)):
|
|
self.metrics_buffer[metric_name].append(float(metric_value))
|
|
|
|
# Log group summary
|
|
turns = [len(s.get("trajectory", [])) for s in results["state"]]
|
|
logger.info(
|
|
"Group: %d rollouts, turns=%s, rewards=%s",
|
|
len(results["state"]),
|
|
turns,
|
|
[f"{s:.3f}" for s in scored_data["scores"]],
|
|
)
|
|
|
|
# Track group-level identical scores for debugging
|
|
self.groups_total += 1
|
|
if len(set(scored_data["scores"])) == 1:
|
|
self.groups_with_identical_scores += 1
|
|
logger.debug(
|
|
"Group has identical scores (%.3f) - will be filtered by base env",
|
|
scored_data["scores"][0],
|
|
)
|
|
|
|
return scored_data, []
|
|
|
|
async def _collect_for_rl(
|
|
self, item: Dict[str, Any]
|
|
) -> Tuple[ScoredDataGroup, list]:
|
|
"""RL mode: uses ManagedServer for logprobs tracking."""
|
|
from atroposlib.envs.server_handling.atropos_managed_client import (
|
|
AtroposManagedClient,
|
|
)
|
|
|
|
sampling_args = {
|
|
"temperature": 1.0,
|
|
"max_completion_tokens": self.config.max_token_length,
|
|
}
|
|
score_sem = asyncio.Semaphore(self.config.group_size)
|
|
model = self.server.servers[0].config.model_name
|
|
|
|
async def run_rollout() -> Dict[str, Any]:
|
|
"""Run a single rollout and return full state for metrics extraction."""
|
|
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
|
client = AtroposManagedClient(managed_server=managed, model=model)
|
|
|
|
state = await self.vf_env.rollout(
|
|
input={
|
|
"prompt": item["prompt"],
|
|
"answer": item.get("answer", ""),
|
|
"example_id": item["example_id"],
|
|
"task": item.get("task", self.config.vf_env_name),
|
|
"info": item.get("info", {}),
|
|
},
|
|
client=client,
|
|
model=model,
|
|
sampling_args=sampling_args,
|
|
)
|
|
await self.rubric.score_rollout(state, score_sem=score_sem)
|
|
return state
|
|
|
|
states = await asyncio.gather(
|
|
*[run_rollout() for _ in range(self.config.group_size)]
|
|
)
|
|
|
|
scored_data = ScoredDataGroup()
|
|
scored_data["tokens"] = []
|
|
scored_data["masks"] = []
|
|
scored_data["scores"] = []
|
|
scored_data["inference_logprobs"] = []
|
|
|
|
for state in states:
|
|
tokens, masks, logprobs, reward = self._extract_from_state(state)
|
|
scored_data["tokens"].append(tokens)
|
|
scored_data["masks"].append(masks)
|
|
scored_data["inference_logprobs"].append(logprobs)
|
|
scored_data["scores"].append(reward)
|
|
|
|
# Capture metrics for wandb logging
|
|
self.reward_buffer.append(reward)
|
|
|
|
trajectory = state.get("trajectory", [])
|
|
num_turns = len(trajectory)
|
|
self.num_turns_buffer.append(num_turns)
|
|
logger.debug("Rollout: %d turns, reward=%.3f", num_turns, reward)
|
|
|
|
# Extract per-function metrics from verifiers state
|
|
state_metrics = state.get("metrics", {})
|
|
if state_metrics:
|
|
for metric_name, metric_value in state_metrics.items():
|
|
if isinstance(metric_value, (int, float)):
|
|
self.metrics_buffer[metric_name].append(float(metric_value))
|
|
|
|
# Log group summary
|
|
turns = [len(s.get("trajectory", [])) for s in states]
|
|
logger.info(
|
|
"Group: %d rollouts, turns=%s, rewards=%s",
|
|
len(states),
|
|
turns,
|
|
[f"{s:.3f}" for s in scored_data["scores"]],
|
|
)
|
|
|
|
# Track group-level identical scores for debugging
|
|
self.groups_total += 1
|
|
if len(set(scored_data["scores"])) == 1:
|
|
self.groups_with_identical_scores += 1
|
|
logger.debug(
|
|
"Group has identical scores (%.3f) - will be filtered by base env",
|
|
scored_data["scores"][0],
|
|
)
|
|
|
|
return scored_data, []
|
|
|
|
def _extract_from_state(
|
|
self, state: Any
|
|
) -> Tuple[List[int], List[int], List[float], float]:
|
|
"""Extract tokens/masks/logprobs from rollout state (RL mode only)."""
|
|
all_tokens: List[int] = []
|
|
all_masks: List[int] = []
|
|
all_logprobs: List[float] = []
|
|
|
|
for step in state.get("trajectory", []):
|
|
tokens = step["tokens"]
|
|
prompt_ids = tokens["prompt_ids"]
|
|
completion_ids = tokens["completion_ids"]
|
|
completion_logprobs = tokens["completion_logprobs"]
|
|
|
|
all_tokens.extend(prompt_ids)
|
|
all_tokens.extend(completion_ids)
|
|
all_masks.extend([-100] * len(prompt_ids))
|
|
all_masks.extend(completion_ids)
|
|
all_logprobs.extend([1.0] * len(prompt_ids))
|
|
all_logprobs.extend(completion_logprobs)
|
|
|
|
return all_tokens, all_masks, all_logprobs, state["reward"]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
VerifiersEnv.cli()
|