mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
rework server and eval for rl rollout. add in asyncmanagedserver for
verifiers
This commit is contained in:
parent
3449a4c23d
commit
cf636595d2
4 changed files with 652 additions and 91 deletions
330
atroposlib/envs/server_handling/atropos_managed_client.py
Normal file
330
atroposlib/envs/server_handling/atropos_managed_client.py
Normal file
|
|
@ -0,0 +1,330 @@
|
|||
"""
|
||||
AtroposManagedClient: AsyncOpenAI-compatible client backed by ManagedServer.
|
||||
|
||||
This module provides a drop-in replacement for AsyncOpenAI that uses Atropos's
|
||||
ManagedServer for inference, enabling token tracking for multi-turn RL training
|
||||
with the Verifiers library.
|
||||
|
||||
Usage:
|
||||
async with server_manager.managed_server(tokenizer=tokenizer) as managed:
|
||||
client = AtroposManagedClient(managed_server=managed, model="model-name")
|
||||
|
||||
# Use like AsyncOpenAI - tokens are tracked automatically
|
||||
response = await client.chat.completions.create(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
max_tokens=100
|
||||
)
|
||||
|
||||
# Token data is available on the response:
|
||||
# - response.prompt_token_ids
|
||||
# - response.choices[0].token_ids
|
||||
# - response.choices[0].logprobs.content[i].logprob
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServer, SequenceNode
|
||||
|
||||
# =============================================================================
|
||||
# Enhanced Types for Token Data Injection
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class LogprobContent:
|
||||
"""
|
||||
Single token logprob entry.
|
||||
|
||||
Compatible with verifiers' parse_response_tokens() which accesses:
|
||||
- response.choices[i].logprobs.content[j].logprob
|
||||
"""
|
||||
|
||||
logprob: float
|
||||
token: str = ""
|
||||
token_id: int = 0
|
||||
top_logprobs: Optional[List[Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChoiceLogprobs:
|
||||
"""
|
||||
Logprobs structure compatible with verifiers expectations.
|
||||
|
||||
Verifiers checks for either object or dict format:
|
||||
- Object: response.choices[i].logprobs.content[j].logprob
|
||||
- Dict: response.choices[i].logprobs["content"][j]["logprob"]
|
||||
|
||||
This dataclass supports the object format.
|
||||
"""
|
||||
|
||||
content: List[LogprobContent] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnhancedChoice:
|
||||
"""
|
||||
Choice with token_ids and logprobs for RL training.
|
||||
|
||||
Adds the following attributes that verifiers expects:
|
||||
- token_ids: List[int] - completion token IDs
|
||||
- logprobs: ChoiceLogprobs - structured logprobs
|
||||
"""
|
||||
|
||||
index: int
|
||||
message: ChatCompletionMessage
|
||||
finish_reason: str
|
||||
token_ids: List[int]
|
||||
logprobs: ChoiceLogprobs
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnhancedChatCompletion:
|
||||
"""
|
||||
ChatCompletion with token data for RL training.
|
||||
|
||||
Compatible with verifiers' parse_response_tokens() expectations:
|
||||
- prompt_token_ids: list[int]
|
||||
- choices[i].token_ids: list[int]
|
||||
- choices[i].logprobs.content[j].logprob
|
||||
"""
|
||||
|
||||
id: str
|
||||
created: int
|
||||
model: str
|
||||
object: str
|
||||
choices: List[EnhancedChoice]
|
||||
prompt_token_ids: List[int]
|
||||
usage: Optional[Dict[str, int]] = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# AsyncOpenAI-Compatible Client Classes
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class _CompletionsNamespace:
|
||||
"""
|
||||
Mimics openai.resources.chat.completions.AsyncCompletions.
|
||||
|
||||
Provides the create() method that verifiers calls.
|
||||
"""
|
||||
|
||||
def __init__(self, parent: "AtroposManagedClient"):
|
||||
self.parent = parent
|
||||
|
||||
async def create(
|
||||
self,
|
||||
*,
|
||||
messages: List[Dict[str, Any]],
|
||||
model: Optional[str] = None,
|
||||
n: int = 1,
|
||||
max_tokens: Optional[int] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
) -> EnhancedChatCompletion:
|
||||
"""
|
||||
Create chat completion with token tracking.
|
||||
|
||||
Returns ChatCompletion with additional attributes:
|
||||
- prompt_token_ids: list[int]
|
||||
- choices[i].token_ids: list[int]
|
||||
- choices[i].logprobs.content: list with logprob info
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'
|
||||
model: Model name (defaults to client's model)
|
||||
n: Number of completions (should be 1 for multi-turn)
|
||||
max_tokens: Max tokens in completion (legacy param)
|
||||
max_completion_tokens: Max tokens in completion (new param)
|
||||
temperature: Sampling temperature
|
||||
top_p: Nucleus sampling parameter
|
||||
tools: Tool definitions for function calling
|
||||
stop: Stop sequences
|
||||
**kwargs: Additional parameters passed to ManagedServer
|
||||
"""
|
||||
# Use max_completion_tokens if provided, else max_tokens
|
||||
effective_max_tokens = max_completion_tokens or max_tokens
|
||||
|
||||
# Build kwargs for ManagedServer
|
||||
completion_kwargs = {
|
||||
"messages": messages,
|
||||
"model": model or self.parent.model,
|
||||
"n": n,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
|
||||
if effective_max_tokens is not None:
|
||||
completion_kwargs["max_tokens"] = effective_max_tokens
|
||||
|
||||
if tools is not None:
|
||||
completion_kwargs["tools"] = tools
|
||||
|
||||
if stop is not None:
|
||||
completion_kwargs["stop"] = stop
|
||||
|
||||
# Add any extra kwargs (like logprobs settings)
|
||||
for key, value in kwargs.items():
|
||||
if value is not None:
|
||||
completion_kwargs[key] = value
|
||||
|
||||
# Call ManagedServer for inference
|
||||
completion = await self.parent.managed_server.chat_completion(
|
||||
**completion_kwargs
|
||||
)
|
||||
|
||||
# Get token state from managed server
|
||||
state = self.parent.managed_server.get_state()
|
||||
nodes: List[SequenceNode] = state["nodes"]
|
||||
|
||||
# Inject token data into response
|
||||
return self._enhance_completion(completion, nodes)
|
||||
|
||||
def _enhance_completion(
|
||||
self, completion: Any, nodes: List[SequenceNode]
|
||||
) -> EnhancedChatCompletion:
|
||||
"""
|
||||
Convert ManagedServer output to verifiers-compatible format.
|
||||
|
||||
Extracts token data from SequenceNodes and injects it into the
|
||||
ChatCompletion response in the format verifiers expects.
|
||||
"""
|
||||
enhanced_choices = []
|
||||
prompt_token_ids: List[int] = []
|
||||
|
||||
for i, (choice, node) in enumerate(zip(completion.choices, nodes)):
|
||||
# Find prompt/completion boundary from masked_tokens
|
||||
# -100 indicates prompt tokens, actual token IDs indicate completion
|
||||
prompt_len = sum(1 for m in node.masked_tokens if m == -100)
|
||||
|
||||
# Extract prompt and completion portions
|
||||
if i == 0:
|
||||
prompt_token_ids = node.tokens[:prompt_len]
|
||||
|
||||
completion_ids = node.tokens[prompt_len:]
|
||||
completion_logprobs = node.logprobs[prompt_len:]
|
||||
|
||||
# Build logprobs structure verifiers expects
|
||||
logprobs_content = []
|
||||
tokenizer = self.parent.managed_server.tokenizer
|
||||
|
||||
for token_id, logprob in zip(completion_ids, completion_logprobs):
|
||||
# Decode token to string if tokenizer available
|
||||
token_str = ""
|
||||
if tokenizer is not None:
|
||||
try:
|
||||
token_str = tokenizer.decode([token_id])
|
||||
except Exception:
|
||||
token_str = f"<token_{token_id}>"
|
||||
|
||||
logprobs_content.append(
|
||||
LogprobContent(
|
||||
logprob=logprob,
|
||||
token=token_str,
|
||||
token_id=token_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Create enhanced choice with token data
|
||||
enhanced_choice = EnhancedChoice(
|
||||
index=choice.index,
|
||||
message=choice.message,
|
||||
finish_reason=choice.finish_reason or "stop",
|
||||
token_ids=completion_ids,
|
||||
logprobs=ChoiceLogprobs(content=logprobs_content),
|
||||
)
|
||||
enhanced_choices.append(enhanced_choice)
|
||||
|
||||
return EnhancedChatCompletion(
|
||||
id=completion.id,
|
||||
created=completion.created,
|
||||
model=completion.model,
|
||||
object=completion.object,
|
||||
choices=enhanced_choices,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
usage=completion.usage.model_dump() if completion.usage else None,
|
||||
)
|
||||
|
||||
|
||||
class _ChatNamespace:
|
||||
"""Mimics openai.resources.chat.AsyncChat."""
|
||||
|
||||
def __init__(self, parent: "AtroposManagedClient"):
|
||||
self.completions = _CompletionsNamespace(parent)
|
||||
|
||||
|
||||
class AtroposManagedClient:
|
||||
"""
|
||||
AsyncOpenAI-compatible client backed by ManagedServer.
|
||||
|
||||
This client provides the same interface as AsyncOpenAI but uses Atropos's
|
||||
ManagedServer for inference, enabling automatic token tracking for
|
||||
multi-turn RL training with the Verifiers library.
|
||||
|
||||
The key feature is that responses include token data attributes that
|
||||
verifiers' parse_response_tokens() expects:
|
||||
- response.prompt_token_ids
|
||||
- response.choices[i].token_ids
|
||||
- response.choices[i].logprobs.content[j].logprob
|
||||
|
||||
Usage:
|
||||
async with server_manager.managed_server(tokenizer=tokenizer) as managed:
|
||||
client = AtroposManagedClient(
|
||||
managed_server=managed,
|
||||
model="Qwen/Qwen2.5-1.5B-Instruct"
|
||||
)
|
||||
|
||||
# Pass to verifiers env.rollout()
|
||||
state = await vf_env.rollout(
|
||||
input=rollout_input,
|
||||
client=client,
|
||||
model="Qwen/Qwen2.5-1.5B-Instruct",
|
||||
)
|
||||
|
||||
# Token data is now in state["trajectory"][i]["tokens"]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
managed_server: ManagedServer,
|
||||
model: str,
|
||||
base_url: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the managed client.
|
||||
|
||||
Args:
|
||||
managed_server: ManagedServer instance for inference and token tracking
|
||||
model: Model name to use for completions
|
||||
base_url: Optional base URL (for API compatibility, not used)
|
||||
"""
|
||||
self.managed_server = managed_server
|
||||
self.model = model
|
||||
self.base_url = base_url or "http://managed-server"
|
||||
|
||||
# Mimic AsyncOpenAI namespace structure
|
||||
self.chat = _ChatNamespace(self)
|
||||
|
||||
def reset(self):
|
||||
"""Reset token tracking state between rollouts."""
|
||||
self.managed_server.reset()
|
||||
|
||||
async def close(self):
|
||||
"""Compatibility method - no-op since ManagedServer handles cleanup."""
|
||||
pass
|
||||
|
||||
def copy(self, **_kwargs) -> "AtroposManagedClient":
|
||||
"""
|
||||
Create a copy of this client (for API compatibility).
|
||||
|
||||
Verifiers may call client.copy() for certain operations.
|
||||
Returns self since we want to maintain the same ManagedServer state.
|
||||
"""
|
||||
return self
|
||||
31
environments/configs/verifiers.yaml
Normal file
31
environments/configs/verifiers.yaml
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
# Verifiers environment configuration
|
||||
# Usage: python environments/verifiers_server.py serve --config environments/configs/verifiers.yaml
|
||||
#
|
||||
# For SFT data generation with external API:
|
||||
# python environments/verifiers_server.py process \
|
||||
# --env.vf_env_name primeintellect/gsm8k \
|
||||
# --env.data_path_to_save_groups output.jsonl \
|
||||
# --openai.base_url https://api.openai.com/v1 \
|
||||
# --openai.api_key $OPENAI_API_KEY \
|
||||
# --openai.model_name gpt-4o
|
||||
|
||||
env:
|
||||
vf_env_name: "primeintellect/gsm8k" # Prime Env Hub environment
|
||||
env_args: {}
|
||||
group_size: 8
|
||||
max_token_length: 2048
|
||||
tokenizer_name: "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
rollout_server_url: "http://localhost:8000"
|
||||
use_wandb: true
|
||||
wandb_name: "verifiers"
|
||||
total_steps: 1000
|
||||
batch_size: 4
|
||||
steps_per_eval: 100
|
||||
|
||||
openai:
|
||||
- model_name: "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
base_url: "http://localhost:9001/v1"
|
||||
api_key: "x"
|
||||
|
||||
slurm: false
|
||||
testing: false
|
||||
|
|
@ -18,15 +18,15 @@ Usage:
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import verifiers as vf
|
||||
import wandb
|
||||
from pydantic import Field
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
import wandb
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
|
|
@ -34,6 +34,36 @@ from atroposlib.envs.base import (
|
|||
)
|
||||
|
||||
|
||||
# Patch math_verify timeout to work in async context
|
||||
# The signal-based timeout doesn't work in non-main threads (asyncio event loop)
|
||||
def _no_signal_timeout(timeout_seconds: int):
|
||||
"""Replacement timeout decorator that doesn't use signals."""
|
||||
|
||||
def decorator(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
# Just call the function without timeout
|
||||
# This is safe because we're in an async context with our own timeouts
|
||||
# timeout_seconds is intentionally unused - we're replacing the timeout logic
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
try:
|
||||
import math_verify.grader
|
||||
import math_verify.parser
|
||||
import math_verify.utils
|
||||
|
||||
# Patch all modules that use the timeout decorator
|
||||
math_verify.utils.timeout = _no_signal_timeout
|
||||
math_verify.parser.timeout = _no_signal_timeout
|
||||
math_verify.grader.timeout = _no_signal_timeout
|
||||
except ImportError:
|
||||
pass # math_verify not installed
|
||||
|
||||
|
||||
class VerifiersEvaluationConfig(BaseEnvConfig):
|
||||
"""Configuration for Verifiers evaluation environment."""
|
||||
|
||||
|
|
@ -91,13 +121,30 @@ class VerifiersEvaluationEnv(BaseEnv):
|
|||
self.vf_env = vf.load_environment(config.vf_env_name, **config.env_args)
|
||||
self.rubric = self.vf_env.rubric
|
||||
|
||||
# Extract rubric components
|
||||
# Extract rubric components from RubricGroup
|
||||
# RubricGroup.funcs is empty - need to collect from individual rubrics
|
||||
self.parser = self.rubric.parser
|
||||
self.reward_funcs: List[Callable] = []
|
||||
self.reward_weights: List[float] = []
|
||||
self.rubric_class_objects: List[Dict[str, Any]] = [] # class_objects per func
|
||||
|
||||
if hasattr(self.rubric, "rubrics"):
|
||||
# RubricGroup: collect from all individual rubrics
|
||||
for rubric in self.rubric.rubrics:
|
||||
class_objects = getattr(rubric, "class_objects", {})
|
||||
for func, weight in zip(rubric.funcs, rubric.weights):
|
||||
self.reward_funcs.append(func)
|
||||
self.reward_weights.append(weight)
|
||||
self.rubric_class_objects.append(class_objects)
|
||||
else:
|
||||
# Single Rubric
|
||||
self.reward_funcs = self.rubric.funcs
|
||||
self.reward_weights = self.rubric.weights
|
||||
self.reward_scales = [
|
||||
weight / sum(self.reward_weights) for weight in self.reward_weights
|
||||
]
|
||||
class_objects = getattr(self.rubric, "class_objects", {})
|
||||
self.rubric_class_objects = [class_objects] * len(self.rubric.funcs)
|
||||
|
||||
total_weight = sum(self.reward_weights) if self.reward_weights else 1.0
|
||||
self.reward_scales = [weight / total_weight for weight in self.reward_weights]
|
||||
self.system_prompt = self.vf_env.system_prompt
|
||||
|
||||
# Tracking
|
||||
|
|
@ -192,14 +239,40 @@ class VerifiersEvaluationEnv(BaseEnv):
|
|||
# Parse answer
|
||||
answer_parsed = self.parser.parse_answer(completion=response_text)
|
||||
|
||||
# Score using reward funcs
|
||||
# Score using reward funcs (async functions need await)
|
||||
# Use signature introspection to pass only required params (like verifiers does)
|
||||
rewards = []
|
||||
for func in self.reward_funcs:
|
||||
reward = func(
|
||||
parser=self.parser,
|
||||
completion=completion_messages,
|
||||
answer=answer,
|
||||
)
|
||||
for i, func in enumerate(self.reward_funcs):
|
||||
try:
|
||||
# Build merged dict of all possible parameters
|
||||
class_objects = self.rubric_class_objects[i]
|
||||
merged = {
|
||||
"completion": completion_messages,
|
||||
"answer": answer,
|
||||
"prompt": question,
|
||||
}
|
||||
merged.update(class_objects) # Adds parser, etc.
|
||||
|
||||
# Filter to only params the function accepts
|
||||
sig = inspect.signature(func)
|
||||
if any(p.kind == p.VAR_KEYWORD for p in sig.parameters.values()):
|
||||
# Function accepts **kwargs, pass everything
|
||||
kwargs = merged
|
||||
else:
|
||||
# Only pass params in signature
|
||||
kwargs = {k: v for k, v in merged.items() if k in sig.parameters}
|
||||
|
||||
result = func(**kwargs)
|
||||
# Reward functions may be async coroutines
|
||||
if asyncio.iscoroutine(result):
|
||||
reward = await result
|
||||
else:
|
||||
reward = result
|
||||
reward = float(reward)
|
||||
except Exception as e:
|
||||
if self.config.full_debug:
|
||||
print(f" Reward func {func.__name__} error: {e}")
|
||||
reward = 0.0
|
||||
rewards.append(reward)
|
||||
|
||||
weighted_rewards = [r * self.reward_scales[j] for j, r in enumerate(rewards)]
|
||||
|
|
@ -235,11 +308,14 @@ class VerifiersEvaluationEnv(BaseEnv):
|
|||
|
||||
start_time = time.time()
|
||||
|
||||
# Create evaluation tasks
|
||||
tasks = [self.rollout_and_score(item) for item in self.eval_items]
|
||||
# Run sequentially to avoid signal/threading issues with math_verify parser
|
||||
# The parser uses signals for timeouts which only work in main thread
|
||||
from tqdm import tqdm
|
||||
|
||||
# Run with progress bar
|
||||
results = await tqdm_asyncio.gather(*tasks, desc="Evaluating")
|
||||
results = []
|
||||
for item in tqdm(self.eval_items, desc="Evaluating"):
|
||||
result = await self.rollout_and_score(item)
|
||||
results.append(result)
|
||||
|
||||
# Filter out failed results
|
||||
valid_results = [r for r in results if r is not None]
|
||||
|
|
|
|||
|
|
@ -39,10 +39,12 @@ To install a Verifiers/Prime environment:
|
|||
Docs: https://docs.primeintellect.ai/tutorials-environments/install
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import verifiers as vf
|
||||
from openai import AsyncOpenAI
|
||||
from pydantic import Field
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
|
|
@ -82,11 +84,21 @@ class VerifiersEnv(BaseEnv):
|
|||
self.rubric = self.vf_env.rubric
|
||||
|
||||
self.parser = self.rubric.parser
|
||||
|
||||
# Handle both single Rubric and RubricGroup (composite)
|
||||
# RubricGroup has empty funcs/weights at top level - must extract from individual rubrics
|
||||
if hasattr(self.rubric, "rubrics"):
|
||||
self.reward_funcs = []
|
||||
self.reward_weights = []
|
||||
for rubric in self.rubric.rubrics:
|
||||
self.reward_funcs.extend(rubric.funcs)
|
||||
self.reward_weights.extend(rubric.weights)
|
||||
else:
|
||||
self.reward_funcs = self.rubric.funcs
|
||||
self.reward_weights = self.rubric.weights
|
||||
self.reward_scales = [
|
||||
weight / sum(self.reward_weights) for weight in self.reward_weights
|
||||
]
|
||||
|
||||
total = sum(self.reward_weights) if self.reward_weights else 1.0
|
||||
self.reward_scales = [weight / total for weight in self.reward_weights]
|
||||
self.system_prompt = self.vf_env.system_prompt
|
||||
|
||||
@classmethod
|
||||
|
|
@ -135,9 +147,15 @@ class VerifiersEnv(BaseEnv):
|
|||
|
||||
async def setup(self):
|
||||
train_data = self.vf_env.get_dataset()
|
||||
self.train = train_data.select_columns(["question", "answer"]).to_list()
|
||||
# Only load columns we need to avoid memory bloat
|
||||
columns_to_keep = ["question", "answer", "info"]
|
||||
available_columns = [c for c in columns_to_keep if c in train_data.column_names]
|
||||
self.train = train_data.select_columns(available_columns).to_list()
|
||||
test_data = self.vf_env.get_eval_dataset()
|
||||
self.test = test_data.select_columns(["question", "answer"]).to_list()
|
||||
available_test_columns = [
|
||||
c for c in columns_to_keep if c in test_data.column_names
|
||||
]
|
||||
self.test = test_data.select_columns(available_test_columns).to_list()
|
||||
self.iter = 0
|
||||
|
||||
def save_checkpoint(self, step, data=None):
|
||||
|
|
@ -254,70 +272,116 @@ class VerifiersEnv(BaseEnv):
|
|||
|
||||
async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, list]:
|
||||
"""
|
||||
Collect trajectories - automatically switches between:
|
||||
- ManagedServer (for RL training with local server - requires logprobs)
|
||||
- tokenize_for_trainer (for SFT datagen with any API - no logprobs needed)
|
||||
Collect trajectories - switches between:
|
||||
- SFT data generation (process mode): Any API, no logprobs needed
|
||||
- RL training (serve mode): Local server with logprobs
|
||||
"""
|
||||
question = item["question"]
|
||||
answer = item["answer"]
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": self.system_prompt},
|
||||
{"role": "user", "content": question},
|
||||
]
|
||||
|
||||
# Check if we're in process mode (SFT data generation)
|
||||
is_process_mode = getattr(self, "process_mode", False)
|
||||
|
||||
if is_process_mode:
|
||||
return await self._collect_trajectories_for_sft(messages, answer)
|
||||
return await self._collect_trajectories_for_sft(item)
|
||||
else:
|
||||
return await self._collect_trajectories_for_training(messages, answer)
|
||||
return await self._collect_trajectories_for_rl(item)
|
||||
|
||||
async def _collect_trajectories_for_sft(
|
||||
self, messages: List[Dict], answer: str
|
||||
self, item: Dict[str, Any]
|
||||
) -> Tuple[ScoredDataGroup, list]:
|
||||
"""
|
||||
SFT data generation mode - works with ANY API (OpenAI, Claude, local).
|
||||
Does NOT require logprobs or local server.
|
||||
|
||||
Uses tokenize_for_trainer to tokenize completions with your training
|
||||
tokenizer, so the resulting data is ready for fine-tuning your target model.
|
||||
Uses verifiers rollout() for multi-turn environments and tokenize_for_trainer
|
||||
to tokenize completions with your training tokenizer.
|
||||
"""
|
||||
completions = await self.server.chat_completion(
|
||||
messages=messages,
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=1.0,
|
||||
question = item["question"]
|
||||
answer = item["answer"]
|
||||
|
||||
# Build initial messages
|
||||
initial_messages: List[Dict[str, str]] = []
|
||||
if self.system_prompt:
|
||||
initial_messages.append({"role": "system", "content": self.system_prompt})
|
||||
initial_messages.append({"role": "user", "content": question})
|
||||
|
||||
# Create AsyncOpenAI client directly from server config (no ManagedServer needed)
|
||||
server_config = self.server.servers[0].config
|
||||
client = AsyncOpenAI(
|
||||
api_key=server_config.api_key,
|
||||
base_url=server_config.base_url,
|
||||
timeout=server_config.timeout,
|
||||
)
|
||||
|
||||
# Sampling args - use max_completion_tokens for newer models like gpt-5
|
||||
sampling_args = {
|
||||
"temperature": 1.0,
|
||||
"max_completion_tokens": self.config.max_token_length,
|
||||
}
|
||||
|
||||
scored_data = ScoredDataGroup()
|
||||
scored_data["tokens"] = []
|
||||
scored_data["masks"] = []
|
||||
scored_data["scores"] = []
|
||||
scored_data["messages"] = []
|
||||
# Note: No inference_logprobs - not needed/available for SFT
|
||||
|
||||
for choice in completions.choices:
|
||||
response = choice.message.content or ""
|
||||
finish_reason = choice.finish_reason or ""
|
||||
# Semaphore for scoring (required by rubric.score_rollout)
|
||||
score_sem = asyncio.Semaphore(1)
|
||||
|
||||
# Build full conversation for scoring and tokenization
|
||||
completion_messages = messages + [
|
||||
{"role": "assistant", "content": response}
|
||||
# Run rollouts in parallel for group_size
|
||||
async def run_single_rollout(example_id: int):
|
||||
# Pass through any info from the dataset item (e.g., docker_image for SWE envs)
|
||||
item_info = item.get("info", {})
|
||||
rollout_input = {
|
||||
"prompt": initial_messages,
|
||||
"answer": answer,
|
||||
"example_id": example_id,
|
||||
"task": self.config.vf_env_name,
|
||||
"info": item_info,
|
||||
}
|
||||
state = await self.vf_env.rollout(
|
||||
input=rollout_input,
|
||||
client=client,
|
||||
model=server_config.model_name,
|
||||
sampling_args=sampling_args,
|
||||
)
|
||||
# Score the rollout using verifiers rubric (computes reward from test output)
|
||||
# This is needed because vf_env.rollout() doesn't call score_rollout
|
||||
await self.rubric.score_rollout(state, score_sem=score_sem)
|
||||
return state
|
||||
|
||||
# Run group_size rollouts in parallel
|
||||
rollout_tasks = [run_single_rollout(i) for i in range(self.config.group_size)]
|
||||
states = await asyncio.gather(*rollout_tasks)
|
||||
|
||||
for state in states:
|
||||
# Extract completion messages from state
|
||||
completion_messages = list(state.get("prompt", [])) + list(
|
||||
state.get("completion", [])
|
||||
)
|
||||
# Ensure all message contents are strings (not None)
|
||||
# This can happen with tool call messages that have content: null
|
||||
completion_messages = [
|
||||
{**msg, "content": msg.get("content") or ""}
|
||||
for msg in completion_messages
|
||||
]
|
||||
|
||||
# Score using verifiers reward funcs
|
||||
score = self._compute_score(completion_messages, answer)
|
||||
# Get reward from verifiers scoring (set by rubric.score_rollout above)
|
||||
score = state.get("reward", 0.0)
|
||||
|
||||
# Determine finish reason from last trajectory step
|
||||
trajectory = state.get("trajectory", [])
|
||||
if trajectory:
|
||||
finish_reason = trajectory[-1]["response"].choices[0].finish_reason
|
||||
else:
|
||||
finish_reason = "stop"
|
||||
|
||||
# Use tokenize_for_trainer for tokenization
|
||||
# This uses YOUR training tokenizer (e.g., Qwen, Llama), not the API's tokenizer
|
||||
# So GPT-4o responses get tokenized for your target model
|
||||
# train_on_all_assistant_turns=True ensures ALL assistant turns are unmasked
|
||||
# for multi-turn environments, not just the last message
|
||||
tokenized = tokenize_for_trainer(
|
||||
tokenizer=self.tokenizer,
|
||||
chat=completion_messages,
|
||||
include_messages=True,
|
||||
finish_reason=finish_reason,
|
||||
train_on_all_assistant_turns=True,
|
||||
)
|
||||
|
||||
scored_data["tokens"].append(tokenized["tokens"])
|
||||
|
|
@ -331,48 +395,72 @@ class VerifiersEnv(BaseEnv):
|
|||
|
||||
return scored_data, []
|
||||
|
||||
async def _collect_trajectories_for_training(
|
||||
self, messages: List[Dict], answer: str
|
||||
async def _collect_trajectories_for_rl(
|
||||
self, item: Dict[str, Any]
|
||||
) -> Tuple[ScoredDataGroup, list]:
|
||||
"""
|
||||
RL training mode - requires local inference server.
|
||||
Uses ManagedServer for proper token/logprob alignment.
|
||||
|
||||
The inference_logprobs are required for policy gradient methods like
|
||||
GRPO, PPO, REINFORCE, etc.
|
||||
RL training mode - requires local inference server for logprobs.
|
||||
Uses AtroposManagedClient with vf_env.rollout() for both single-turn and multi-turn.
|
||||
"""
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
completions = await managed.chat_completion(
|
||||
messages=messages,
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=1.0,
|
||||
from atroposlib.envs.server_handling.atropos_managed_client import (
|
||||
AtroposManagedClient,
|
||||
)
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
question = item["question"]
|
||||
answer = item["answer"]
|
||||
item_info = item.get("info", {})
|
||||
|
||||
initial_messages: List[Dict[str, str]] = []
|
||||
if self.system_prompt:
|
||||
initial_messages.append({"role": "system", "content": self.system_prompt})
|
||||
initial_messages.append({"role": "user", "content": question})
|
||||
|
||||
sampling_args = {
|
||||
"temperature": 1.0,
|
||||
"max_completion_tokens": self.config.max_token_length,
|
||||
}
|
||||
|
||||
scored_data = ScoredDataGroup()
|
||||
scored_data["tokens"] = []
|
||||
scored_data["masks"] = []
|
||||
scored_data["scores"] = []
|
||||
scored_data["inference_logprobs"] = [] # Required for RL training!
|
||||
scored_data["inference_logprobs"] = []
|
||||
|
||||
for i, choice in enumerate(completions.choices):
|
||||
response = choice.message.content or ""
|
||||
# Semaphore for scoring (required by rubric.score_rollout)
|
||||
score_sem = asyncio.Semaphore(1)
|
||||
|
||||
# Build full conversation for scoring
|
||||
completion_messages = messages + [
|
||||
{"role": "assistant", "content": response}
|
||||
]
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
client = AtroposManagedClient(
|
||||
managed_server=managed,
|
||||
model=self.server_configs[0].model_name,
|
||||
)
|
||||
|
||||
# Score using verifiers reward funcs
|
||||
score = self._compute_score(completion_messages, answer)
|
||||
# Run group_size rollouts sequentially (ManagedServer state must be reset between)
|
||||
for i in range(self.config.group_size):
|
||||
client.reset()
|
||||
|
||||
# Use ManagedServer's properly aligned tokens/masks/logprobs
|
||||
node = nodes[i]
|
||||
scored_data["tokens"].append(node.tokens)
|
||||
scored_data["masks"].append(node.masked_tokens)
|
||||
scored_data["inference_logprobs"].append(node.logprobs)
|
||||
rollout_input = {
|
||||
"prompt": initial_messages,
|
||||
"answer": answer,
|
||||
"example_id": i,
|
||||
"task": self.config.vf_env_name,
|
||||
"info": item_info,
|
||||
}
|
||||
|
||||
state = await self.vf_env.rollout(
|
||||
input=rollout_input,
|
||||
client=client,
|
||||
model=self.server_configs[0].model_name,
|
||||
sampling_args=sampling_args,
|
||||
)
|
||||
|
||||
# Score the rollout (computes reward from test output)
|
||||
await self.rubric.score_rollout(state, score_sem=score_sem)
|
||||
|
||||
tokens, masks, logprobs, score = self._extract_from_state(state)
|
||||
scored_data["tokens"].append(tokens)
|
||||
scored_data["masks"].append(masks)
|
||||
scored_data["inference_logprobs"].append(logprobs)
|
||||
scored_data["scores"].append(score)
|
||||
|
||||
# Track scores for wandb logging
|
||||
|
|
@ -381,6 +469,42 @@ class VerifiersEnv(BaseEnv):
|
|||
|
||||
return scored_data, []
|
||||
|
||||
def _extract_from_state(
|
||||
self, state: Any
|
||||
) -> Tuple[List[int], List[int], List[float], float]:
|
||||
"""
|
||||
Extract tokens/masks/logprobs/score from a single rollout state.
|
||||
|
||||
Handles the mask convention conversion:
|
||||
- Verifiers: prompt_mask=0, completion_mask=1
|
||||
- Atropos: masked_tokens=-100 (prompt), token_id (completion)
|
||||
"""
|
||||
all_tokens: List[int] = []
|
||||
all_masks: List[int] = []
|
||||
all_logprobs: List[float] = []
|
||||
|
||||
trajectory = state.get("trajectory", [])
|
||||
|
||||
for step in trajectory:
|
||||
tokens = step["tokens"]
|
||||
|
||||
prompt_ids = tokens["prompt_ids"]
|
||||
completion_ids = tokens["completion_ids"]
|
||||
completion_logprobs = tokens["completion_logprobs"]
|
||||
|
||||
all_tokens.extend(prompt_ids)
|
||||
all_tokens.extend(completion_ids)
|
||||
|
||||
all_masks.extend([-100] * len(prompt_ids))
|
||||
all_masks.extend(completion_ids)
|
||||
|
||||
all_logprobs.extend([1.0] * len(prompt_ids))
|
||||
all_logprobs.extend(completion_logprobs)
|
||||
|
||||
reward = state["reward"]
|
||||
|
||||
return all_tokens, all_masks, all_logprobs, reward
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
VerifiersEnv.cli()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue