From cf636595d22696cc55537d99a623874924131bd0 Mon Sep 17 00:00:00 2001 From: "balyan.sid@gmail.com" Date: Sat, 10 Jan 2026 14:55:08 +0530 Subject: [PATCH] rework server and eval for rl rollout. add in asyncmanagedserver for verifiers --- .../server_handling/atropos_managed_client.py | 330 ++++++++++++++++++ environments/configs/verifiers.yaml | 31 ++ .../eval_environments/verifiers_eval.py | 116 ++++-- environments/verifiers_server.py | 266 ++++++++++---- 4 files changed, 652 insertions(+), 91 deletions(-) create mode 100644 atroposlib/envs/server_handling/atropos_managed_client.py create mode 100644 environments/configs/verifiers.yaml diff --git a/atroposlib/envs/server_handling/atropos_managed_client.py b/atroposlib/envs/server_handling/atropos_managed_client.py new file mode 100644 index 00000000..75e78e39 --- /dev/null +++ b/atroposlib/envs/server_handling/atropos_managed_client.py @@ -0,0 +1,330 @@ +""" +AtroposManagedClient: AsyncOpenAI-compatible client backed by ManagedServer. + +This module provides a drop-in replacement for AsyncOpenAI that uses Atropos's +ManagedServer for inference, enabling token tracking for multi-turn RL training +with the Verifiers library. + +Usage: + async with server_manager.managed_server(tokenizer=tokenizer) as managed: + client = AtroposManagedClient(managed_server=managed, model="model-name") + + # Use like AsyncOpenAI - tokens are tracked automatically + response = await client.chat.completions.create( + messages=[{"role": "user", "content": "Hello"}], + max_tokens=100 + ) + + # Token data is available on the response: + # - response.prompt_token_ids + # - response.choices[0].token_ids + # - response.choices[0].logprobs.content[i].logprob +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +from openai.types.chat.chat_completion_message import ChatCompletionMessage + +from atroposlib.envs.server_handling.managed_server import ManagedServer, SequenceNode + +# ============================================================================= +# Enhanced Types for Token Data Injection +# ============================================================================= + + +@dataclass +class LogprobContent: + """ + Single token logprob entry. + + Compatible with verifiers' parse_response_tokens() which accesses: + - response.choices[i].logprobs.content[j].logprob + """ + + logprob: float + token: str = "" + token_id: int = 0 + top_logprobs: Optional[List[Any]] = None + + +@dataclass +class ChoiceLogprobs: + """ + Logprobs structure compatible with verifiers expectations. + + Verifiers checks for either object or dict format: + - Object: response.choices[i].logprobs.content[j].logprob + - Dict: response.choices[i].logprobs["content"][j]["logprob"] + + This dataclass supports the object format. + """ + + content: List[LogprobContent] = field(default_factory=list) + + +@dataclass +class EnhancedChoice: + """ + Choice with token_ids and logprobs for RL training. + + Adds the following attributes that verifiers expects: + - token_ids: List[int] - completion token IDs + - logprobs: ChoiceLogprobs - structured logprobs + """ + + index: int + message: ChatCompletionMessage + finish_reason: str + token_ids: List[int] + logprobs: ChoiceLogprobs + + +@dataclass +class EnhancedChatCompletion: + """ + ChatCompletion with token data for RL training. + + Compatible with verifiers' parse_response_tokens() expectations: + - prompt_token_ids: list[int] + - choices[i].token_ids: list[int] + - choices[i].logprobs.content[j].logprob + """ + + id: str + created: int + model: str + object: str + choices: List[EnhancedChoice] + prompt_token_ids: List[int] + usage: Optional[Dict[str, int]] = None + + +# ============================================================================= +# AsyncOpenAI-Compatible Client Classes +# ============================================================================= + + +class _CompletionsNamespace: + """ + Mimics openai.resources.chat.completions.AsyncCompletions. + + Provides the create() method that verifiers calls. + """ + + def __init__(self, parent: "AtroposManagedClient"): + self.parent = parent + + async def create( + self, + *, + messages: List[Dict[str, Any]], + model: Optional[str] = None, + n: int = 1, + max_tokens: Optional[int] = None, + max_completion_tokens: Optional[int] = None, + temperature: float = 1.0, + top_p: float = 1.0, + tools: Optional[List[Dict]] = None, + stop: Optional[List[str]] = None, + **kwargs, + ) -> EnhancedChatCompletion: + """ + Create chat completion with token tracking. + + Returns ChatCompletion with additional attributes: + - prompt_token_ids: list[int] + - choices[i].token_ids: list[int] + - choices[i].logprobs.content: list with logprob info + + Args: + messages: List of message dicts with 'role' and 'content' + model: Model name (defaults to client's model) + n: Number of completions (should be 1 for multi-turn) + max_tokens: Max tokens in completion (legacy param) + max_completion_tokens: Max tokens in completion (new param) + temperature: Sampling temperature + top_p: Nucleus sampling parameter + tools: Tool definitions for function calling + stop: Stop sequences + **kwargs: Additional parameters passed to ManagedServer + """ + # Use max_completion_tokens if provided, else max_tokens + effective_max_tokens = max_completion_tokens or max_tokens + + # Build kwargs for ManagedServer + completion_kwargs = { + "messages": messages, + "model": model or self.parent.model, + "n": n, + "temperature": temperature, + "top_p": top_p, + } + + if effective_max_tokens is not None: + completion_kwargs["max_tokens"] = effective_max_tokens + + if tools is not None: + completion_kwargs["tools"] = tools + + if stop is not None: + completion_kwargs["stop"] = stop + + # Add any extra kwargs (like logprobs settings) + for key, value in kwargs.items(): + if value is not None: + completion_kwargs[key] = value + + # Call ManagedServer for inference + completion = await self.parent.managed_server.chat_completion( + **completion_kwargs + ) + + # Get token state from managed server + state = self.parent.managed_server.get_state() + nodes: List[SequenceNode] = state["nodes"] + + # Inject token data into response + return self._enhance_completion(completion, nodes) + + def _enhance_completion( + self, completion: Any, nodes: List[SequenceNode] + ) -> EnhancedChatCompletion: + """ + Convert ManagedServer output to verifiers-compatible format. + + Extracts token data from SequenceNodes and injects it into the + ChatCompletion response in the format verifiers expects. + """ + enhanced_choices = [] + prompt_token_ids: List[int] = [] + + for i, (choice, node) in enumerate(zip(completion.choices, nodes)): + # Find prompt/completion boundary from masked_tokens + # -100 indicates prompt tokens, actual token IDs indicate completion + prompt_len = sum(1 for m in node.masked_tokens if m == -100) + + # Extract prompt and completion portions + if i == 0: + prompt_token_ids = node.tokens[:prompt_len] + + completion_ids = node.tokens[prompt_len:] + completion_logprobs = node.logprobs[prompt_len:] + + # Build logprobs structure verifiers expects + logprobs_content = [] + tokenizer = self.parent.managed_server.tokenizer + + for token_id, logprob in zip(completion_ids, completion_logprobs): + # Decode token to string if tokenizer available + token_str = "" + if tokenizer is not None: + try: + token_str = tokenizer.decode([token_id]) + except Exception: + token_str = f"" + + logprobs_content.append( + LogprobContent( + logprob=logprob, + token=token_str, + token_id=token_id, + ) + ) + + # Create enhanced choice with token data + enhanced_choice = EnhancedChoice( + index=choice.index, + message=choice.message, + finish_reason=choice.finish_reason or "stop", + token_ids=completion_ids, + logprobs=ChoiceLogprobs(content=logprobs_content), + ) + enhanced_choices.append(enhanced_choice) + + return EnhancedChatCompletion( + id=completion.id, + created=completion.created, + model=completion.model, + object=completion.object, + choices=enhanced_choices, + prompt_token_ids=prompt_token_ids, + usage=completion.usage.model_dump() if completion.usage else None, + ) + + +class _ChatNamespace: + """Mimics openai.resources.chat.AsyncChat.""" + + def __init__(self, parent: "AtroposManagedClient"): + self.completions = _CompletionsNamespace(parent) + + +class AtroposManagedClient: + """ + AsyncOpenAI-compatible client backed by ManagedServer. + + This client provides the same interface as AsyncOpenAI but uses Atropos's + ManagedServer for inference, enabling automatic token tracking for + multi-turn RL training with the Verifiers library. + + The key feature is that responses include token data attributes that + verifiers' parse_response_tokens() expects: + - response.prompt_token_ids + - response.choices[i].token_ids + - response.choices[i].logprobs.content[j].logprob + + Usage: + async with server_manager.managed_server(tokenizer=tokenizer) as managed: + client = AtroposManagedClient( + managed_server=managed, + model="Qwen/Qwen2.5-1.5B-Instruct" + ) + + # Pass to verifiers env.rollout() + state = await vf_env.rollout( + input=rollout_input, + client=client, + model="Qwen/Qwen2.5-1.5B-Instruct", + ) + + # Token data is now in state["trajectory"][i]["tokens"] + """ + + def __init__( + self, + managed_server: ManagedServer, + model: str, + base_url: Optional[str] = None, + ): + """ + Initialize the managed client. + + Args: + managed_server: ManagedServer instance for inference and token tracking + model: Model name to use for completions + base_url: Optional base URL (for API compatibility, not used) + """ + self.managed_server = managed_server + self.model = model + self.base_url = base_url or "http://managed-server" + + # Mimic AsyncOpenAI namespace structure + self.chat = _ChatNamespace(self) + + def reset(self): + """Reset token tracking state between rollouts.""" + self.managed_server.reset() + + async def close(self): + """Compatibility method - no-op since ManagedServer handles cleanup.""" + pass + + def copy(self, **_kwargs) -> "AtroposManagedClient": + """ + Create a copy of this client (for API compatibility). + + Verifiers may call client.copy() for certain operations. + Returns self since we want to maintain the same ManagedServer state. + """ + return self diff --git a/environments/configs/verifiers.yaml b/environments/configs/verifiers.yaml new file mode 100644 index 00000000..91ef7ec2 --- /dev/null +++ b/environments/configs/verifiers.yaml @@ -0,0 +1,31 @@ +# Verifiers environment configuration +# Usage: python environments/verifiers_server.py serve --config environments/configs/verifiers.yaml +# +# For SFT data generation with external API: +# python environments/verifiers_server.py process \ +# --env.vf_env_name primeintellect/gsm8k \ +# --env.data_path_to_save_groups output.jsonl \ +# --openai.base_url https://api.openai.com/v1 \ +# --openai.api_key $OPENAI_API_KEY \ +# --openai.model_name gpt-4o + +env: + vf_env_name: "primeintellect/gsm8k" # Prime Env Hub environment + env_args: {} + group_size: 8 + max_token_length: 2048 + tokenizer_name: "Qwen/Qwen2.5-1.5B-Instruct" + rollout_server_url: "http://localhost:8000" + use_wandb: true + wandb_name: "verifiers" + total_steps: 1000 + batch_size: 4 + steps_per_eval: 100 + +openai: + - model_name: "Qwen/Qwen2.5-1.5B-Instruct" + base_url: "http://localhost:9001/v1" + api_key: "x" + +slurm: false +testing: false diff --git a/environments/eval_environments/verifiers_eval.py b/environments/eval_environments/verifiers_eval.py index 98b95ffa..c5fb1063 100644 --- a/environments/eval_environments/verifiers_eval.py +++ b/environments/eval_environments/verifiers_eval.py @@ -18,15 +18,15 @@ Usage: """ import asyncio +import inspect import os import time -from typing import Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import verifiers as vf -import wandb from pydantic import Field -from tqdm.asyncio import tqdm_asyncio +import wandb from atroposlib.envs.base import ( APIServerConfig, BaseEnv, @@ -34,6 +34,36 @@ from atroposlib.envs.base import ( ) +# Patch math_verify timeout to work in async context +# The signal-based timeout doesn't work in non-main threads (asyncio event loop) +def _no_signal_timeout(timeout_seconds: int): + """Replacement timeout decorator that doesn't use signals.""" + + def decorator(func): + def wrapper(*args, **kwargs): + # Just call the function without timeout + # This is safe because we're in an async context with our own timeouts + # timeout_seconds is intentionally unused - we're replacing the timeout logic + return func(*args, **kwargs) + + return wrapper + + return decorator + + +try: + import math_verify.grader + import math_verify.parser + import math_verify.utils + + # Patch all modules that use the timeout decorator + math_verify.utils.timeout = _no_signal_timeout + math_verify.parser.timeout = _no_signal_timeout + math_verify.grader.timeout = _no_signal_timeout +except ImportError: + pass # math_verify not installed + + class VerifiersEvaluationConfig(BaseEnvConfig): """Configuration for Verifiers evaluation environment.""" @@ -91,13 +121,30 @@ class VerifiersEvaluationEnv(BaseEnv): self.vf_env = vf.load_environment(config.vf_env_name, **config.env_args) self.rubric = self.vf_env.rubric - # Extract rubric components + # Extract rubric components from RubricGroup + # RubricGroup.funcs is empty - need to collect from individual rubrics self.parser = self.rubric.parser - self.reward_funcs = self.rubric.funcs - self.reward_weights = self.rubric.weights - self.reward_scales = [ - weight / sum(self.reward_weights) for weight in self.reward_weights - ] + self.reward_funcs: List[Callable] = [] + self.reward_weights: List[float] = [] + self.rubric_class_objects: List[Dict[str, Any]] = [] # class_objects per func + + if hasattr(self.rubric, "rubrics"): + # RubricGroup: collect from all individual rubrics + for rubric in self.rubric.rubrics: + class_objects = getattr(rubric, "class_objects", {}) + for func, weight in zip(rubric.funcs, rubric.weights): + self.reward_funcs.append(func) + self.reward_weights.append(weight) + self.rubric_class_objects.append(class_objects) + else: + # Single Rubric + self.reward_funcs = self.rubric.funcs + self.reward_weights = self.rubric.weights + class_objects = getattr(self.rubric, "class_objects", {}) + self.rubric_class_objects = [class_objects] * len(self.rubric.funcs) + + total_weight = sum(self.reward_weights) if self.reward_weights else 1.0 + self.reward_scales = [weight / total_weight for weight in self.reward_weights] self.system_prompt = self.vf_env.system_prompt # Tracking @@ -192,14 +239,40 @@ class VerifiersEvaluationEnv(BaseEnv): # Parse answer answer_parsed = self.parser.parse_answer(completion=response_text) - # Score using reward funcs + # Score using reward funcs (async functions need await) + # Use signature introspection to pass only required params (like verifiers does) rewards = [] - for func in self.reward_funcs: - reward = func( - parser=self.parser, - completion=completion_messages, - answer=answer, - ) + for i, func in enumerate(self.reward_funcs): + try: + # Build merged dict of all possible parameters + class_objects = self.rubric_class_objects[i] + merged = { + "completion": completion_messages, + "answer": answer, + "prompt": question, + } + merged.update(class_objects) # Adds parser, etc. + + # Filter to only params the function accepts + sig = inspect.signature(func) + if any(p.kind == p.VAR_KEYWORD for p in sig.parameters.values()): + # Function accepts **kwargs, pass everything + kwargs = merged + else: + # Only pass params in signature + kwargs = {k: v for k, v in merged.items() if k in sig.parameters} + + result = func(**kwargs) + # Reward functions may be async coroutines + if asyncio.iscoroutine(result): + reward = await result + else: + reward = result + reward = float(reward) + except Exception as e: + if self.config.full_debug: + print(f" Reward func {func.__name__} error: {e}") + reward = 0.0 rewards.append(reward) weighted_rewards = [r * self.reward_scales[j] for j, r in enumerate(rewards)] @@ -235,11 +308,14 @@ class VerifiersEvaluationEnv(BaseEnv): start_time = time.time() - # Create evaluation tasks - tasks = [self.rollout_and_score(item) for item in self.eval_items] + # Run sequentially to avoid signal/threading issues with math_verify parser + # The parser uses signals for timeouts which only work in main thread + from tqdm import tqdm - # Run with progress bar - results = await tqdm_asyncio.gather(*tasks, desc="Evaluating") + results = [] + for item in tqdm(self.eval_items, desc="Evaluating"): + result = await self.rollout_and_score(item) + results.append(result) # Filter out failed results valid_results = [r for r in results if r is not None] diff --git a/environments/verifiers_server.py b/environments/verifiers_server.py index 434e1368..42d77c1c 100644 --- a/environments/verifiers_server.py +++ b/environments/verifiers_server.py @@ -39,10 +39,12 @@ To install a Verifiers/Prime environment: Docs: https://docs.primeintellect.ai/tutorials-environments/install """ +import asyncio import time from typing import Any, Dict, List, Optional, Tuple import verifiers as vf +from openai import AsyncOpenAI from pydantic import Field from tqdm.asyncio import tqdm_asyncio @@ -82,11 +84,21 @@ class VerifiersEnv(BaseEnv): self.rubric = self.vf_env.rubric self.parser = self.rubric.parser - self.reward_funcs = self.rubric.funcs - self.reward_weights = self.rubric.weights - self.reward_scales = [ - weight / sum(self.reward_weights) for weight in self.reward_weights - ] + + # Handle both single Rubric and RubricGroup (composite) + # RubricGroup has empty funcs/weights at top level - must extract from individual rubrics + if hasattr(self.rubric, "rubrics"): + self.reward_funcs = [] + self.reward_weights = [] + for rubric in self.rubric.rubrics: + self.reward_funcs.extend(rubric.funcs) + self.reward_weights.extend(rubric.weights) + else: + self.reward_funcs = self.rubric.funcs + self.reward_weights = self.rubric.weights + + total = sum(self.reward_weights) if self.reward_weights else 1.0 + self.reward_scales = [weight / total for weight in self.reward_weights] self.system_prompt = self.vf_env.system_prompt @classmethod @@ -135,9 +147,15 @@ class VerifiersEnv(BaseEnv): async def setup(self): train_data = self.vf_env.get_dataset() - self.train = train_data.select_columns(["question", "answer"]).to_list() + # Only load columns we need to avoid memory bloat + columns_to_keep = ["question", "answer", "info"] + available_columns = [c for c in columns_to_keep if c in train_data.column_names] + self.train = train_data.select_columns(available_columns).to_list() test_data = self.vf_env.get_eval_dataset() - self.test = test_data.select_columns(["question", "answer"]).to_list() + available_test_columns = [ + c for c in columns_to_keep if c in test_data.column_names + ] + self.test = test_data.select_columns(available_test_columns).to_list() self.iter = 0 def save_checkpoint(self, step, data=None): @@ -254,70 +272,116 @@ class VerifiersEnv(BaseEnv): async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, list]: """ - Collect trajectories - automatically switches between: - - ManagedServer (for RL training with local server - requires logprobs) - - tokenize_for_trainer (for SFT datagen with any API - no logprobs needed) + Collect trajectories - switches between: + - SFT data generation (process mode): Any API, no logprobs needed + - RL training (serve mode): Local server with logprobs """ - question = item["question"] - answer = item["answer"] - - messages = [ - {"role": "system", "content": self.system_prompt}, - {"role": "user", "content": question}, - ] - - # Check if we're in process mode (SFT data generation) is_process_mode = getattr(self, "process_mode", False) if is_process_mode: - return await self._collect_trajectories_for_sft(messages, answer) + return await self._collect_trajectories_for_sft(item) else: - return await self._collect_trajectories_for_training(messages, answer) + return await self._collect_trajectories_for_rl(item) async def _collect_trajectories_for_sft( - self, messages: List[Dict], answer: str + self, item: Dict[str, Any] ) -> Tuple[ScoredDataGroup, list]: """ SFT data generation mode - works with ANY API (OpenAI, Claude, local). Does NOT require logprobs or local server. - Uses tokenize_for_trainer to tokenize completions with your training - tokenizer, so the resulting data is ready for fine-tuning your target model. + Uses verifiers rollout() for multi-turn environments and tokenize_for_trainer + to tokenize completions with your training tokenizer. """ - completions = await self.server.chat_completion( - messages=messages, - n=self.config.group_size, - max_tokens=self.config.max_token_length, - temperature=1.0, + question = item["question"] + answer = item["answer"] + + # Build initial messages + initial_messages: List[Dict[str, str]] = [] + if self.system_prompt: + initial_messages.append({"role": "system", "content": self.system_prompt}) + initial_messages.append({"role": "user", "content": question}) + + # Create AsyncOpenAI client directly from server config (no ManagedServer needed) + server_config = self.server.servers[0].config + client = AsyncOpenAI( + api_key=server_config.api_key, + base_url=server_config.base_url, + timeout=server_config.timeout, ) + # Sampling args - use max_completion_tokens for newer models like gpt-5 + sampling_args = { + "temperature": 1.0, + "max_completion_tokens": self.config.max_token_length, + } + scored_data = ScoredDataGroup() scored_data["tokens"] = [] scored_data["masks"] = [] scored_data["scores"] = [] scored_data["messages"] = [] - # Note: No inference_logprobs - not needed/available for SFT - for choice in completions.choices: - response = choice.message.content or "" - finish_reason = choice.finish_reason or "" + # Semaphore for scoring (required by rubric.score_rollout) + score_sem = asyncio.Semaphore(1) - # Build full conversation for scoring and tokenization - completion_messages = messages + [ - {"role": "assistant", "content": response} + # Run rollouts in parallel for group_size + async def run_single_rollout(example_id: int): + # Pass through any info from the dataset item (e.g., docker_image for SWE envs) + item_info = item.get("info", {}) + rollout_input = { + "prompt": initial_messages, + "answer": answer, + "example_id": example_id, + "task": self.config.vf_env_name, + "info": item_info, + } + state = await self.vf_env.rollout( + input=rollout_input, + client=client, + model=server_config.model_name, + sampling_args=sampling_args, + ) + # Score the rollout using verifiers rubric (computes reward from test output) + # This is needed because vf_env.rollout() doesn't call score_rollout + await self.rubric.score_rollout(state, score_sem=score_sem) + return state + + # Run group_size rollouts in parallel + rollout_tasks = [run_single_rollout(i) for i in range(self.config.group_size)] + states = await asyncio.gather(*rollout_tasks) + + for state in states: + # Extract completion messages from state + completion_messages = list(state.get("prompt", [])) + list( + state.get("completion", []) + ) + # Ensure all message contents are strings (not None) + # This can happen with tool call messages that have content: null + completion_messages = [ + {**msg, "content": msg.get("content") or ""} + for msg in completion_messages ] - # Score using verifiers reward funcs - score = self._compute_score(completion_messages, answer) + # Get reward from verifiers scoring (set by rubric.score_rollout above) + score = state.get("reward", 0.0) + + # Determine finish reason from last trajectory step + trajectory = state.get("trajectory", []) + if trajectory: + finish_reason = trajectory[-1]["response"].choices[0].finish_reason + else: + finish_reason = "stop" # Use tokenize_for_trainer for tokenization - # This uses YOUR training tokenizer (e.g., Qwen, Llama), not the API's tokenizer - # So GPT-4o responses get tokenized for your target model + # train_on_all_assistant_turns=True ensures ALL assistant turns are unmasked + # for multi-turn environments, not just the last message tokenized = tokenize_for_trainer( tokenizer=self.tokenizer, chat=completion_messages, include_messages=True, finish_reason=finish_reason, + train_on_all_assistant_turns=True, ) scored_data["tokens"].append(tokenized["tokens"]) @@ -331,49 +395,73 @@ class VerifiersEnv(BaseEnv): return scored_data, [] - async def _collect_trajectories_for_training( - self, messages: List[Dict], answer: str + async def _collect_trajectories_for_rl( + self, item: Dict[str, Any] ) -> Tuple[ScoredDataGroup, list]: """ - RL training mode - requires local inference server. - Uses ManagedServer for proper token/logprob alignment. - - The inference_logprobs are required for policy gradient methods like - GRPO, PPO, REINFORCE, etc. + RL training mode - requires local inference server for logprobs. + Uses AtroposManagedClient with vf_env.rollout() for both single-turn and multi-turn. """ - async with self.server.managed_server(tokenizer=self.tokenizer) as managed: - completions = await managed.chat_completion( - messages=messages, - n=self.config.group_size, - max_tokens=self.config.max_token_length, - temperature=1.0, - ) - state = managed.get_state() - nodes = state["nodes"] + from atroposlib.envs.server_handling.atropos_managed_client import ( + AtroposManagedClient, + ) + + question = item["question"] + answer = item["answer"] + item_info = item.get("info", {}) + + initial_messages: List[Dict[str, str]] = [] + if self.system_prompt: + initial_messages.append({"role": "system", "content": self.system_prompt}) + initial_messages.append({"role": "user", "content": question}) + + sampling_args = { + "temperature": 1.0, + "max_completion_tokens": self.config.max_token_length, + } scored_data = ScoredDataGroup() scored_data["tokens"] = [] scored_data["masks"] = [] scored_data["scores"] = [] - scored_data["inference_logprobs"] = [] # Required for RL training! + scored_data["inference_logprobs"] = [] - for i, choice in enumerate(completions.choices): - response = choice.message.content or "" + # Semaphore for scoring (required by rubric.score_rollout) + score_sem = asyncio.Semaphore(1) - # Build full conversation for scoring - completion_messages = messages + [ - {"role": "assistant", "content": response} - ] + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + client = AtroposManagedClient( + managed_server=managed, + model=self.server_configs[0].model_name, + ) - # Score using verifiers reward funcs - score = self._compute_score(completion_messages, answer) + # Run group_size rollouts sequentially (ManagedServer state must be reset between) + for i in range(self.config.group_size): + client.reset() - # Use ManagedServer's properly aligned tokens/masks/logprobs - node = nodes[i] - scored_data["tokens"].append(node.tokens) - scored_data["masks"].append(node.masked_tokens) - scored_data["inference_logprobs"].append(node.logprobs) - scored_data["scores"].append(score) + rollout_input = { + "prompt": initial_messages, + "answer": answer, + "example_id": i, + "task": self.config.vf_env_name, + "info": item_info, + } + + state = await self.vf_env.rollout( + input=rollout_input, + client=client, + model=self.server_configs[0].model_name, + sampling_args=sampling_args, + ) + + # Score the rollout (computes reward from test output) + await self.rubric.score_rollout(state, score_sem=score_sem) + + tokens, masks, logprobs, score = self._extract_from_state(state) + scored_data["tokens"].append(tokens) + scored_data["masks"].append(masks) + scored_data["inference_logprobs"].append(logprobs) + scored_data["scores"].append(score) # Track scores for wandb logging for score in scored_data["scores"]: @@ -381,6 +469,42 @@ class VerifiersEnv(BaseEnv): return scored_data, [] + def _extract_from_state( + self, state: Any + ) -> Tuple[List[int], List[int], List[float], float]: + """ + Extract tokens/masks/logprobs/score from a single rollout state. + + Handles the mask convention conversion: + - Verifiers: prompt_mask=0, completion_mask=1 + - Atropos: masked_tokens=-100 (prompt), token_id (completion) + """ + all_tokens: List[int] = [] + all_masks: List[int] = [] + all_logprobs: List[float] = [] + + trajectory = state.get("trajectory", []) + + for step in trajectory: + tokens = step["tokens"] + + prompt_ids = tokens["prompt_ids"] + completion_ids = tokens["completion_ids"] + completion_logprobs = tokens["completion_logprobs"] + + all_tokens.extend(prompt_ids) + all_tokens.extend(completion_ids) + + all_masks.extend([-100] * len(prompt_ids)) + all_masks.extend(completion_ids) + + all_logprobs.extend([1.0] * len(prompt_ids)) + all_logprobs.extend(completion_logprobs) + + reward = state["reward"] + + return all_tokens, all_masks, all_logprobs, reward + if __name__ == "__main__": VerifiersEnv.cli()