add docs :)

This commit is contained in:
Dakota 2025-10-29 11:26:43 -05:00
parent c3a118f50d
commit 5d6d6bb0dc
6 changed files with 892 additions and 21 deletions

View file

@ -373,18 +373,37 @@ class MathEnv(BaseEnv):
if thinking_len < 1024:
print("thinking_len is less than 1024, skipping", flush=True)
return None, []
# Use managed server for automatic token/logprob tracking
# ============================================================================
# MANAGED SERVER USAGE - Chat Completion API
# ============================================================================
# This demonstrates using ManagedServer with the chat_completion() API.
# The process is identical to the completion() API (see math_server_zero.py),
# but uses OpenAI chat message format instead of raw text prompts.
#
# ManagedServer automatically:
# 1. Applies the tokenizer's chat template to convert messages to text
# 2. Tokenizes both prompt and completion
# 3. Applies proper masking (-100 for prompt tokens, actual IDs for completion)
# 4. Applies proper logprob masking (1.0 for prompt, actual values for completion)
# 5. Ensures perfect alignment between tokens and logprobs
#
# See: atroposlib/envs/server_handling/MANAGED_SERVER.md for full documentation
# ============================================================================
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
# Call chat_completion through the managed server wrapper
# Returns standard OpenAI-compatible ChatCompletion object
chat_completions = await managed.chat_completion(
messages=chat,
n=self.config.group_size,
n=self.config.group_size, # Generate multiple completions for GRPO
max_tokens=thinking_len,
temperature=1.0,
top_p=0.95,
)
# Get tracked sequences with aligned tokens and logprobs
state = managed.get_state()
nodes = state["nodes"]
nodes = state["nodes"] # List of SequenceNode objects, one per completion
print("Finished generation", flush=True)
to_score = list()
@ -397,14 +416,18 @@ class MathEnv(BaseEnv):
{"role": "user", "content": user_prompt},
{"role": "assistant", "content": chat_completion.message.content},
)
# Extract pre-computed data from SequenceNode
# node.tokens: Full unmasked tokens [prompt + completion]
# node.masked_tokens: [-100, ..., -100, tok1, tok2, ...] for training
# node.logprobs: [1.0, ..., 1.0, logp1, logp2, ...] for training
to_score.append(
(
messages,
item[1],
item[1], # Ground truth answer
chat_completion.finish_reason,
node.tokens,
node.masked_tokens,
node.logprobs,
node.tokens, # Pre-computed by ManagedServer
node.masked_tokens, # Pre-computed by ManagedServer
node.logprobs, # Pre-computed by ManagedServer
)
)
print("scoring normal", flush=True)
@ -776,7 +799,23 @@ class MathEnv(BaseEnv):
self.tokenizer.apply_chat_template(chat_bwd, add_generation_prompt=True)
)
# Use managed server for both forward and backward completions
# ============================================================================
# MULTIPLE MANAGED SERVER CONTEXTS - RLAIF Pattern
# ============================================================================
# This demonstrates using SEPARATE managed_server contexts for independent
# completions. Each context tracks its own set of sequences independently.
#
# Pattern: Create separate async functions that each use their own context,
# then gather them in parallel. This is useful for:
# - RLAIF (forward/backward preference judgments)
# - Multi-step workflows where completions don't extend each other
# - Comparing different prompts or conditions
#
# Note: The tokens/masks/logprobs from these contexts are NOT used directly
# in this RLAIF workflow. Instead, we stored them earlier from the original
# completions (see lines 461-471 where they're added to backlog_item).
# ============================================================================
async def get_fwd_completion():
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
return await managed.chat_completion(
@ -896,7 +935,8 @@ class MathEnv(BaseEnv):
max_token_length = self.config.max_token_length - len(
self.tokenizer.apply_chat_template(chat, add_generation_prompt=True)
)
# Use managed server for judge completions
# Judge completions: Standard managed_server usage
# Tokens/masks/logprobs from nodes will be used directly for training
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
chat_completions = await managed.chat_completion(
messages=chat,
@ -990,7 +1030,10 @@ class MathEnv(BaseEnv):
retry_messages, add_generation_prompt=True
)
)
# Use managed server for retry completions
# Retry/self-correction completions: Nested managed_server usage
# This demonstrates using managed_server INSIDE another workflow.
# Tokens/masks/logprobs from retry_nodes will be stored in backlog
# for potential use in the "selfcorrect" trajectory type (see lines 1070-1077)
async with self.server.managed_server(
tokenizer=self.tokenizer
) as managed:
@ -1031,7 +1074,9 @@ class MathEnv(BaseEnv):
)
)
backlog_reasons.append(retry_chat_completion.finish_reason)
# Store tokens, masks, and logprobs from managed_server
# Store pre-computed tokens/masks/logprobs from ManagedServer
# These will be passed through the backlog (line 1110-1116) and
# eventually used in collect_trajectories "selfcorrect" case (line 620-636)
backlog_tokens.append(retry_node.tokens)
backlog_masks.append(retry_node.masked_tokens)
backlog_logprobs.append(retry_node.logprobs)