mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-28 17:29:30 +00:00
parallelize verifiers_server: use generate() for SFT, parallel
ManagedServer contexts for RL
This commit is contained in:
parent
24b4488c60
commit
dceb1d8fd8
1 changed files with 93 additions and 352 deletions
|
|
@ -41,13 +41,11 @@ Docs: https://docs.primeintellect.ai/tutorials-environments/install
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import verifiers as vf
|
||||
from openai import AsyncOpenAI
|
||||
from pydantic import Field
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
|
|
@ -60,59 +58,7 @@ from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Verifiers API Compatibility Layer
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _get_rubric_reward_funcs(rubric: vf.Rubric) -> List[Callable]:
|
||||
"""
|
||||
Get reward functions from a Rubric with API version compatibility.
|
||||
|
||||
Handles different verifiers API versions:
|
||||
- v0.1.9+: rubric._get_reward_funcs() (private)
|
||||
- v0.1.5-0.1.8: rubric.get_reward_funcs() (public)
|
||||
- fallback: rubric.funcs (direct access)
|
||||
"""
|
||||
if hasattr(rubric, "_get_reward_funcs"):
|
||||
return rubric._get_reward_funcs()
|
||||
elif hasattr(rubric, "get_reward_funcs"):
|
||||
return rubric.get_reward_funcs()
|
||||
elif hasattr(rubric, "funcs"):
|
||||
return rubric.funcs
|
||||
else:
|
||||
raise AttributeError(
|
||||
f"Cannot find reward functions on rubric. "
|
||||
f"Available attrs: {dir(rubric)}"
|
||||
)
|
||||
|
||||
|
||||
def _get_rubric_reward_weights(rubric: vf.Rubric) -> List[float]:
|
||||
"""
|
||||
Get reward weights from a Rubric with API version compatibility.
|
||||
|
||||
Handles different verifiers API versions:
|
||||
- v0.1.9+: rubric._get_reward_weights() (private)
|
||||
- v0.1.5-0.1.8: rubric.get_reward_weights() (public)
|
||||
- fallback: rubric.weights (direct access)
|
||||
"""
|
||||
if hasattr(rubric, "_get_reward_weights"):
|
||||
return rubric._get_reward_weights()
|
||||
elif hasattr(rubric, "get_reward_weights"):
|
||||
return rubric.get_reward_weights()
|
||||
elif hasattr(rubric, "weights"):
|
||||
return rubric.weights
|
||||
else:
|
||||
raise AttributeError(
|
||||
f"Cannot find reward weights on rubric. " f"Available attrs: {dir(rubric)}"
|
||||
)
|
||||
|
||||
|
||||
class VfEnvConfig(BaseEnvConfig):
|
||||
"""
|
||||
Configuration for the Verifiers environments.
|
||||
"""
|
||||
|
||||
vf_env_name: str = ""
|
||||
env_args: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
|
@ -130,38 +76,12 @@ class VerifiersEnv(BaseEnv):
|
|||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.percent_correct_buffer: List[float] = []
|
||||
self.eval_metrics: List[Tuple[str, float]] = []
|
||||
|
||||
# Load verifiers environment
|
||||
logger.info("Loading verifiers environment: %s", config.vf_env_name)
|
||||
self.vf_env = vf.load_environment(config.vf_env_name, **config.env_args)
|
||||
self.rubric = self.vf_env.rubric
|
||||
|
||||
self.parser = self.rubric.parser
|
||||
|
||||
# Handle both single Rubric and RubricGroup (composite)
|
||||
# RubricGroup has empty funcs/weights at top level - must extract from individual rubrics
|
||||
if hasattr(self.rubric, "rubrics"):
|
||||
# RubricGroup: collect from all individual rubrics
|
||||
self.reward_funcs: List[Callable] = []
|
||||
self.reward_weights: List[float] = []
|
||||
for rubric in self.rubric.rubrics: # type: ignore[attr-defined]
|
||||
self.reward_funcs.extend(_get_rubric_reward_funcs(rubric))
|
||||
self.reward_weights.extend(_get_rubric_reward_weights(rubric))
|
||||
else:
|
||||
# Single Rubric: use compatibility layer
|
||||
self.reward_funcs = _get_rubric_reward_funcs(self.rubric)
|
||||
self.reward_weights = _get_rubric_reward_weights(self.rubric)
|
||||
|
||||
total = sum(self.reward_weights) if self.reward_weights else 1.0
|
||||
self.reward_scales = [weight / total for weight in self.reward_weights]
|
||||
self.system_prompt = self.vf_env.system_prompt
|
||||
|
||||
logger.info(
|
||||
"Loaded environment with %d reward functions, system_prompt=%s",
|
||||
len(self.reward_funcs),
|
||||
bool(self.system_prompt),
|
||||
)
|
||||
logger.info("Reward functions: %s", self.rubric._get_reward_func_names())
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[VfEnvConfig, List[APIServerConfig]]:
|
||||
|
|
@ -176,9 +96,6 @@ class VerifiersEnv(BaseEnv):
|
|||
max_token_length=2048,
|
||||
wandb_name="verifiers",
|
||||
)
|
||||
# Default config for local inference server (vLLM, SGLang, TRL)
|
||||
# For SFT data generation with OpenAI, override via CLI:
|
||||
# --openai.base_url https://api.openai.com/v1 --openai.model_name gpt-4o
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="gpt-4.1-nano",
|
||||
|
|
@ -193,31 +110,19 @@ class VerifiersEnv(BaseEnv):
|
|||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
# Calculate percent_correct from buffer
|
||||
if self.percent_correct_buffer:
|
||||
wandb_metrics["train/percent_correct"] = sum(
|
||||
self.percent_correct_buffer
|
||||
) / len(self.percent_correct_buffer)
|
||||
|
||||
self.percent_correct_buffer = list()
|
||||
|
||||
for item in self.eval_metrics:
|
||||
wandb_metrics[item[0]] = item[1]
|
||||
self.eval_metrics = list()
|
||||
self.percent_correct_buffer = []
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
async def setup(self):
|
||||
train_data = self.vf_env.get_dataset()
|
||||
# Only load columns we need to avoid memory bloat
|
||||
columns_to_keep = ["question", "answer", "info"]
|
||||
available_columns = [c for c in columns_to_keep if c in train_data.column_names]
|
||||
self.train = train_data.select_columns(available_columns).to_list()
|
||||
test_data = self.vf_env.get_eval_dataset()
|
||||
available_test_columns = [
|
||||
c for c in columns_to_keep if c in test_data.column_names
|
||||
]
|
||||
self.test = test_data.select_columns(available_test_columns).to_list()
|
||||
self.iter = 0
|
||||
|
||||
def save_checkpoint(self, step, data=None):
|
||||
|
|
@ -226,145 +131,32 @@ class VerifiersEnv(BaseEnv):
|
|||
data["iter"] = self.iter
|
||||
super().save_checkpoint(step, data)
|
||||
|
||||
def _compute_score(self, completion_messages: List[Dict], answer: str) -> float:
|
||||
"""Compute score using verifiers reward functions."""
|
||||
rewards = []
|
||||
for func in self.reward_funcs:
|
||||
reward = func(
|
||||
parser=self.parser,
|
||||
completion=completion_messages,
|
||||
answer=answer,
|
||||
)
|
||||
rewards.append(reward)
|
||||
weighted_rewards = [r * self.reward_scales[j] for j, r in enumerate(rewards)]
|
||||
return sum(weighted_rewards)
|
||||
|
||||
async def rollout_and_score_eval(
|
||||
self, question: str, answer: str, **kwargs
|
||||
) -> dict:
|
||||
"""
|
||||
Rollout and score for evaluation.
|
||||
Uses ManagedServer in serve mode, direct API calls in process mode.
|
||||
"""
|
||||
system_prompt = kwargs.get("system_prompt")
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": question},
|
||||
]
|
||||
|
||||
is_process_mode = getattr(self, "process_mode", False)
|
||||
|
||||
if is_process_mode:
|
||||
# Process mode: use direct API call (works with any API)
|
||||
completion = await self.server.chat_completion(
|
||||
messages=messages,
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0,
|
||||
)
|
||||
else:
|
||||
# Serve mode: use ManagedServer for token tracking
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
completion = await managed.chat_completion(
|
||||
messages=messages,
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
response_content = completion.choices[0].message.content or ""
|
||||
messages.append({"role": "assistant", "content": response_content})
|
||||
|
||||
answer_parsed = self.parser.parse_answer(completion=response_content)
|
||||
|
||||
score = self._compute_score(messages, answer)
|
||||
|
||||
sample = {
|
||||
"messages": messages,
|
||||
"question": question,
|
||||
"gold_answer": answer,
|
||||
"model_parsed": str(answer_parsed) if answer_parsed else None,
|
||||
"score": score,
|
||||
"correct": bool(score),
|
||||
"finish_reason": completion.choices[0].finish_reason,
|
||||
}
|
||||
|
||||
return {"score": score, "sample": sample}
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
start_time = time.time()
|
||||
|
||||
eval_tasks = []
|
||||
for item in self.test:
|
||||
eval_tasks.append(
|
||||
self.rollout_and_score_eval(
|
||||
item["question"], item["answer"], system_prompt=self.system_prompt
|
||||
)
|
||||
)
|
||||
results = await tqdm_asyncio.gather(*eval_tasks)
|
||||
|
||||
scores = [result["score"] for result in results]
|
||||
samples = [result["sample"] for result in results]
|
||||
|
||||
avg_total_score = sum(scores) / len(scores)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
self.eval_metrics.append(("eval/avg_total_score", avg_total_score))
|
||||
|
||||
eval_metrics = {"eval/avg_total_score": avg_total_score}
|
||||
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
samples=samples,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
generation_parameters={
|
||||
"temperature": 0.0,
|
||||
"max_tokens": self.config.max_token_length,
|
||||
},
|
||||
)
|
||||
|
||||
return eval_metrics
|
||||
|
||||
async def get_next_item(self):
|
||||
next_item = self.train[self.iter % len(self.train)]
|
||||
self.iter += 1
|
||||
return next_item
|
||||
|
||||
async def evaluate(self) -> Dict[str, float]:
|
||||
"""No-op. Use environments/eval_environments/verifiers_eval.py for evaluation."""
|
||||
return {}
|
||||
|
||||
def _build_initial_messages(self, question: str) -> List[Dict[str, str]]:
|
||||
messages: List[Dict[str, str]] = []
|
||||
if self.system_prompt:
|
||||
messages.append({"role": "system", "content": self.system_prompt})
|
||||
messages.append({"role": "user", "content": question})
|
||||
return messages
|
||||
|
||||
async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, list]:
|
||||
"""
|
||||
Collect trajectories - switches between:
|
||||
- SFT data generation (process mode): Any API, no logprobs needed
|
||||
- RL training (serve mode): Local server with logprobs
|
||||
"""
|
||||
is_process_mode = getattr(self, "process_mode", False)
|
||||
|
||||
if is_process_mode:
|
||||
return await self._collect_trajectories_for_sft(item)
|
||||
else:
|
||||
return await self._collect_trajectories_for_rl(item)
|
||||
return await self._collect_for_sft(item)
|
||||
return await self._collect_for_rl(item)
|
||||
|
||||
async def _collect_trajectories_for_sft(
|
||||
async def _collect_for_sft(
|
||||
self, item: Dict[str, Any]
|
||||
) -> Tuple[ScoredDataGroup, list]:
|
||||
"""
|
||||
SFT data generation mode - works with ANY API (OpenAI, Claude, local).
|
||||
Does NOT require logprobs or local server.
|
||||
|
||||
Uses verifiers rollout() for multi-turn environments and tokenize_for_trainer
|
||||
to tokenize completions with your training tokenizer.
|
||||
"""
|
||||
question = item["question"]
|
||||
answer = item["answer"]
|
||||
|
||||
# Build initial messages
|
||||
initial_messages: List[Dict[str, str]] = []
|
||||
if self.system_prompt:
|
||||
initial_messages.append({"role": "system", "content": self.system_prompt})
|
||||
initial_messages.append({"role": "user", "content": question})
|
||||
|
||||
# Create AsyncOpenAI client directly from server config (no ManagedServer needed)
|
||||
"""SFT mode: uses vf_env.generate() for parallel generation + scoring."""
|
||||
server_config = self.server.servers[0].config
|
||||
client = AsyncOpenAI(
|
||||
api_key=server_config.api_key,
|
||||
|
|
@ -372,11 +164,31 @@ class VerifiersEnv(BaseEnv):
|
|||
timeout=server_config.timeout,
|
||||
)
|
||||
|
||||
# Sampling args - use max_completion_tokens for newer models like gpt-5
|
||||
sampling_args = {
|
||||
"temperature": 1.0,
|
||||
"max_completion_tokens": self.config.max_token_length,
|
||||
}
|
||||
initial_messages = self._build_initial_messages(item["question"])
|
||||
inputs = [
|
||||
{
|
||||
"prompt": initial_messages,
|
||||
"answer": item["answer"],
|
||||
"example_id": i,
|
||||
"task": self.config.vf_env_name,
|
||||
"info": item.get("info", {}),
|
||||
}
|
||||
for i in range(self.config.group_size)
|
||||
]
|
||||
|
||||
results = await self.vf_env.generate(
|
||||
inputs=inputs,
|
||||
client=client,
|
||||
model=server_config.model_name,
|
||||
sampling_args={
|
||||
"temperature": 1.0,
|
||||
"max_completion_tokens": self.config.max_token_length,
|
||||
},
|
||||
max_concurrent=self.config.group_size,
|
||||
max_concurrent_scoring=self.config.group_size,
|
||||
save_results=False,
|
||||
independent_scoring=True,
|
||||
)
|
||||
|
||||
scored_data = ScoredDataGroup()
|
||||
scored_data["tokens"] = []
|
||||
|
|
@ -384,63 +196,22 @@ class VerifiersEnv(BaseEnv):
|
|||
scored_data["scores"] = []
|
||||
scored_data["messages"] = []
|
||||
|
||||
# Semaphore for scoring (required by rubric.score_rollout)
|
||||
score_sem = asyncio.Semaphore(1)
|
||||
|
||||
# Run rollouts in parallel for group_size
|
||||
async def run_single_rollout(example_id: int):
|
||||
# Pass through any info from the dataset item (e.g., docker_image for SWE envs)
|
||||
item_info = item.get("info", {})
|
||||
rollout_input = {
|
||||
"prompt": initial_messages,
|
||||
"answer": answer,
|
||||
"example_id": example_id,
|
||||
"task": self.config.vf_env_name,
|
||||
"info": item_info,
|
||||
}
|
||||
state = await self.vf_env.rollout(
|
||||
input=rollout_input,
|
||||
client=client,
|
||||
model=server_config.model_name,
|
||||
sampling_args=sampling_args,
|
||||
)
|
||||
# Score the rollout using verifiers rubric (computes reward from test output)
|
||||
# This is needed because vf_env.rollout() doesn't call score_rollout
|
||||
await self.rubric.score_rollout(state, score_sem=score_sem)
|
||||
return state
|
||||
|
||||
# Run group_size rollouts in parallel
|
||||
rollout_tasks = [run_single_rollout(i) for i in range(self.config.group_size)]
|
||||
states = await asyncio.gather(*rollout_tasks)
|
||||
|
||||
for state in states:
|
||||
# Extract completion messages from state
|
||||
completion_messages = list(state.get("prompt", [])) + list(
|
||||
state.get("completion", [])
|
||||
)
|
||||
# Ensure all message contents are strings (not None)
|
||||
# This can happen with tool call messages that have content: null
|
||||
completion_messages = [
|
||||
{**msg, "content": msg.get("content") or ""}
|
||||
for msg in completion_messages
|
||||
for state in results["state"]:
|
||||
messages = list(state.get("prompt", [])) + list(state.get("completion", []))
|
||||
messages = [
|
||||
{**msg, "content": msg.get("content") or ""} for msg in messages
|
||||
]
|
||||
|
||||
# Get reward from verifiers scoring (set by rubric.score_rollout above)
|
||||
score = state.get("reward", 0.0)
|
||||
|
||||
# Determine finish reason from last trajectory step
|
||||
trajectory = state.get("trajectory", [])
|
||||
if trajectory:
|
||||
finish_reason = trajectory[-1]["response"].choices[0].finish_reason
|
||||
else:
|
||||
finish_reason = "stop"
|
||||
finish_reason = (
|
||||
trajectory[-1]["response"].choices[0].finish_reason
|
||||
if trajectory
|
||||
else "stop"
|
||||
)
|
||||
|
||||
# Use tokenize_for_trainer for tokenization
|
||||
# train_on_all_assistant_turns=True ensures ALL assistant turns are unmasked
|
||||
# for multi-turn environments, not just the last message
|
||||
tokenized = tokenize_for_trainer(
|
||||
tokenizer=self.tokenizer,
|
||||
chat=completion_messages,
|
||||
chat=messages,
|
||||
include_messages=True,
|
||||
finish_reason=finish_reason,
|
||||
train_on_all_assistant_turns=True,
|
||||
|
|
@ -448,39 +219,54 @@ class VerifiersEnv(BaseEnv):
|
|||
|
||||
scored_data["tokens"].append(tokenized["tokens"])
|
||||
scored_data["masks"].append(tokenized["masks"])
|
||||
scored_data["messages"].append(completion_messages)
|
||||
scored_data["scores"].append(score)
|
||||
scored_data["messages"].append(messages)
|
||||
scored_data["scores"].append(state.get("reward", 0.0))
|
||||
|
||||
# Track scores for wandb logging
|
||||
for score in scored_data["scores"]:
|
||||
self.percent_correct_buffer.append(max(score, 0))
|
||||
|
||||
return scored_data, []
|
||||
|
||||
async def _collect_trajectories_for_rl(
|
||||
async def _collect_for_rl(
|
||||
self, item: Dict[str, Any]
|
||||
) -> Tuple[ScoredDataGroup, list]:
|
||||
"""
|
||||
RL training mode - requires local inference server for logprobs.
|
||||
Uses AtroposManagedClient with vf_env.rollout() for both single-turn and multi-turn.
|
||||
"""
|
||||
"""RL mode: uses ManagedServer for logprobs tracking."""
|
||||
from atroposlib.envs.server_handling.atropos_managed_client import (
|
||||
AtroposManagedClient,
|
||||
)
|
||||
|
||||
question = item["question"]
|
||||
answer = item["answer"]
|
||||
item_info = item.get("info", {})
|
||||
|
||||
initial_messages: List[Dict[str, str]] = []
|
||||
if self.system_prompt:
|
||||
initial_messages.append({"role": "system", "content": self.system_prompt})
|
||||
initial_messages.append({"role": "user", "content": question})
|
||||
|
||||
initial_messages = self._build_initial_messages(item["question"])
|
||||
sampling_args = {
|
||||
"temperature": 1.0,
|
||||
"max_completion_tokens": self.config.max_token_length,
|
||||
}
|
||||
score_sem = asyncio.Semaphore(self.config.group_size)
|
||||
model = self.server.servers[0].config.model_name
|
||||
|
||||
async def run_rollout(
|
||||
example_id: int,
|
||||
) -> Tuple[List[int], List[int], List[float], float]:
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
client = AtroposManagedClient(managed_server=managed, model=model)
|
||||
|
||||
state = await self.vf_env.rollout(
|
||||
input={
|
||||
"prompt": initial_messages,
|
||||
"answer": item["answer"],
|
||||
"example_id": example_id,
|
||||
"task": self.config.vf_env_name,
|
||||
"info": item.get("info", {}),
|
||||
},
|
||||
client=client,
|
||||
model=model,
|
||||
sampling_args=sampling_args,
|
||||
)
|
||||
await self.rubric.score_rollout(state, score_sem=score_sem)
|
||||
return self._extract_from_state(state)
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[run_rollout(i) for i in range(self.config.group_size)]
|
||||
)
|
||||
|
||||
scored_data = ScoredDataGroup()
|
||||
scored_data["tokens"] = []
|
||||
|
|
@ -488,44 +274,12 @@ class VerifiersEnv(BaseEnv):
|
|||
scored_data["scores"] = []
|
||||
scored_data["inference_logprobs"] = []
|
||||
|
||||
# Semaphore for scoring (required by rubric.score_rollout)
|
||||
score_sem = asyncio.Semaphore(1)
|
||||
for tokens, masks, logprobs, score in results:
|
||||
scored_data["tokens"].append(tokens)
|
||||
scored_data["masks"].append(masks)
|
||||
scored_data["inference_logprobs"].append(logprobs)
|
||||
scored_data["scores"].append(score)
|
||||
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
client = AtroposManagedClient(
|
||||
managed_server=managed,
|
||||
model=self.server_configs[0].model_name,
|
||||
)
|
||||
|
||||
# Run group_size rollouts sequentially (ManagedServer state must be reset between)
|
||||
for i in range(self.config.group_size):
|
||||
client.reset()
|
||||
|
||||
rollout_input = {
|
||||
"prompt": initial_messages,
|
||||
"answer": answer,
|
||||
"example_id": i,
|
||||
"task": self.config.vf_env_name,
|
||||
"info": item_info,
|
||||
}
|
||||
|
||||
state = await self.vf_env.rollout(
|
||||
input=rollout_input,
|
||||
client=client,
|
||||
model=self.server_configs[0].model_name,
|
||||
sampling_args=sampling_args,
|
||||
)
|
||||
|
||||
# Score the rollout (computes reward from test output)
|
||||
await self.rubric.score_rollout(state, score_sem=score_sem)
|
||||
|
||||
tokens, masks, logprobs, score = self._extract_from_state(state)
|
||||
scored_data["tokens"].append(tokens)
|
||||
scored_data["masks"].append(masks)
|
||||
scored_data["inference_logprobs"].append(logprobs)
|
||||
scored_data["scores"].append(score)
|
||||
|
||||
# Track scores for wandb logging
|
||||
for score in scored_data["scores"]:
|
||||
self.percent_correct_buffer.append(max(score, 0))
|
||||
|
||||
|
|
@ -534,38 +288,25 @@ class VerifiersEnv(BaseEnv):
|
|||
def _extract_from_state(
|
||||
self, state: Any
|
||||
) -> Tuple[List[int], List[int], List[float], float]:
|
||||
"""
|
||||
Extract tokens/masks/logprobs/score from a single rollout state.
|
||||
|
||||
Handles the mask convention conversion:
|
||||
- Verifiers: prompt_mask=0, completion_mask=1
|
||||
- Atropos: masked_tokens=-100 (prompt), token_id (completion)
|
||||
"""
|
||||
"""Extract tokens/masks/logprobs from rollout state (RL mode only)."""
|
||||
all_tokens: List[int] = []
|
||||
all_masks: List[int] = []
|
||||
all_logprobs: List[float] = []
|
||||
|
||||
trajectory = state.get("trajectory", [])
|
||||
|
||||
for step in trajectory:
|
||||
for step in state.get("trajectory", []):
|
||||
tokens = step["tokens"]
|
||||
|
||||
prompt_ids = tokens["prompt_ids"]
|
||||
completion_ids = tokens["completion_ids"]
|
||||
completion_logprobs = tokens["completion_logprobs"]
|
||||
|
||||
all_tokens.extend(prompt_ids)
|
||||
all_tokens.extend(completion_ids)
|
||||
|
||||
all_masks.extend([-100] * len(prompt_ids))
|
||||
all_masks.extend(completion_ids)
|
||||
|
||||
all_logprobs.extend([1.0] * len(prompt_ids))
|
||||
all_logprobs.extend(completion_logprobs)
|
||||
|
||||
reward = state["reward"]
|
||||
|
||||
return all_tokens, all_masks, all_logprobs, reward
|
||||
return all_tokens, all_masks, all_logprobs, state["reward"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue