mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-27 17:23:08 +00:00
add docs :)
This commit is contained in:
parent
c3a118f50d
commit
5d6d6bb0dc
6 changed files with 892 additions and 21 deletions
|
|
@ -316,33 +316,68 @@ class MathEnv(BaseEnv):
|
|||
curr_length += self.config.start_tok_length
|
||||
thinking_len = min(thinking_len, curr_length)
|
||||
|
||||
# Use managed server for automatic token/logprob tracking
|
||||
# ============================================================================
|
||||
# MANAGED SERVER USAGE - Automatic Token & Logprob Tracking
|
||||
# ============================================================================
|
||||
# This is the RECOMMENDED approach for handling inference in Atropos environments.
|
||||
# ManagedServer automatically:
|
||||
# 1. Tokenizes the prompt and completion
|
||||
# 2. Applies proper masking (-100 for prompt tokens, actual IDs for completion)
|
||||
# 3. Applies proper logprob masking (1.0 for prompt, actual values for completion)
|
||||
# 4. Ensures perfect alignment between tokens and logprobs
|
||||
# 5. Handles the n>1 case (multiple completions from same prompt)
|
||||
#
|
||||
# Benefits over manual handling:
|
||||
# - No manual tokenization needed
|
||||
# - No off-by-one errors
|
||||
# - No manual masking calculations
|
||||
# - Guaranteed correct alignment
|
||||
# - Clean, simple code
|
||||
#
|
||||
# See: atroposlib/envs/server_handling/MANAGED_SERVER.md for full documentation
|
||||
# ============================================================================
|
||||
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
# Call completion as usual, but through the managed server wrapper
|
||||
# This returns a standard OpenAI-compatible Completion object
|
||||
completion = await managed.completion(
|
||||
prompt=user_prompt,
|
||||
n=self.config.group_size,
|
||||
n=self.config.group_size, # Generate multiple completions for GRPO
|
||||
max_tokens=thinking_len,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
stop=stop_list,
|
||||
)
|
||||
|
||||
# Get tracked sequences with aligned tokens and logprobs
|
||||
# Get the tracked sequences with aligned tokens and logprobs
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
nodes = state["nodes"] # List of SequenceNode objects, one per completion
|
||||
|
||||
# ============================================================================
|
||||
# Extract Pre-Computed Data from SequenceNodes
|
||||
# ============================================================================
|
||||
# Each SequenceNode contains:
|
||||
# - full_text: Complete text (prompt + completion)
|
||||
# - tokens: Full unmasked token sequence [1, 2, 3, ..., N]
|
||||
# - masked_tokens: Training format [-100, -100, ..., -100, actual, actual, ...]
|
||||
# - logprobs: Training format [1.0, 1.0, ..., 1.0, -0.5, -0.3, ...]
|
||||
# - metadata: Contains finish_reason, etc.
|
||||
#
|
||||
# Note: -100 is used for prompt token masking (standard PyTorch ignore_index)
|
||||
# 1.0 is used for prompt logprob masking (obviously bad probability)
|
||||
# ============================================================================
|
||||
|
||||
# Extract data from SequenceNodes for scoring
|
||||
to_score = list()
|
||||
to_backlog = list()
|
||||
for i, (choice, node) in enumerate(zip(completion.choices, nodes)):
|
||||
to_score.append(
|
||||
(
|
||||
node.full_text, # Complete text (prompt + completion)
|
||||
item[1], # Answer
|
||||
choice.finish_reason, # finish_reason (already a clean string)
|
||||
node.tokens, # all tokens (prompt + completion)
|
||||
node.masked_tokens, # masked tokens (already formatted correctly)
|
||||
node.logprobs, # logprobs (already formatted correctly)
|
||||
item[1], # Ground truth answer
|
||||
choice.finish_reason, # "stop" or "length"
|
||||
node.tokens, # Full unmasked tokens [prompt + completion]
|
||||
node.masked_tokens, # [-100, ..., -100, tok1, tok2, ...]
|
||||
node.logprobs, # [1.0, ..., 1.0, logp1, logp2, ...]
|
||||
)
|
||||
)
|
||||
to_postprocess = await self.score(to_score)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue