parallelize verifiers_server: use generate() for SFT, parallel

ManagedServer contexts for RL
This commit is contained in:
balyan.sid@gmail.com 2026-01-12 07:20:56 +05:30
parent 24b4488c60
commit dceb1d8fd8

View file

@ -41,13 +41,11 @@ Docs: https://docs.primeintellect.ai/tutorials-environments/install
import asyncio
import logging
import time
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import verifiers as vf
from openai import AsyncOpenAI
from pydantic import Field
from tqdm.asyncio import tqdm_asyncio
from atroposlib.envs.base import (
APIServerConfig,
@ -60,59 +58,7 @@ from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
logger = logging.getLogger(__name__)
# =============================================================================
# Verifiers API Compatibility Layer
# =============================================================================
def _get_rubric_reward_funcs(rubric: vf.Rubric) -> List[Callable]:
"""
Get reward functions from a Rubric with API version compatibility.
Handles different verifiers API versions:
- v0.1.9+: rubric._get_reward_funcs() (private)
- v0.1.5-0.1.8: rubric.get_reward_funcs() (public)
- fallback: rubric.funcs (direct access)
"""
if hasattr(rubric, "_get_reward_funcs"):
return rubric._get_reward_funcs()
elif hasattr(rubric, "get_reward_funcs"):
return rubric.get_reward_funcs()
elif hasattr(rubric, "funcs"):
return rubric.funcs
else:
raise AttributeError(
f"Cannot find reward functions on rubric. "
f"Available attrs: {dir(rubric)}"
)
def _get_rubric_reward_weights(rubric: vf.Rubric) -> List[float]:
"""
Get reward weights from a Rubric with API version compatibility.
Handles different verifiers API versions:
- v0.1.9+: rubric._get_reward_weights() (private)
- v0.1.5-0.1.8: rubric.get_reward_weights() (public)
- fallback: rubric.weights (direct access)
"""
if hasattr(rubric, "_get_reward_weights"):
return rubric._get_reward_weights()
elif hasattr(rubric, "get_reward_weights"):
return rubric.get_reward_weights()
elif hasattr(rubric, "weights"):
return rubric.weights
else:
raise AttributeError(
f"Cannot find reward weights on rubric. " f"Available attrs: {dir(rubric)}"
)
class VfEnvConfig(BaseEnvConfig):
"""
Configuration for the Verifiers environments.
"""
vf_env_name: str = ""
env_args: Dict[str, Any] = Field(default_factory=dict)
@ -130,38 +76,12 @@ class VerifiersEnv(BaseEnv):
):
super().__init__(config, server_configs, slurm, testing)
self.percent_correct_buffer: List[float] = []
self.eval_metrics: List[Tuple[str, float]] = []
# Load verifiers environment
logger.info("Loading verifiers environment: %s", config.vf_env_name)
self.vf_env = vf.load_environment(config.vf_env_name, **config.env_args)
self.rubric = self.vf_env.rubric
self.parser = self.rubric.parser
# Handle both single Rubric and RubricGroup (composite)
# RubricGroup has empty funcs/weights at top level - must extract from individual rubrics
if hasattr(self.rubric, "rubrics"):
# RubricGroup: collect from all individual rubrics
self.reward_funcs: List[Callable] = []
self.reward_weights: List[float] = []
for rubric in self.rubric.rubrics: # type: ignore[attr-defined]
self.reward_funcs.extend(_get_rubric_reward_funcs(rubric))
self.reward_weights.extend(_get_rubric_reward_weights(rubric))
else:
# Single Rubric: use compatibility layer
self.reward_funcs = _get_rubric_reward_funcs(self.rubric)
self.reward_weights = _get_rubric_reward_weights(self.rubric)
total = sum(self.reward_weights) if self.reward_weights else 1.0
self.reward_scales = [weight / total for weight in self.reward_weights]
self.system_prompt = self.vf_env.system_prompt
logger.info(
"Loaded environment with %d reward functions, system_prompt=%s",
len(self.reward_funcs),
bool(self.system_prompt),
)
logger.info("Reward functions: %s", self.rubric._get_reward_func_names())
@classmethod
def config_init(cls) -> Tuple[VfEnvConfig, List[APIServerConfig]]:
@ -176,9 +96,6 @@ class VerifiersEnv(BaseEnv):
max_token_length=2048,
wandb_name="verifiers",
)
# Default config for local inference server (vLLM, SGLang, TRL)
# For SFT data generation with OpenAI, override via CLI:
# --openai.base_url https://api.openai.com/v1 --openai.model_name gpt-4o
server_configs = [
APIServerConfig(
model_name="gpt-4.1-nano",
@ -193,31 +110,19 @@ class VerifiersEnv(BaseEnv):
if wandb_metrics is None:
wandb_metrics = {}
# Calculate percent_correct from buffer
if self.percent_correct_buffer:
wandb_metrics["train/percent_correct"] = sum(
self.percent_correct_buffer
) / len(self.percent_correct_buffer)
self.percent_correct_buffer = list()
for item in self.eval_metrics:
wandb_metrics[item[0]] = item[1]
self.eval_metrics = list()
self.percent_correct_buffer = []
await super().wandb_log(wandb_metrics)
async def setup(self):
train_data = self.vf_env.get_dataset()
# Only load columns we need to avoid memory bloat
columns_to_keep = ["question", "answer", "info"]
available_columns = [c for c in columns_to_keep if c in train_data.column_names]
self.train = train_data.select_columns(available_columns).to_list()
test_data = self.vf_env.get_eval_dataset()
available_test_columns = [
c for c in columns_to_keep if c in test_data.column_names
]
self.test = test_data.select_columns(available_test_columns).to_list()
self.iter = 0
def save_checkpoint(self, step, data=None):
@ -226,145 +131,32 @@ class VerifiersEnv(BaseEnv):
data["iter"] = self.iter
super().save_checkpoint(step, data)
def _compute_score(self, completion_messages: List[Dict], answer: str) -> float:
"""Compute score using verifiers reward functions."""
rewards = []
for func in self.reward_funcs:
reward = func(
parser=self.parser,
completion=completion_messages,
answer=answer,
)
rewards.append(reward)
weighted_rewards = [r * self.reward_scales[j] for j, r in enumerate(rewards)]
return sum(weighted_rewards)
async def rollout_and_score_eval(
self, question: str, answer: str, **kwargs
) -> dict:
"""
Rollout and score for evaluation.
Uses ManagedServer in serve mode, direct API calls in process mode.
"""
system_prompt = kwargs.get("system_prompt")
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": question},
]
is_process_mode = getattr(self, "process_mode", False)
if is_process_mode:
# Process mode: use direct API call (works with any API)
completion = await self.server.chat_completion(
messages=messages,
n=1,
max_tokens=self.config.max_token_length,
temperature=0.0,
)
else:
# Serve mode: use ManagedServer for token tracking
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
completion = await managed.chat_completion(
messages=messages,
n=1,
max_tokens=self.config.max_token_length,
temperature=0.0,
)
response_content = completion.choices[0].message.content or ""
messages.append({"role": "assistant", "content": response_content})
answer_parsed = self.parser.parse_answer(completion=response_content)
score = self._compute_score(messages, answer)
sample = {
"messages": messages,
"question": question,
"gold_answer": answer,
"model_parsed": str(answer_parsed) if answer_parsed else None,
"score": score,
"correct": bool(score),
"finish_reason": completion.choices[0].finish_reason,
}
return {"score": score, "sample": sample}
async def evaluate(self, *args, **kwargs):
start_time = time.time()
eval_tasks = []
for item in self.test:
eval_tasks.append(
self.rollout_and_score_eval(
item["question"], item["answer"], system_prompt=self.system_prompt
)
)
results = await tqdm_asyncio.gather(*eval_tasks)
scores = [result["score"] for result in results]
samples = [result["sample"] for result in results]
avg_total_score = sum(scores) / len(scores)
end_time = time.time()
self.eval_metrics.append(("eval/avg_total_score", avg_total_score))
eval_metrics = {"eval/avg_total_score": avg_total_score}
await self.evaluate_log(
metrics=eval_metrics,
samples=samples,
start_time=start_time,
end_time=end_time,
generation_parameters={
"temperature": 0.0,
"max_tokens": self.config.max_token_length,
},
)
return eval_metrics
async def get_next_item(self):
next_item = self.train[self.iter % len(self.train)]
self.iter += 1
return next_item
async def evaluate(self) -> Dict[str, float]:
"""No-op. Use environments/eval_environments/verifiers_eval.py for evaluation."""
return {}
def _build_initial_messages(self, question: str) -> List[Dict[str, str]]:
messages: List[Dict[str, str]] = []
if self.system_prompt:
messages.append({"role": "system", "content": self.system_prompt})
messages.append({"role": "user", "content": question})
return messages
async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, list]:
"""
Collect trajectories - switches between:
- SFT data generation (process mode): Any API, no logprobs needed
- RL training (serve mode): Local server with logprobs
"""
is_process_mode = getattr(self, "process_mode", False)
if is_process_mode:
return await self._collect_trajectories_for_sft(item)
else:
return await self._collect_trajectories_for_rl(item)
return await self._collect_for_sft(item)
return await self._collect_for_rl(item)
async def _collect_trajectories_for_sft(
async def _collect_for_sft(
self, item: Dict[str, Any]
) -> Tuple[ScoredDataGroup, list]:
"""
SFT data generation mode - works with ANY API (OpenAI, Claude, local).
Does NOT require logprobs or local server.
Uses verifiers rollout() for multi-turn environments and tokenize_for_trainer
to tokenize completions with your training tokenizer.
"""
question = item["question"]
answer = item["answer"]
# Build initial messages
initial_messages: List[Dict[str, str]] = []
if self.system_prompt:
initial_messages.append({"role": "system", "content": self.system_prompt})
initial_messages.append({"role": "user", "content": question})
# Create AsyncOpenAI client directly from server config (no ManagedServer needed)
"""SFT mode: uses vf_env.generate() for parallel generation + scoring."""
server_config = self.server.servers[0].config
client = AsyncOpenAI(
api_key=server_config.api_key,
@ -372,11 +164,31 @@ class VerifiersEnv(BaseEnv):
timeout=server_config.timeout,
)
# Sampling args - use max_completion_tokens for newer models like gpt-5
sampling_args = {
"temperature": 1.0,
"max_completion_tokens": self.config.max_token_length,
}
initial_messages = self._build_initial_messages(item["question"])
inputs = [
{
"prompt": initial_messages,
"answer": item["answer"],
"example_id": i,
"task": self.config.vf_env_name,
"info": item.get("info", {}),
}
for i in range(self.config.group_size)
]
results = await self.vf_env.generate(
inputs=inputs,
client=client,
model=server_config.model_name,
sampling_args={
"temperature": 1.0,
"max_completion_tokens": self.config.max_token_length,
},
max_concurrent=self.config.group_size,
max_concurrent_scoring=self.config.group_size,
save_results=False,
independent_scoring=True,
)
scored_data = ScoredDataGroup()
scored_data["tokens"] = []
@ -384,63 +196,22 @@ class VerifiersEnv(BaseEnv):
scored_data["scores"] = []
scored_data["messages"] = []
# Semaphore for scoring (required by rubric.score_rollout)
score_sem = asyncio.Semaphore(1)
# Run rollouts in parallel for group_size
async def run_single_rollout(example_id: int):
# Pass through any info from the dataset item (e.g., docker_image for SWE envs)
item_info = item.get("info", {})
rollout_input = {
"prompt": initial_messages,
"answer": answer,
"example_id": example_id,
"task": self.config.vf_env_name,
"info": item_info,
}
state = await self.vf_env.rollout(
input=rollout_input,
client=client,
model=server_config.model_name,
sampling_args=sampling_args,
)
# Score the rollout using verifiers rubric (computes reward from test output)
# This is needed because vf_env.rollout() doesn't call score_rollout
await self.rubric.score_rollout(state, score_sem=score_sem)
return state
# Run group_size rollouts in parallel
rollout_tasks = [run_single_rollout(i) for i in range(self.config.group_size)]
states = await asyncio.gather(*rollout_tasks)
for state in states:
# Extract completion messages from state
completion_messages = list(state.get("prompt", [])) + list(
state.get("completion", [])
)
# Ensure all message contents are strings (not None)
# This can happen with tool call messages that have content: null
completion_messages = [
{**msg, "content": msg.get("content") or ""}
for msg in completion_messages
for state in results["state"]:
messages = list(state.get("prompt", [])) + list(state.get("completion", []))
messages = [
{**msg, "content": msg.get("content") or ""} for msg in messages
]
# Get reward from verifiers scoring (set by rubric.score_rollout above)
score = state.get("reward", 0.0)
# Determine finish reason from last trajectory step
trajectory = state.get("trajectory", [])
if trajectory:
finish_reason = trajectory[-1]["response"].choices[0].finish_reason
else:
finish_reason = "stop"
finish_reason = (
trajectory[-1]["response"].choices[0].finish_reason
if trajectory
else "stop"
)
# Use tokenize_for_trainer for tokenization
# train_on_all_assistant_turns=True ensures ALL assistant turns are unmasked
# for multi-turn environments, not just the last message
tokenized = tokenize_for_trainer(
tokenizer=self.tokenizer,
chat=completion_messages,
chat=messages,
include_messages=True,
finish_reason=finish_reason,
train_on_all_assistant_turns=True,
@ -448,39 +219,54 @@ class VerifiersEnv(BaseEnv):
scored_data["tokens"].append(tokenized["tokens"])
scored_data["masks"].append(tokenized["masks"])
scored_data["messages"].append(completion_messages)
scored_data["scores"].append(score)
scored_data["messages"].append(messages)
scored_data["scores"].append(state.get("reward", 0.0))
# Track scores for wandb logging
for score in scored_data["scores"]:
self.percent_correct_buffer.append(max(score, 0))
return scored_data, []
async def _collect_trajectories_for_rl(
async def _collect_for_rl(
self, item: Dict[str, Any]
) -> Tuple[ScoredDataGroup, list]:
"""
RL training mode - requires local inference server for logprobs.
Uses AtroposManagedClient with vf_env.rollout() for both single-turn and multi-turn.
"""
"""RL mode: uses ManagedServer for logprobs tracking."""
from atroposlib.envs.server_handling.atropos_managed_client import (
AtroposManagedClient,
)
question = item["question"]
answer = item["answer"]
item_info = item.get("info", {})
initial_messages: List[Dict[str, str]] = []
if self.system_prompt:
initial_messages.append({"role": "system", "content": self.system_prompt})
initial_messages.append({"role": "user", "content": question})
initial_messages = self._build_initial_messages(item["question"])
sampling_args = {
"temperature": 1.0,
"max_completion_tokens": self.config.max_token_length,
}
score_sem = asyncio.Semaphore(self.config.group_size)
model = self.server.servers[0].config.model_name
async def run_rollout(
example_id: int,
) -> Tuple[List[int], List[int], List[float], float]:
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
client = AtroposManagedClient(managed_server=managed, model=model)
state = await self.vf_env.rollout(
input={
"prompt": initial_messages,
"answer": item["answer"],
"example_id": example_id,
"task": self.config.vf_env_name,
"info": item.get("info", {}),
},
client=client,
model=model,
sampling_args=sampling_args,
)
await self.rubric.score_rollout(state, score_sem=score_sem)
return self._extract_from_state(state)
results = await asyncio.gather(
*[run_rollout(i) for i in range(self.config.group_size)]
)
scored_data = ScoredDataGroup()
scored_data["tokens"] = []
@ -488,44 +274,12 @@ class VerifiersEnv(BaseEnv):
scored_data["scores"] = []
scored_data["inference_logprobs"] = []
# Semaphore for scoring (required by rubric.score_rollout)
score_sem = asyncio.Semaphore(1)
for tokens, masks, logprobs, score in results:
scored_data["tokens"].append(tokens)
scored_data["masks"].append(masks)
scored_data["inference_logprobs"].append(logprobs)
scored_data["scores"].append(score)
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
client = AtroposManagedClient(
managed_server=managed,
model=self.server_configs[0].model_name,
)
# Run group_size rollouts sequentially (ManagedServer state must be reset between)
for i in range(self.config.group_size):
client.reset()
rollout_input = {
"prompt": initial_messages,
"answer": answer,
"example_id": i,
"task": self.config.vf_env_name,
"info": item_info,
}
state = await self.vf_env.rollout(
input=rollout_input,
client=client,
model=self.server_configs[0].model_name,
sampling_args=sampling_args,
)
# Score the rollout (computes reward from test output)
await self.rubric.score_rollout(state, score_sem=score_sem)
tokens, masks, logprobs, score = self._extract_from_state(state)
scored_data["tokens"].append(tokens)
scored_data["masks"].append(masks)
scored_data["inference_logprobs"].append(logprobs)
scored_data["scores"].append(score)
# Track scores for wandb logging
for score in scored_data["scores"]:
self.percent_correct_buffer.append(max(score, 0))
@ -534,38 +288,25 @@ class VerifiersEnv(BaseEnv):
def _extract_from_state(
self, state: Any
) -> Tuple[List[int], List[int], List[float], float]:
"""
Extract tokens/masks/logprobs/score from a single rollout state.
Handles the mask convention conversion:
- Verifiers: prompt_mask=0, completion_mask=1
- Atropos: masked_tokens=-100 (prompt), token_id (completion)
"""
"""Extract tokens/masks/logprobs from rollout state (RL mode only)."""
all_tokens: List[int] = []
all_masks: List[int] = []
all_logprobs: List[float] = []
trajectory = state.get("trajectory", [])
for step in trajectory:
for step in state.get("trajectory", []):
tokens = step["tokens"]
prompt_ids = tokens["prompt_ids"]
completion_ids = tokens["completion_ids"]
completion_logprobs = tokens["completion_logprobs"]
all_tokens.extend(prompt_ids)
all_tokens.extend(completion_ids)
all_masks.extend([-100] * len(prompt_ids))
all_masks.extend(completion_ids)
all_logprobs.extend([1.0] * len(prompt_ids))
all_logprobs.extend(completion_logprobs)
reward = state["reward"]
return all_tokens, all_masks, all_logprobs, reward
return all_tokens, all_masks, all_logprobs, state["reward"]
if __name__ == "__main__":