add docs :)

This commit is contained in:
Dakota 2025-10-29 11:26:43 -05:00
parent c3a118f50d
commit 5d6d6bb0dc
6 changed files with 892 additions and 21 deletions

View file

@ -316,33 +316,68 @@ class MathEnv(BaseEnv):
curr_length += self.config.start_tok_length
thinking_len = min(thinking_len, curr_length)
# Use managed server for automatic token/logprob tracking
# ============================================================================
# MANAGED SERVER USAGE - Automatic Token & Logprob Tracking
# ============================================================================
# This is the RECOMMENDED approach for handling inference in Atropos environments.
# ManagedServer automatically:
# 1. Tokenizes the prompt and completion
# 2. Applies proper masking (-100 for prompt tokens, actual IDs for completion)
# 3. Applies proper logprob masking (1.0 for prompt, actual values for completion)
# 4. Ensures perfect alignment between tokens and logprobs
# 5. Handles the n>1 case (multiple completions from same prompt)
#
# Benefits over manual handling:
# - No manual tokenization needed
# - No off-by-one errors
# - No manual masking calculations
# - Guaranteed correct alignment
# - Clean, simple code
#
# See: atroposlib/envs/server_handling/MANAGED_SERVER.md for full documentation
# ============================================================================
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
# Call completion as usual, but through the managed server wrapper
# This returns a standard OpenAI-compatible Completion object
completion = await managed.completion(
prompt=user_prompt,
n=self.config.group_size,
n=self.config.group_size, # Generate multiple completions for GRPO
max_tokens=thinking_len,
temperature=1.0,
top_p=1.0,
stop=stop_list,
)
# Get tracked sequences with aligned tokens and logprobs
# Get the tracked sequences with aligned tokens and logprobs
state = managed.get_state()
nodes = state["nodes"]
nodes = state["nodes"] # List of SequenceNode objects, one per completion
# ============================================================================
# Extract Pre-Computed Data from SequenceNodes
# ============================================================================
# Each SequenceNode contains:
# - full_text: Complete text (prompt + completion)
# - tokens: Full unmasked token sequence [1, 2, 3, ..., N]
# - masked_tokens: Training format [-100, -100, ..., -100, actual, actual, ...]
# - logprobs: Training format [1.0, 1.0, ..., 1.0, -0.5, -0.3, ...]
# - metadata: Contains finish_reason, etc.
#
# Note: -100 is used for prompt token masking (standard PyTorch ignore_index)
# 1.0 is used for prompt logprob masking (obviously bad probability)
# ============================================================================
# Extract data from SequenceNodes for scoring
to_score = list()
to_backlog = list()
for i, (choice, node) in enumerate(zip(completion.choices, nodes)):
to_score.append(
(
node.full_text, # Complete text (prompt + completion)
item[1], # Answer
choice.finish_reason, # finish_reason (already a clean string)
node.tokens, # all tokens (prompt + completion)
node.masked_tokens, # masked tokens (already formatted correctly)
node.logprobs, # logprobs (already formatted correctly)
item[1], # Ground truth answer
choice.finish_reason, # "stop" or "length"
node.tokens, # Full unmasked tokens [prompt + completion]
node.masked_tokens, # [-100, ..., -100, tok1, tok2, ...]
node.logprobs, # [1.0, ..., 1.0, logp1, logp2, ...]
)
)
to_postprocess = await self.score(to_score)