mirror of
https://github.com/NousResearch/atropos.git
synced 2026-05-03 17:53:17 +00:00
add docs :)
This commit is contained in:
parent
c3a118f50d
commit
5d6d6bb0dc
6 changed files with 892 additions and 21 deletions
|
|
@ -373,18 +373,37 @@ class MathEnv(BaseEnv):
|
|||
if thinking_len < 1024:
|
||||
print("thinking_len is less than 1024, skipping", flush=True)
|
||||
return None, []
|
||||
# Use managed server for automatic token/logprob tracking
|
||||
|
||||
# ============================================================================
|
||||
# MANAGED SERVER USAGE - Chat Completion API
|
||||
# ============================================================================
|
||||
# This demonstrates using ManagedServer with the chat_completion() API.
|
||||
# The process is identical to the completion() API (see math_server_zero.py),
|
||||
# but uses OpenAI chat message format instead of raw text prompts.
|
||||
#
|
||||
# ManagedServer automatically:
|
||||
# 1. Applies the tokenizer's chat template to convert messages to text
|
||||
# 2. Tokenizes both prompt and completion
|
||||
# 3. Applies proper masking (-100 for prompt tokens, actual IDs for completion)
|
||||
# 4. Applies proper logprob masking (1.0 for prompt, actual values for completion)
|
||||
# 5. Ensures perfect alignment between tokens and logprobs
|
||||
#
|
||||
# See: atroposlib/envs/server_handling/MANAGED_SERVER.md for full documentation
|
||||
# ============================================================================
|
||||
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
# Call chat_completion through the managed server wrapper
|
||||
# Returns standard OpenAI-compatible ChatCompletion object
|
||||
chat_completions = await managed.chat_completion(
|
||||
messages=chat,
|
||||
n=self.config.group_size,
|
||||
n=self.config.group_size, # Generate multiple completions for GRPO
|
||||
max_tokens=thinking_len,
|
||||
temperature=1.0,
|
||||
top_p=0.95,
|
||||
)
|
||||
# Get tracked sequences with aligned tokens and logprobs
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
nodes = state["nodes"] # List of SequenceNode objects, one per completion
|
||||
|
||||
print("Finished generation", flush=True)
|
||||
to_score = list()
|
||||
|
|
@ -397,14 +416,18 @@ class MathEnv(BaseEnv):
|
|||
{"role": "user", "content": user_prompt},
|
||||
{"role": "assistant", "content": chat_completion.message.content},
|
||||
)
|
||||
# Extract pre-computed data from SequenceNode
|
||||
# node.tokens: Full unmasked tokens [prompt + completion]
|
||||
# node.masked_tokens: [-100, ..., -100, tok1, tok2, ...] for training
|
||||
# node.logprobs: [1.0, ..., 1.0, logp1, logp2, ...] for training
|
||||
to_score.append(
|
||||
(
|
||||
messages,
|
||||
item[1],
|
||||
item[1], # Ground truth answer
|
||||
chat_completion.finish_reason,
|
||||
node.tokens,
|
||||
node.masked_tokens,
|
||||
node.logprobs,
|
||||
node.tokens, # Pre-computed by ManagedServer
|
||||
node.masked_tokens, # Pre-computed by ManagedServer
|
||||
node.logprobs, # Pre-computed by ManagedServer
|
||||
)
|
||||
)
|
||||
print("scoring normal", flush=True)
|
||||
|
|
@ -776,7 +799,23 @@ class MathEnv(BaseEnv):
|
|||
self.tokenizer.apply_chat_template(chat_bwd, add_generation_prompt=True)
|
||||
)
|
||||
|
||||
# Use managed server for both forward and backward completions
|
||||
# ============================================================================
|
||||
# MULTIPLE MANAGED SERVER CONTEXTS - RLAIF Pattern
|
||||
# ============================================================================
|
||||
# This demonstrates using SEPARATE managed_server contexts for independent
|
||||
# completions. Each context tracks its own set of sequences independently.
|
||||
#
|
||||
# Pattern: Create separate async functions that each use their own context,
|
||||
# then gather them in parallel. This is useful for:
|
||||
# - RLAIF (forward/backward preference judgments)
|
||||
# - Multi-step workflows where completions don't extend each other
|
||||
# - Comparing different prompts or conditions
|
||||
#
|
||||
# Note: The tokens/masks/logprobs from these contexts are NOT used directly
|
||||
# in this RLAIF workflow. Instead, we stored them earlier from the original
|
||||
# completions (see lines 461-471 where they're added to backlog_item).
|
||||
# ============================================================================
|
||||
|
||||
async def get_fwd_completion():
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
return await managed.chat_completion(
|
||||
|
|
@ -896,7 +935,8 @@ class MathEnv(BaseEnv):
|
|||
max_token_length = self.config.max_token_length - len(
|
||||
self.tokenizer.apply_chat_template(chat, add_generation_prompt=True)
|
||||
)
|
||||
# Use managed server for judge completions
|
||||
# Judge completions: Standard managed_server usage
|
||||
# Tokens/masks/logprobs from nodes will be used directly for training
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
chat_completions = await managed.chat_completion(
|
||||
messages=chat,
|
||||
|
|
@ -990,7 +1030,10 @@ class MathEnv(BaseEnv):
|
|||
retry_messages, add_generation_prompt=True
|
||||
)
|
||||
)
|
||||
# Use managed server for retry completions
|
||||
# Retry/self-correction completions: Nested managed_server usage
|
||||
# This demonstrates using managed_server INSIDE another workflow.
|
||||
# Tokens/masks/logprobs from retry_nodes will be stored in backlog
|
||||
# for potential use in the "selfcorrect" trajectory type (see lines 1070-1077)
|
||||
async with self.server.managed_server(
|
||||
tokenizer=self.tokenizer
|
||||
) as managed:
|
||||
|
|
@ -1031,7 +1074,9 @@ class MathEnv(BaseEnv):
|
|||
)
|
||||
)
|
||||
backlog_reasons.append(retry_chat_completion.finish_reason)
|
||||
# Store tokens, masks, and logprobs from managed_server
|
||||
# Store pre-computed tokens/masks/logprobs from ManagedServer
|
||||
# These will be passed through the backlog (line 1110-1116) and
|
||||
# eventually used in collect_trajectories "selfcorrect" case (line 620-636)
|
||||
backlog_tokens.append(retry_node.tokens)
|
||||
backlog_masks.append(retry_node.masked_tokens)
|
||||
backlog_logprobs.append(retry_node.logprobs)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue