rework server and eval for rl rollout. add in asyncmanagedserver for

verifiers
This commit is contained in:
balyan.sid@gmail.com 2026-01-10 14:55:08 +05:30
parent 3449a4c23d
commit cf636595d2
4 changed files with 652 additions and 91 deletions

View file

@ -0,0 +1,330 @@
"""
AtroposManagedClient: AsyncOpenAI-compatible client backed by ManagedServer.
This module provides a drop-in replacement for AsyncOpenAI that uses Atropos's
ManagedServer for inference, enabling token tracking for multi-turn RL training
with the Verifiers library.
Usage:
async with server_manager.managed_server(tokenizer=tokenizer) as managed:
client = AtroposManagedClient(managed_server=managed, model="model-name")
# Use like AsyncOpenAI - tokens are tracked automatically
response = await client.chat.completions.create(
messages=[{"role": "user", "content": "Hello"}],
max_tokens=100
)
# Token data is available on the response:
# - response.prompt_token_ids
# - response.choices[0].token_ids
# - response.choices[0].logprobs.content[i].logprob
"""
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from atroposlib.envs.server_handling.managed_server import ManagedServer, SequenceNode
# =============================================================================
# Enhanced Types for Token Data Injection
# =============================================================================
@dataclass
class LogprobContent:
"""
Single token logprob entry.
Compatible with verifiers' parse_response_tokens() which accesses:
- response.choices[i].logprobs.content[j].logprob
"""
logprob: float
token: str = ""
token_id: int = 0
top_logprobs: Optional[List[Any]] = None
@dataclass
class ChoiceLogprobs:
"""
Logprobs structure compatible with verifiers expectations.
Verifiers checks for either object or dict format:
- Object: response.choices[i].logprobs.content[j].logprob
- Dict: response.choices[i].logprobs["content"][j]["logprob"]
This dataclass supports the object format.
"""
content: List[LogprobContent] = field(default_factory=list)
@dataclass
class EnhancedChoice:
"""
Choice with token_ids and logprobs for RL training.
Adds the following attributes that verifiers expects:
- token_ids: List[int] - completion token IDs
- logprobs: ChoiceLogprobs - structured logprobs
"""
index: int
message: ChatCompletionMessage
finish_reason: str
token_ids: List[int]
logprobs: ChoiceLogprobs
@dataclass
class EnhancedChatCompletion:
"""
ChatCompletion with token data for RL training.
Compatible with verifiers' parse_response_tokens() expectations:
- prompt_token_ids: list[int]
- choices[i].token_ids: list[int]
- choices[i].logprobs.content[j].logprob
"""
id: str
created: int
model: str
object: str
choices: List[EnhancedChoice]
prompt_token_ids: List[int]
usage: Optional[Dict[str, int]] = None
# =============================================================================
# AsyncOpenAI-Compatible Client Classes
# =============================================================================
class _CompletionsNamespace:
"""
Mimics openai.resources.chat.completions.AsyncCompletions.
Provides the create() method that verifiers calls.
"""
def __init__(self, parent: "AtroposManagedClient"):
self.parent = parent
async def create(
self,
*,
messages: List[Dict[str, Any]],
model: Optional[str] = None,
n: int = 1,
max_tokens: Optional[int] = None,
max_completion_tokens: Optional[int] = None,
temperature: float = 1.0,
top_p: float = 1.0,
tools: Optional[List[Dict]] = None,
stop: Optional[List[str]] = None,
**kwargs,
) -> EnhancedChatCompletion:
"""
Create chat completion with token tracking.
Returns ChatCompletion with additional attributes:
- prompt_token_ids: list[int]
- choices[i].token_ids: list[int]
- choices[i].logprobs.content: list with logprob info
Args:
messages: List of message dicts with 'role' and 'content'
model: Model name (defaults to client's model)
n: Number of completions (should be 1 for multi-turn)
max_tokens: Max tokens in completion (legacy param)
max_completion_tokens: Max tokens in completion (new param)
temperature: Sampling temperature
top_p: Nucleus sampling parameter
tools: Tool definitions for function calling
stop: Stop sequences
**kwargs: Additional parameters passed to ManagedServer
"""
# Use max_completion_tokens if provided, else max_tokens
effective_max_tokens = max_completion_tokens or max_tokens
# Build kwargs for ManagedServer
completion_kwargs = {
"messages": messages,
"model": model or self.parent.model,
"n": n,
"temperature": temperature,
"top_p": top_p,
}
if effective_max_tokens is not None:
completion_kwargs["max_tokens"] = effective_max_tokens
if tools is not None:
completion_kwargs["tools"] = tools
if stop is not None:
completion_kwargs["stop"] = stop
# Add any extra kwargs (like logprobs settings)
for key, value in kwargs.items():
if value is not None:
completion_kwargs[key] = value
# Call ManagedServer for inference
completion = await self.parent.managed_server.chat_completion(
**completion_kwargs
)
# Get token state from managed server
state = self.parent.managed_server.get_state()
nodes: List[SequenceNode] = state["nodes"]
# Inject token data into response
return self._enhance_completion(completion, nodes)
def _enhance_completion(
self, completion: Any, nodes: List[SequenceNode]
) -> EnhancedChatCompletion:
"""
Convert ManagedServer output to verifiers-compatible format.
Extracts token data from SequenceNodes and injects it into the
ChatCompletion response in the format verifiers expects.
"""
enhanced_choices = []
prompt_token_ids: List[int] = []
for i, (choice, node) in enumerate(zip(completion.choices, nodes)):
# Find prompt/completion boundary from masked_tokens
# -100 indicates prompt tokens, actual token IDs indicate completion
prompt_len = sum(1 for m in node.masked_tokens if m == -100)
# Extract prompt and completion portions
if i == 0:
prompt_token_ids = node.tokens[:prompt_len]
completion_ids = node.tokens[prompt_len:]
completion_logprobs = node.logprobs[prompt_len:]
# Build logprobs structure verifiers expects
logprobs_content = []
tokenizer = self.parent.managed_server.tokenizer
for token_id, logprob in zip(completion_ids, completion_logprobs):
# Decode token to string if tokenizer available
token_str = ""
if tokenizer is not None:
try:
token_str = tokenizer.decode([token_id])
except Exception:
token_str = f"<token_{token_id}>"
logprobs_content.append(
LogprobContent(
logprob=logprob,
token=token_str,
token_id=token_id,
)
)
# Create enhanced choice with token data
enhanced_choice = EnhancedChoice(
index=choice.index,
message=choice.message,
finish_reason=choice.finish_reason or "stop",
token_ids=completion_ids,
logprobs=ChoiceLogprobs(content=logprobs_content),
)
enhanced_choices.append(enhanced_choice)
return EnhancedChatCompletion(
id=completion.id,
created=completion.created,
model=completion.model,
object=completion.object,
choices=enhanced_choices,
prompt_token_ids=prompt_token_ids,
usage=completion.usage.model_dump() if completion.usage else None,
)
class _ChatNamespace:
"""Mimics openai.resources.chat.AsyncChat."""
def __init__(self, parent: "AtroposManagedClient"):
self.completions = _CompletionsNamespace(parent)
class AtroposManagedClient:
"""
AsyncOpenAI-compatible client backed by ManagedServer.
This client provides the same interface as AsyncOpenAI but uses Atropos's
ManagedServer for inference, enabling automatic token tracking for
multi-turn RL training with the Verifiers library.
The key feature is that responses include token data attributes that
verifiers' parse_response_tokens() expects:
- response.prompt_token_ids
- response.choices[i].token_ids
- response.choices[i].logprobs.content[j].logprob
Usage:
async with server_manager.managed_server(tokenizer=tokenizer) as managed:
client = AtroposManagedClient(
managed_server=managed,
model="Qwen/Qwen2.5-1.5B-Instruct"
)
# Pass to verifiers env.rollout()
state = await vf_env.rollout(
input=rollout_input,
client=client,
model="Qwen/Qwen2.5-1.5B-Instruct",
)
# Token data is now in state["trajectory"][i]["tokens"]
"""
def __init__(
self,
managed_server: ManagedServer,
model: str,
base_url: Optional[str] = None,
):
"""
Initialize the managed client.
Args:
managed_server: ManagedServer instance for inference and token tracking
model: Model name to use for completions
base_url: Optional base URL (for API compatibility, not used)
"""
self.managed_server = managed_server
self.model = model
self.base_url = base_url or "http://managed-server"
# Mimic AsyncOpenAI namespace structure
self.chat = _ChatNamespace(self)
def reset(self):
"""Reset token tracking state between rollouts."""
self.managed_server.reset()
async def close(self):
"""Compatibility method - no-op since ManagedServer handles cleanup."""
pass
def copy(self, **_kwargs) -> "AtroposManagedClient":
"""
Create a copy of this client (for API compatibility).
Verifiers may call client.copy() for certain operations.
Returns self since we want to maintain the same ManagedServer state.
"""
return self

View file

@ -0,0 +1,31 @@
# Verifiers environment configuration
# Usage: python environments/verifiers_server.py serve --config environments/configs/verifiers.yaml
#
# For SFT data generation with external API:
# python environments/verifiers_server.py process \
# --env.vf_env_name primeintellect/gsm8k \
# --env.data_path_to_save_groups output.jsonl \
# --openai.base_url https://api.openai.com/v1 \
# --openai.api_key $OPENAI_API_KEY \
# --openai.model_name gpt-4o
env:
vf_env_name: "primeintellect/gsm8k" # Prime Env Hub environment
env_args: {}
group_size: 8
max_token_length: 2048
tokenizer_name: "Qwen/Qwen2.5-1.5B-Instruct"
rollout_server_url: "http://localhost:8000"
use_wandb: true
wandb_name: "verifiers"
total_steps: 1000
batch_size: 4
steps_per_eval: 100
openai:
- model_name: "Qwen/Qwen2.5-1.5B-Instruct"
base_url: "http://localhost:9001/v1"
api_key: "x"
slurm: false
testing: false

View file

@ -18,15 +18,15 @@ Usage:
"""
import asyncio
import inspect
import os
import time
from typing import Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple
import verifiers as vf
import wandb
from pydantic import Field
from tqdm.asyncio import tqdm_asyncio
import wandb
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
@ -34,6 +34,36 @@ from atroposlib.envs.base import (
)
# Patch math_verify timeout to work in async context
# The signal-based timeout doesn't work in non-main threads (asyncio event loop)
def _no_signal_timeout(timeout_seconds: int):
"""Replacement timeout decorator that doesn't use signals."""
def decorator(func):
def wrapper(*args, **kwargs):
# Just call the function without timeout
# This is safe because we're in an async context with our own timeouts
# timeout_seconds is intentionally unused - we're replacing the timeout logic
return func(*args, **kwargs)
return wrapper
return decorator
try:
import math_verify.grader
import math_verify.parser
import math_verify.utils
# Patch all modules that use the timeout decorator
math_verify.utils.timeout = _no_signal_timeout
math_verify.parser.timeout = _no_signal_timeout
math_verify.grader.timeout = _no_signal_timeout
except ImportError:
pass # math_verify not installed
class VerifiersEvaluationConfig(BaseEnvConfig):
"""Configuration for Verifiers evaluation environment."""
@ -91,13 +121,30 @@ class VerifiersEvaluationEnv(BaseEnv):
self.vf_env = vf.load_environment(config.vf_env_name, **config.env_args)
self.rubric = self.vf_env.rubric
# Extract rubric components
# Extract rubric components from RubricGroup
# RubricGroup.funcs is empty - need to collect from individual rubrics
self.parser = self.rubric.parser
self.reward_funcs: List[Callable] = []
self.reward_weights: List[float] = []
self.rubric_class_objects: List[Dict[str, Any]] = [] # class_objects per func
if hasattr(self.rubric, "rubrics"):
# RubricGroup: collect from all individual rubrics
for rubric in self.rubric.rubrics:
class_objects = getattr(rubric, "class_objects", {})
for func, weight in zip(rubric.funcs, rubric.weights):
self.reward_funcs.append(func)
self.reward_weights.append(weight)
self.rubric_class_objects.append(class_objects)
else:
# Single Rubric
self.reward_funcs = self.rubric.funcs
self.reward_weights = self.rubric.weights
self.reward_scales = [
weight / sum(self.reward_weights) for weight in self.reward_weights
]
class_objects = getattr(self.rubric, "class_objects", {})
self.rubric_class_objects = [class_objects] * len(self.rubric.funcs)
total_weight = sum(self.reward_weights) if self.reward_weights else 1.0
self.reward_scales = [weight / total_weight for weight in self.reward_weights]
self.system_prompt = self.vf_env.system_prompt
# Tracking
@ -192,14 +239,40 @@ class VerifiersEvaluationEnv(BaseEnv):
# Parse answer
answer_parsed = self.parser.parse_answer(completion=response_text)
# Score using reward funcs
# Score using reward funcs (async functions need await)
# Use signature introspection to pass only required params (like verifiers does)
rewards = []
for func in self.reward_funcs:
reward = func(
parser=self.parser,
completion=completion_messages,
answer=answer,
)
for i, func in enumerate(self.reward_funcs):
try:
# Build merged dict of all possible parameters
class_objects = self.rubric_class_objects[i]
merged = {
"completion": completion_messages,
"answer": answer,
"prompt": question,
}
merged.update(class_objects) # Adds parser, etc.
# Filter to only params the function accepts
sig = inspect.signature(func)
if any(p.kind == p.VAR_KEYWORD for p in sig.parameters.values()):
# Function accepts **kwargs, pass everything
kwargs = merged
else:
# Only pass params in signature
kwargs = {k: v for k, v in merged.items() if k in sig.parameters}
result = func(**kwargs)
# Reward functions may be async coroutines
if asyncio.iscoroutine(result):
reward = await result
else:
reward = result
reward = float(reward)
except Exception as e:
if self.config.full_debug:
print(f" Reward func {func.__name__} error: {e}")
reward = 0.0
rewards.append(reward)
weighted_rewards = [r * self.reward_scales[j] for j, r in enumerate(rewards)]
@ -235,11 +308,14 @@ class VerifiersEvaluationEnv(BaseEnv):
start_time = time.time()
# Create evaluation tasks
tasks = [self.rollout_and_score(item) for item in self.eval_items]
# Run sequentially to avoid signal/threading issues with math_verify parser
# The parser uses signals for timeouts which only work in main thread
from tqdm import tqdm
# Run with progress bar
results = await tqdm_asyncio.gather(*tasks, desc="Evaluating")
results = []
for item in tqdm(self.eval_items, desc="Evaluating"):
result = await self.rollout_and_score(item)
results.append(result)
# Filter out failed results
valid_results = [r for r in results if r is not None]

View file

@ -39,10 +39,12 @@ To install a Verifiers/Prime environment:
Docs: https://docs.primeintellect.ai/tutorials-environments/install
"""
import asyncio
import time
from typing import Any, Dict, List, Optional, Tuple
import verifiers as vf
from openai import AsyncOpenAI
from pydantic import Field
from tqdm.asyncio import tqdm_asyncio
@ -82,11 +84,21 @@ class VerifiersEnv(BaseEnv):
self.rubric = self.vf_env.rubric
self.parser = self.rubric.parser
# Handle both single Rubric and RubricGroup (composite)
# RubricGroup has empty funcs/weights at top level - must extract from individual rubrics
if hasattr(self.rubric, "rubrics"):
self.reward_funcs = []
self.reward_weights = []
for rubric in self.rubric.rubrics:
self.reward_funcs.extend(rubric.funcs)
self.reward_weights.extend(rubric.weights)
else:
self.reward_funcs = self.rubric.funcs
self.reward_weights = self.rubric.weights
self.reward_scales = [
weight / sum(self.reward_weights) for weight in self.reward_weights
]
total = sum(self.reward_weights) if self.reward_weights else 1.0
self.reward_scales = [weight / total for weight in self.reward_weights]
self.system_prompt = self.vf_env.system_prompt
@classmethod
@ -135,9 +147,15 @@ class VerifiersEnv(BaseEnv):
async def setup(self):
train_data = self.vf_env.get_dataset()
self.train = train_data.select_columns(["question", "answer"]).to_list()
# Only load columns we need to avoid memory bloat
columns_to_keep = ["question", "answer", "info"]
available_columns = [c for c in columns_to_keep if c in train_data.column_names]
self.train = train_data.select_columns(available_columns).to_list()
test_data = self.vf_env.get_eval_dataset()
self.test = test_data.select_columns(["question", "answer"]).to_list()
available_test_columns = [
c for c in columns_to_keep if c in test_data.column_names
]
self.test = test_data.select_columns(available_test_columns).to_list()
self.iter = 0
def save_checkpoint(self, step, data=None):
@ -254,70 +272,116 @@ class VerifiersEnv(BaseEnv):
async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, list]:
"""
Collect trajectories - automatically switches between:
- ManagedServer (for RL training with local server - requires logprobs)
- tokenize_for_trainer (for SFT datagen with any API - no logprobs needed)
Collect trajectories - switches between:
- SFT data generation (process mode): Any API, no logprobs needed
- RL training (serve mode): Local server with logprobs
"""
question = item["question"]
answer = item["answer"]
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": question},
]
# Check if we're in process mode (SFT data generation)
is_process_mode = getattr(self, "process_mode", False)
if is_process_mode:
return await self._collect_trajectories_for_sft(messages, answer)
return await self._collect_trajectories_for_sft(item)
else:
return await self._collect_trajectories_for_training(messages, answer)
return await self._collect_trajectories_for_rl(item)
async def _collect_trajectories_for_sft(
self, messages: List[Dict], answer: str
self, item: Dict[str, Any]
) -> Tuple[ScoredDataGroup, list]:
"""
SFT data generation mode - works with ANY API (OpenAI, Claude, local).
Does NOT require logprobs or local server.
Uses tokenize_for_trainer to tokenize completions with your training
tokenizer, so the resulting data is ready for fine-tuning your target model.
Uses verifiers rollout() for multi-turn environments and tokenize_for_trainer
to tokenize completions with your training tokenizer.
"""
completions = await self.server.chat_completion(
messages=messages,
n=self.config.group_size,
max_tokens=self.config.max_token_length,
temperature=1.0,
question = item["question"]
answer = item["answer"]
# Build initial messages
initial_messages: List[Dict[str, str]] = []
if self.system_prompt:
initial_messages.append({"role": "system", "content": self.system_prompt})
initial_messages.append({"role": "user", "content": question})
# Create AsyncOpenAI client directly from server config (no ManagedServer needed)
server_config = self.server.servers[0].config
client = AsyncOpenAI(
api_key=server_config.api_key,
base_url=server_config.base_url,
timeout=server_config.timeout,
)
# Sampling args - use max_completion_tokens for newer models like gpt-5
sampling_args = {
"temperature": 1.0,
"max_completion_tokens": self.config.max_token_length,
}
scored_data = ScoredDataGroup()
scored_data["tokens"] = []
scored_data["masks"] = []
scored_data["scores"] = []
scored_data["messages"] = []
# Note: No inference_logprobs - not needed/available for SFT
for choice in completions.choices:
response = choice.message.content or ""
finish_reason = choice.finish_reason or ""
# Semaphore for scoring (required by rubric.score_rollout)
score_sem = asyncio.Semaphore(1)
# Build full conversation for scoring and tokenization
completion_messages = messages + [
{"role": "assistant", "content": response}
# Run rollouts in parallel for group_size
async def run_single_rollout(example_id: int):
# Pass through any info from the dataset item (e.g., docker_image for SWE envs)
item_info = item.get("info", {})
rollout_input = {
"prompt": initial_messages,
"answer": answer,
"example_id": example_id,
"task": self.config.vf_env_name,
"info": item_info,
}
state = await self.vf_env.rollout(
input=rollout_input,
client=client,
model=server_config.model_name,
sampling_args=sampling_args,
)
# Score the rollout using verifiers rubric (computes reward from test output)
# This is needed because vf_env.rollout() doesn't call score_rollout
await self.rubric.score_rollout(state, score_sem=score_sem)
return state
# Run group_size rollouts in parallel
rollout_tasks = [run_single_rollout(i) for i in range(self.config.group_size)]
states = await asyncio.gather(*rollout_tasks)
for state in states:
# Extract completion messages from state
completion_messages = list(state.get("prompt", [])) + list(
state.get("completion", [])
)
# Ensure all message contents are strings (not None)
# This can happen with tool call messages that have content: null
completion_messages = [
{**msg, "content": msg.get("content") or ""}
for msg in completion_messages
]
# Score using verifiers reward funcs
score = self._compute_score(completion_messages, answer)
# Get reward from verifiers scoring (set by rubric.score_rollout above)
score = state.get("reward", 0.0)
# Determine finish reason from last trajectory step
trajectory = state.get("trajectory", [])
if trajectory:
finish_reason = trajectory[-1]["response"].choices[0].finish_reason
else:
finish_reason = "stop"
# Use tokenize_for_trainer for tokenization
# This uses YOUR training tokenizer (e.g., Qwen, Llama), not the API's tokenizer
# So GPT-4o responses get tokenized for your target model
# train_on_all_assistant_turns=True ensures ALL assistant turns are unmasked
# for multi-turn environments, not just the last message
tokenized = tokenize_for_trainer(
tokenizer=self.tokenizer,
chat=completion_messages,
include_messages=True,
finish_reason=finish_reason,
train_on_all_assistant_turns=True,
)
scored_data["tokens"].append(tokenized["tokens"])
@ -331,48 +395,72 @@ class VerifiersEnv(BaseEnv):
return scored_data, []
async def _collect_trajectories_for_training(
self, messages: List[Dict], answer: str
async def _collect_trajectories_for_rl(
self, item: Dict[str, Any]
) -> Tuple[ScoredDataGroup, list]:
"""
RL training mode - requires local inference server.
Uses ManagedServer for proper token/logprob alignment.
The inference_logprobs are required for policy gradient methods like
GRPO, PPO, REINFORCE, etc.
RL training mode - requires local inference server for logprobs.
Uses AtroposManagedClient with vf_env.rollout() for both single-turn and multi-turn.
"""
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
completions = await managed.chat_completion(
messages=messages,
n=self.config.group_size,
max_tokens=self.config.max_token_length,
temperature=1.0,
from atroposlib.envs.server_handling.atropos_managed_client import (
AtroposManagedClient,
)
state = managed.get_state()
nodes = state["nodes"]
question = item["question"]
answer = item["answer"]
item_info = item.get("info", {})
initial_messages: List[Dict[str, str]] = []
if self.system_prompt:
initial_messages.append({"role": "system", "content": self.system_prompt})
initial_messages.append({"role": "user", "content": question})
sampling_args = {
"temperature": 1.0,
"max_completion_tokens": self.config.max_token_length,
}
scored_data = ScoredDataGroup()
scored_data["tokens"] = []
scored_data["masks"] = []
scored_data["scores"] = []
scored_data["inference_logprobs"] = [] # Required for RL training!
scored_data["inference_logprobs"] = []
for i, choice in enumerate(completions.choices):
response = choice.message.content or ""
# Semaphore for scoring (required by rubric.score_rollout)
score_sem = asyncio.Semaphore(1)
# Build full conversation for scoring
completion_messages = messages + [
{"role": "assistant", "content": response}
]
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
client = AtroposManagedClient(
managed_server=managed,
model=self.server_configs[0].model_name,
)
# Score using verifiers reward funcs
score = self._compute_score(completion_messages, answer)
# Run group_size rollouts sequentially (ManagedServer state must be reset between)
for i in range(self.config.group_size):
client.reset()
# Use ManagedServer's properly aligned tokens/masks/logprobs
node = nodes[i]
scored_data["tokens"].append(node.tokens)
scored_data["masks"].append(node.masked_tokens)
scored_data["inference_logprobs"].append(node.logprobs)
rollout_input = {
"prompt": initial_messages,
"answer": answer,
"example_id": i,
"task": self.config.vf_env_name,
"info": item_info,
}
state = await self.vf_env.rollout(
input=rollout_input,
client=client,
model=self.server_configs[0].model_name,
sampling_args=sampling_args,
)
# Score the rollout (computes reward from test output)
await self.rubric.score_rollout(state, score_sem=score_sem)
tokens, masks, logprobs, score = self._extract_from_state(state)
scored_data["tokens"].append(tokens)
scored_data["masks"].append(masks)
scored_data["inference_logprobs"].append(logprobs)
scored_data["scores"].append(score)
# Track scores for wandb logging
@ -381,6 +469,42 @@ class VerifiersEnv(BaseEnv):
return scored_data, []
def _extract_from_state(
self, state: Any
) -> Tuple[List[int], List[int], List[float], float]:
"""
Extract tokens/masks/logprobs/score from a single rollout state.
Handles the mask convention conversion:
- Verifiers: prompt_mask=0, completion_mask=1
- Atropos: masked_tokens=-100 (prompt), token_id (completion)
"""
all_tokens: List[int] = []
all_masks: List[int] = []
all_logprobs: List[float] = []
trajectory = state.get("trajectory", [])
for step in trajectory:
tokens = step["tokens"]
prompt_ids = tokens["prompt_ids"]
completion_ids = tokens["completion_ids"]
completion_logprobs = tokens["completion_logprobs"]
all_tokens.extend(prompt_ids)
all_tokens.extend(completion_ids)
all_masks.extend([-100] * len(prompt_ids))
all_masks.extend(completion_ids)
all_logprobs.extend([1.0] * len(prompt_ids))
all_logprobs.extend(completion_logprobs)
reward = state["reward"]
return all_tokens, all_masks, all_logprobs, reward
if __name__ == "__main__":
VerifiersEnv.cli()