mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
add docs :)
This commit is contained in:
parent
c3a118f50d
commit
5d6d6bb0dc
6 changed files with 892 additions and 21 deletions
|
|
@ -125,6 +125,7 @@ Environment Used: [https://github.com/NousResearch/atropos/blob/main/environment
|
|||
|
||||
Key Documents:
|
||||
- [Base Environment Class](atroposlib/envs/README.md) - Documentation for creating custom environments
|
||||
- [ManagedServer Guide](atroposlib/envs/server_handling/MANAGED_SERVER.md) - **Recommended approach** for automatic token and logprob tracking
|
||||
- [Environments Overview and Contribution Guide](environments/community/README.md) - Documentation for existing environments and how to contribute new ones.
|
||||
- [Full Environment Config Options](CONFIG.md) - Documentation for creating custom environments
|
||||
- [Example Trainer](example_trainer/README.md) - Getting started with training
|
||||
|
|
|
|||
|
|
@ -142,6 +142,12 @@ The API documentation (Swagger UI) will be available at `http://<your-server-ip>
|
|||
images: Optional[Any] = None # Image data (if applicable)
|
||||
env_id: Optional[int] = None # ID of the environment that generated this data
|
||||
```
|
||||
* **Expected Data Format:**
|
||||
* `tokens`: Full unmasked token sequences (prompt + completion)
|
||||
* `masks`: Token sequences for training with **`-100` for prompt positions**, actual token IDs for completion positions
|
||||
* `inference_logprobs`: Optional logprob sequences for training with **`1.0` for masked positions** (masked), actual logprob values for completion positions
|
||||
* Why **1.0** for masked logprobs? It represents an "obviously bad" probability (e^1.0 ≈ 2.718 > 1.0, invalid), making masked positions easy to identify during training
|
||||
* **Recommended:** Use [ManagedServer](../envs/server_handling/MANAGED_SERVER.md) in your environment to automatically produce this format
|
||||
* **Response:**
|
||||
* Normal submission: `{"status": "received"}`
|
||||
* Mixed-size group buffered: `{"status": "buffered", "buffer_size": <sequences_in_buffer>}`
|
||||
|
|
|
|||
|
|
@ -17,6 +17,98 @@ To achieve this generality, our environment abstraction deviates from other open
|
|||
|
||||
- **Environments return tokens (not messages!)**: One of the most peculiar design choices we made was that at least for text-only environments, environments are responsible for tokenization. This gives us the flexibility to assign token-level rewards and to mix completions-based (e.g. autocomplete suggestion accept/reject) and chat-based (e.g. instruct-model code generation) environments together in the same training run. For cases like multimodal where a OpenAI-formatted message list needs to be passed to a transformers `AutoProcessor`, we support a `list[dict]`-valued `messages` key within our group abstraction [ScoredDataGroup](https://github.com/NousResearch/atropos/blob/a282604baac8dbb3b201f992cfc889ee1e5a0f4a/atroposlib/envs/base.py#L55).
|
||||
|
||||
## Working with Servers and ManagedServer
|
||||
|
||||
**🎯 Recommended Approach:** Use `ManagedServer` for automatic token and logprob tracking!
|
||||
|
||||
When implementing `collect_trajectory` or `collect_trajectories`, you need to interact with your inference server to generate completions and extract tokens/logprobs for training. The **recommended way** to do this is using `ManagedServer`, which automatically handles tokenization, masking, and logprob alignment.
|
||||
|
||||
### ManagedServer Overview
|
||||
|
||||
`ManagedServer` wraps your `APIServer` and automatically tracks:
|
||||
- **Tokens**: Full unmasked token sequences
|
||||
- **Masked Tokens**: Training format with `-100` for prompt positions, actual token IDs for completion
|
||||
- **Logprobs**: Training format with `1.0` for masked positions, actual logprob values for completion
|
||||
- **Full Text**: Complete text (prompt + completion)
|
||||
- **Metadata**: Finish reasons and other information
|
||||
|
||||
**Why 1.0 for masked logprobs?** It represents an "obviously bad" probability (e^1.0 ≈ 2.718 > 1.0, which is invalid), making it easy to identify and ignore during training.
|
||||
|
||||
### Basic Usage Pattern
|
||||
|
||||
```python
|
||||
async def collect_trajectories(self, item):
|
||||
prompt = format_prompt(item)
|
||||
|
||||
# Use managed server with context manager
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
completion = await managed.completion(
|
||||
prompt=prompt,
|
||||
n=self.config.group_size,
|
||||
max_tokens=4096,
|
||||
temperature=1.0,
|
||||
)
|
||||
|
||||
# Get tracked sequences with aligned tokens and logprobs
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
# Extract pre-computed, guaranteed-aligned data
|
||||
for choice, node in zip(completion.choices, nodes):
|
||||
tokens = node.tokens # ✅ Automatically computed
|
||||
masked_tokens = node.masked_tokens # ✅ Automatically masked
|
||||
logprobs = node.logprobs # ✅ Automatically aligned
|
||||
finish_reason = node.metadata["finish_reason"]
|
||||
|
||||
# Score and return...
|
||||
```
|
||||
|
||||
### Chat Completion Pattern
|
||||
|
||||
For chat-based environments, use `chat_completion()`:
|
||||
|
||||
```python
|
||||
async def collect_trajectories(self, item):
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": item["question"]},
|
||||
]
|
||||
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
chat_completion = await managed.chat_completion(
|
||||
messages=messages,
|
||||
n=self.config.group_size,
|
||||
max_tokens=4096,
|
||||
)
|
||||
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
# Process nodes...
|
||||
```
|
||||
|
||||
### Benefits Over Manual Handling
|
||||
|
||||
❌ **Without ManagedServer:**
|
||||
- Manually tokenize prompts and completions
|
||||
- Manually compute prompt lengths
|
||||
- Manually apply masking logic
|
||||
- Manually extract and align logprobs
|
||||
- Prone to off-by-one errors
|
||||
|
||||
✅ **With ManagedServer:**
|
||||
- Automatic tokenization
|
||||
- Automatic masking
|
||||
- Guaranteed alignment
|
||||
- Clean, simple code
|
||||
- Works with both `completion()` and `chat_completion()` APIs
|
||||
|
||||
### Complete Documentation
|
||||
|
||||
For detailed examples, advanced patterns (multi-turn, RLAIF, backlog workflows), API reference, and migration guide, see:
|
||||
|
||||
📚 **[ManagedServer Complete Guide](server_handling/MANAGED_SERVER.md)**
|
||||
|
||||
## Core Methods to Implement
|
||||
|
||||
These methods **must** be implemented in your subclass:
|
||||
|
|
|
|||
692
atroposlib/envs/server_handling/MANAGED_SERVER.md
Normal file
692
atroposlib/envs/server_handling/MANAGED_SERVER.md
Normal file
|
|
@ -0,0 +1,692 @@
|
|||
# ManagedServer: Automatic Token and Logprob Tracking
|
||||
|
||||
## Overview
|
||||
|
||||
`ManagedServer` is a wrapper around `APIServer` that automatically tracks text sequences with aligned tokens and logprobs. It eliminates the need for manual token extraction, alignment, and masking in your environment code, making it **the recommended approach** for handling inference in Atropos environments.
|
||||
|
||||
### Why Use ManagedServer?
|
||||
|
||||
**Before ManagedServer** (manual approach):
|
||||
```python
|
||||
# Manual token extraction
|
||||
response = await self.server.completion(prompt=prompt, n=8)
|
||||
# Manually tokenize and align
|
||||
tokens = self.tokenizer.encode(prompt + response.text)
|
||||
# Manually apply masking
|
||||
prompt_len = len(self.tokenizer.encode(prompt))
|
||||
masked_tokens = [-100] * prompt_len + tokens[prompt_len:]
|
||||
# Manually extract and align logprobs
|
||||
logprobs = extract_logprobs_somehow(response)
|
||||
```
|
||||
|
||||
**With ManagedServer** (automatic):
|
||||
```python
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
response = await managed.completion(prompt=prompt, n=8)
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
# tokens, masked_tokens, and logprobs are already aligned and ready!
|
||||
```
|
||||
|
||||
### Key Benefits
|
||||
|
||||
- ✅ **Automatic Tokenization**: No need to manually tokenize prompts and completions
|
||||
- ✅ **Automatic Masking**: Prompt tokens automatically masked with -100, logprobs with 1.0
|
||||
- ✅ **Perfect Alignment**: Tokens and logprobs are guaranteed to align correctly
|
||||
- ✅ **Multi-turn Support**: Automatically handles conversation extensions
|
||||
- ✅ **Branching Support**: Handles n>1 completions naturally
|
||||
- ✅ **Clean API**: Simple context manager pattern
|
||||
- ✅ **Less Error-Prone**: Eliminates common token alignment bugs
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### SequenceNode Structure
|
||||
|
||||
Each completion tracked by ManagedServer is stored as a `SequenceNode`:
|
||||
|
||||
```python
|
||||
class SequenceNode(BaseModel):
|
||||
full_text: str # Complete text (prompt + completion)
|
||||
tokens: List[int] # Full token sequence (unmasked)
|
||||
masked_tokens: List[int] # Tokens for training (-100 for prompt, actual IDs for completion)
|
||||
logprobs: List[float] # Logprobs for training (1.0 for prompt, actual values for completion)
|
||||
metadata: Optional[Dict[str, Any]] # Contains finish_reason, etc.
|
||||
```
|
||||
|
||||
### Masking Methodology
|
||||
|
||||
ManagedServer applies automatic masking to distinguish between prompt and completion:
|
||||
|
||||
| Field | Masked Positions | Completion Positions | Purpose |
|
||||
|-------|------------------|-----------------------|--------------------------------|
|
||||
| `tokens` | Actual token IDs | Actual token IDs | Full unmasked sequence |
|
||||
| `masked_tokens` | **-100** | Actual token IDs | Training input (mask prompts) |
|
||||
| `logprobs` | **1.0** | Actual logprob values | Training target (mask prompts) |
|
||||
|
||||
**Why 1.0 for masked logprobs?**
|
||||
|
||||
The value 1.0 is used to indicate "obviously bad" logprobs for prompt positions:
|
||||
- `e^1.0 ≈ 2.718`, which would represent a probability > 1.0 (invalid)
|
||||
- This makes masked positions easy to identify and filter during training
|
||||
- Trainers should ignore positions where `logprobs > 0.0` or where `masked_tokens == -100`
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
# Prompt: "What is 2+2?"
|
||||
# Completion: " 4"
|
||||
# Tokenized: [1, 1867, 374, 220, 17, 10, 17, 30] + [220, 19]
|
||||
|
||||
node.tokens = [1, 1867, 374, 220, 17, 10, 17, 30, 220, 19]
|
||||
node.masked_tokens = [-100, -100, -100, -100, -100, -100, -100, -100, 220, 19]
|
||||
node.logprobs = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -0.342, -0.156]
|
||||
```
|
||||
|
||||
### Two Operating Modes
|
||||
|
||||
ManagedServer supports two modes for tracking sequences:
|
||||
|
||||
#### 1. Default Mode (track_tree=False)
|
||||
- Maintains a simple list of current nodes
|
||||
- When a new prompt **extends** an existing node's `full_text`, it **replaces** that node
|
||||
- Best for most RL scenarios (GRPO, DPO, etc.)
|
||||
- Accessed via `state["nodes"]`
|
||||
|
||||
```python
|
||||
async with server.managed_server(tokenizer=tokenizer) as managed:
|
||||
# First completion
|
||||
await managed.completion(prompt="Hello", n=1)
|
||||
state = managed.get_state()
|
||||
len(state["nodes"]) # → 1
|
||||
|
||||
# Extension (prompt starts with previous full_text)
|
||||
await managed.completion(prompt="Hello World", n=1)
|
||||
state = managed.get_state()
|
||||
len(state["nodes"]) # → 1 (replaced, not added)
|
||||
```
|
||||
|
||||
#### 2. Tree Mode (track_tree=True)
|
||||
- Maintains a dictionary of nodes keyed by `full_text`
|
||||
- Every unique `full_text` creates a new entry
|
||||
- Useful for multi-turn RL with per-step advantages
|
||||
- Accessed via `state["sequences"]` or `state["tree"]`
|
||||
|
||||
```python
|
||||
managed = ManagedServer(server, tokenizer=tokenizer, track_tree=True)
|
||||
```
|
||||
|
||||
## Usage Patterns
|
||||
|
||||
### Pattern 1: Basic Single-Turn (Completion API)
|
||||
|
||||
Use with completion-style prompts (like math_server_zero.py):
|
||||
|
||||
```python
|
||||
async def collect_trajectories(self, item):
|
||||
prompt = format_prompt(item)
|
||||
|
||||
# Use managed server context
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
completion = await managed.completion(
|
||||
prompt=prompt,
|
||||
n=self.config.group_size, # e.g., 16
|
||||
max_tokens=4096,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
)
|
||||
|
||||
# Get tracked sequences
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
# Process nodes for training
|
||||
to_score = []
|
||||
for choice, node in zip(completion.choices, nodes):
|
||||
to_score.append({
|
||||
"full_text": node.full_text,
|
||||
"tokens": node.tokens,
|
||||
"masked_tokens": node.masked_tokens,
|
||||
"logprobs": node.logprobs,
|
||||
"finish_reason": node.metadata["finish_reason"],
|
||||
})
|
||||
|
||||
return await self.score(to_score)
|
||||
```
|
||||
|
||||
### Pattern 2: Basic Single-Turn (Chat Completion API)
|
||||
|
||||
Use with chat messages (like math_server.py):
|
||||
|
||||
```python
|
||||
async def collect_trajectories(self, item):
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": item["question"]},
|
||||
]
|
||||
|
||||
# Use managed server context
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
chat_completion = await managed.chat_completion(
|
||||
messages=messages,
|
||||
n=self.config.group_size,
|
||||
max_tokens=4096,
|
||||
temperature=1.0,
|
||||
top_p=0.95,
|
||||
)
|
||||
|
||||
# Get tracked sequences
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
# Process nodes
|
||||
to_score = []
|
||||
for choice, node in zip(chat_completion.choices, nodes):
|
||||
to_score.append({
|
||||
"content": choice.message.content,
|
||||
"tokens": node.tokens,
|
||||
"masked_tokens": node.masked_tokens,
|
||||
"logprobs": node.logprobs,
|
||||
"finish_reason": choice.finish_reason,
|
||||
})
|
||||
|
||||
return await self.score(to_score)
|
||||
```
|
||||
|
||||
### Pattern 3: Multi-Turn Conversations
|
||||
|
||||
ManagedServer automatically detects when a prompt extends a previous sequence:
|
||||
|
||||
```python
|
||||
# Turn 1
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
await managed.completion(prompt="Hello", n=1)
|
||||
state = managed.get_state()
|
||||
# nodes[0].full_text = "Hello World"
|
||||
|
||||
# Turn 2: Extends turn 1
|
||||
# This prompt starts with "Hello World" (turn 1's full_text)
|
||||
await managed.completion(prompt="Hello World! How are you?", n=1)
|
||||
state = managed.get_state()
|
||||
# nodes[0].full_text = "Hello World! How are you? I'm great!"
|
||||
# The node from turn 1 has been replaced with the extended version
|
||||
```
|
||||
|
||||
**How Extension Detection Works:**
|
||||
1. ManagedServer checks if the new prompt starts with any existing node's `full_text`
|
||||
2. If yes, it reuses those tokens and only tokenizes the new suffix
|
||||
3. The extended node replaces the original in the list
|
||||
|
||||
### Pattern 4: Multiple Contexts in One Method
|
||||
|
||||
You can use multiple managed_server contexts for complex workflows:
|
||||
|
||||
```python
|
||||
async def collect_trajectories_rlaif(self, item):
|
||||
# First set of completions
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
completions_fwd = await managed.chat_completion(
|
||||
messages=messages_fwd,
|
||||
n=3,
|
||||
temperature=1.0,
|
||||
)
|
||||
state_fwd = managed.get_state()
|
||||
|
||||
# Second set of completions (independent context)
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
completions_bwd = await managed.chat_completion(
|
||||
messages=messages_bwd,
|
||||
n=3,
|
||||
temperature=1.0,
|
||||
)
|
||||
state_bwd = managed.get_state()
|
||||
|
||||
# Process both sets
|
||||
nodes_fwd = state_fwd["nodes"]
|
||||
nodes_bwd = state_bwd["nodes"]
|
||||
```
|
||||
|
||||
### Pattern 5: Passing Tokens Through Backlog
|
||||
|
||||
For complex multi-step workflows, you can pass pre-computed tokens/masks/logprobs through backlog items:
|
||||
|
||||
```python
|
||||
async def collect_trajectories_normal(self, item):
|
||||
# Generate initial completions
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
response = await managed.chat_completion(messages=chat, n=16)
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
# Find interesting pairs for RLAIF
|
||||
if should_do_rlaif:
|
||||
# Pass tokens/masks/logprobs to next stage
|
||||
backlog_item = (
|
||||
item["problem"],
|
||||
item["answer"],
|
||||
"rlaif", # Type marker
|
||||
messages_1,
|
||||
messages_2,
|
||||
# Pre-computed data from managed_server
|
||||
nodes[idx1].tokens, # Solution 1 tokens
|
||||
nodes[idx1].masked_tokens, # Solution 1 masks
|
||||
nodes[idx1].logprobs, # Solution 1 logprobs
|
||||
nodes[idx2].tokens, # Solution 2 tokens
|
||||
nodes[idx2].masked_tokens, # Solution 2 masks
|
||||
nodes[idx2].logprobs, # Solution 2 logprobs
|
||||
)
|
||||
return None, [backlog_item]
|
||||
|
||||
async def collect_trajectories_rlaif(self, item):
|
||||
# Extract pre-computed data
|
||||
tokens_1 = item[5]
|
||||
masks_1 = item[6]
|
||||
logprobs_1 = item[7]
|
||||
tokens_2 = item[8]
|
||||
masks_2 = item[9]
|
||||
logprobs_2 = item[10]
|
||||
|
||||
# Do RLAIF judgment...
|
||||
# Use pre-computed tokens/masks/logprobs directly
|
||||
return {
|
||||
"tokens": [tokens_1, tokens_2],
|
||||
"masks": [masks_1, masks_2],
|
||||
"inference_logprobs": [logprobs_1, logprobs_2],
|
||||
"scores": [score_1, score_2],
|
||||
}
|
||||
```
|
||||
|
||||
## Complete Examples
|
||||
|
||||
### Example 1: Completion API (math_server_zero.py)
|
||||
|
||||
```python
|
||||
async def collect_trajectories(self, item) -> Tuple[List, List]:
|
||||
# Format prompt
|
||||
user_prompt = prompt_format.format(
|
||||
prompt=problem_format.format(problem=item[0])
|
||||
)
|
||||
|
||||
# Calculate max tokens
|
||||
thinking_len = self.config.max_token_length - len(
|
||||
self.tokenizer.encode(user_prompt)
|
||||
)
|
||||
|
||||
# Use managed server for automatic token/logprob tracking
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
completion = await managed.completion(
|
||||
prompt=user_prompt,
|
||||
n=self.config.group_size,
|
||||
max_tokens=thinking_len,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
stop=stop_list,
|
||||
)
|
||||
|
||||
# Get tracked sequences with aligned tokens and logprobs
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
# Extract data from SequenceNodes for scoring
|
||||
to_score = []
|
||||
for choice, node in zip(completion.choices, nodes):
|
||||
to_score.append((
|
||||
node.full_text, # Complete text (prompt + completion)
|
||||
item[1], # Answer
|
||||
choice.finish_reason, # Finish reason
|
||||
node.tokens, # All tokens (prompt + completion)
|
||||
node.masked_tokens, # Masked tokens (-100 for prompt, IDs for completion)
|
||||
node.logprobs, # Logprobs (1.0 for prompt, actual for completion)
|
||||
))
|
||||
|
||||
# Score and return
|
||||
to_postprocess = await self.score(to_score)
|
||||
return to_postprocess, []
|
||||
```
|
||||
|
||||
### Example 2: Chat Completion API (math_server.py)
|
||||
|
||||
```python
|
||||
async def collect_trajectories_normal(self, item) -> Tuple[List, List]:
|
||||
# Prepare chat messages
|
||||
user_prompt = problem_format.format(problem=item[0])
|
||||
chat = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
# Calculate max tokens
|
||||
thinking_len = self.config.max_token_length - len(
|
||||
self.tokenizer.apply_chat_template(chat, add_generation_prompt=True)
|
||||
)
|
||||
|
||||
# Use managed server for automatic token/logprob tracking
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
chat_completions = await managed.chat_completion(
|
||||
messages=chat,
|
||||
n=self.config.group_size,
|
||||
max_tokens=thinking_len,
|
||||
temperature=1.0,
|
||||
top_p=0.95,
|
||||
)
|
||||
|
||||
# Get tracked sequences with aligned tokens and logprobs
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
# Extract data from SequenceNodes for scoring
|
||||
to_score = []
|
||||
for chat_completion, node in zip(chat_completions.choices, nodes):
|
||||
messages = (
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
{"role": "assistant", "content": chat_completion.message.content},
|
||||
)
|
||||
to_score.append((
|
||||
messages, # Full conversation
|
||||
item[1], # Answer
|
||||
chat_completion.finish_reason, # Finish reason
|
||||
node.tokens, # All tokens
|
||||
node.masked_tokens, # Masked tokens
|
||||
node.logprobs, # Logprobs
|
||||
))
|
||||
|
||||
# Score and return
|
||||
to_postprocess = await self.score_normal(to_score)
|
||||
return to_postprocess, []
|
||||
```
|
||||
|
||||
### Example 3: RLAIF with Multiple Contexts (math_server.py)
|
||||
|
||||
```python
|
||||
async def collect_trajectories_rlaif(self, item) -> Tuple[List, List]:
|
||||
# Prepare forward and backward prompts
|
||||
user_prompt_fwd = rlaif_format.format(
|
||||
problem=item[0],
|
||||
solution1=solution1_text,
|
||||
solution2=solution2_text,
|
||||
)
|
||||
user_prompt_bwd = rlaif_format.format(
|
||||
problem=item[0],
|
||||
solution1=solution2_text, # Swapped
|
||||
solution2=solution1_text, # Swapped
|
||||
)
|
||||
|
||||
# Generate both forward and backward judgments in parallel
|
||||
async def get_fwd_completion():
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
return await managed.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt_fwd},
|
||||
],
|
||||
n=3,
|
||||
max_tokens=max_tokens,
|
||||
temperature=1.0,
|
||||
top_p=0.95,
|
||||
)
|
||||
|
||||
async def get_bwd_completion():
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
return await managed.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt_bwd},
|
||||
],
|
||||
n=3,
|
||||
max_tokens=max_tokens,
|
||||
temperature=1.0,
|
||||
top_p=0.95,
|
||||
)
|
||||
|
||||
# Gather both completions
|
||||
completions_fwd, completions_bwd = await asyncio.gather(
|
||||
get_fwd_completion(),
|
||||
get_bwd_completion()
|
||||
)
|
||||
|
||||
# Extract pre-computed tokens/masks/logprobs from item
|
||||
# (These were stored when the original solutions were generated)
|
||||
tokens_1 = item[6]
|
||||
masks_1 = item[7]
|
||||
logprobs_1 = item[8]
|
||||
tokens_2 = item[9]
|
||||
masks_2 = item[10]
|
||||
logprobs_2 = item[11]
|
||||
|
||||
# Score based on judgments...
|
||||
score_1, score_2 = calculate_scores(completions_fwd, completions_bwd)
|
||||
|
||||
# Return using pre-computed tokens
|
||||
return {
|
||||
"tokens": [tokens_1, tokens_2],
|
||||
"masks": [masks_1, masks_2],
|
||||
"inference_logprobs": [logprobs_1, logprobs_2],
|
||||
"scores": [score_1, score_2],
|
||||
"messages": [messages_1, messages_2],
|
||||
}, []
|
||||
```
|
||||
|
||||
## Migration from Manual Token Handling
|
||||
|
||||
### Before: Manual Approach
|
||||
|
||||
```python
|
||||
async def collect_trajectories(self, item):
|
||||
prompt = format_prompt(item)
|
||||
|
||||
# Call server
|
||||
completion = await self.server.completion(
|
||||
prompt=prompt,
|
||||
n=8,
|
||||
max_tokens=4096,
|
||||
logprobs=True,
|
||||
)
|
||||
|
||||
# Manually handle tokens
|
||||
to_score = []
|
||||
for choice in completion.choices:
|
||||
# Manually tokenize full text
|
||||
full_text = prompt + choice.text
|
||||
tokens = self.tokenizer.encode(full_text, add_special_tokens=True)
|
||||
|
||||
# Manually compute prompt length
|
||||
prompt_tokens = self.tokenizer.encode(prompt, add_special_tokens=True)
|
||||
prompt_len = len(prompt_tokens)
|
||||
|
||||
# Manually apply masking
|
||||
masked_tokens = [-100] * prompt_len + tokens[prompt_len:]
|
||||
|
||||
# Manually extract and align logprobs (error-prone!)
|
||||
logprobs = [1.0] * prompt_len
|
||||
if hasattr(choice, 'logprobs') and choice.logprobs:
|
||||
for logprob_obj in choice.logprobs:
|
||||
logprobs.append(logprob_obj.logprob)
|
||||
|
||||
# Manually pad/truncate to match length
|
||||
while len(logprobs) < len(tokens):
|
||||
logprobs.append(1.0)
|
||||
|
||||
to_score.append({
|
||||
"tokens": tokens,
|
||||
"masked_tokens": masked_tokens,
|
||||
"logprobs": logprobs,
|
||||
})
|
||||
```
|
||||
|
||||
### After: ManagedServer Approach
|
||||
|
||||
```python
|
||||
async def collect_trajectories(self, item):
|
||||
prompt = format_prompt(item)
|
||||
|
||||
# Use managed server - everything automatic!
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
completion = await managed.completion(
|
||||
prompt=prompt,
|
||||
n=8,
|
||||
max_tokens=4096,
|
||||
)
|
||||
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
# Extract pre-computed, guaranteed-aligned data
|
||||
to_score = []
|
||||
for node in nodes:
|
||||
to_score.append({
|
||||
"tokens": node.tokens, # ✅ Automatically computed
|
||||
"masked_tokens": node.masked_tokens, # ✅ Automatically masked
|
||||
"logprobs": node.logprobs, # ✅ Automatically aligned
|
||||
})
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- ❌ No manual tokenization needed
|
||||
- ❌ No manual masking calculations
|
||||
- ❌ No logprob extraction and alignment
|
||||
- ❌ No off-by-one errors
|
||||
- ✅ Clean, simple code
|
||||
- ✅ Guaranteed correctness
|
||||
|
||||
## API Reference
|
||||
|
||||
### ManagedServer Class
|
||||
|
||||
```python
|
||||
class ManagedServer:
|
||||
def __init__(
|
||||
self,
|
||||
server: APIServer,
|
||||
tokenizer: Optional[Any] = None,
|
||||
track_tree: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the managed server.
|
||||
|
||||
Args:
|
||||
server: The underlying APIServer instance to wrap
|
||||
tokenizer: Tokenizer for encoding/decoding. If not provided,
|
||||
will attempt to extract from server or create from model name.
|
||||
track_tree: If True, maintains a tree structure with parent-child links.
|
||||
If False (default), maintains a simple list that updates in-place.
|
||||
"""
|
||||
```
|
||||
|
||||
### Methods
|
||||
|
||||
#### `async def chat_completion(**kwargs) -> ChatCompletion`
|
||||
Intercept chat completion call and track sequences.
|
||||
|
||||
**Args:**
|
||||
- `messages`: List of message dicts with 'role' and 'content'
|
||||
- `n`: Number of completions to generate
|
||||
- `max_tokens`: Maximum tokens in completion
|
||||
- Other standard chat completion parameters
|
||||
|
||||
**Returns:**
|
||||
- `ChatCompletion` response (same as OpenAI API)
|
||||
|
||||
**Side Effects:**
|
||||
- Tracks sequences in internal storage
|
||||
- Updates `current_nodes` list (default mode) or `sequences` dict (tree mode)
|
||||
|
||||
#### `async def completion(**kwargs) -> Completion`
|
||||
Intercept completion call and track sequences.
|
||||
|
||||
**Args:**
|
||||
- `prompt`: The prompt string
|
||||
- `n`: Number of completions to generate
|
||||
- `max_tokens`: Maximum tokens in completion
|
||||
- Other standard completion parameters
|
||||
|
||||
**Returns:**
|
||||
- `Completion` response (same as OpenAI API)
|
||||
|
||||
**Side Effects:**
|
||||
- Tracks sequences in internal storage
|
||||
|
||||
#### `def get_state() -> Dict[str, Any]`
|
||||
Get the current state of tracked sequences.
|
||||
|
||||
**Returns:**
|
||||
- For default mode (track_tree=False):
|
||||
```python
|
||||
{
|
||||
"nodes": List[SequenceNode] # List of tracked sequences
|
||||
}
|
||||
```
|
||||
- For tree mode (track_tree=True):
|
||||
```python
|
||||
{
|
||||
"sequences": Dict[str, SequenceNode], # Keyed by full_text
|
||||
"tree": Dict[str, SequenceNode], # Alias for compatibility
|
||||
}
|
||||
```
|
||||
|
||||
#### `def reset()`
|
||||
Clear all tracked sequences.
|
||||
|
||||
### Context Manager (Recommended Usage)
|
||||
|
||||
```python
|
||||
async with server_manager.managed_server(tokenizer=tokenizer) as managed:
|
||||
# Use managed.completion() or managed.chat_completion()
|
||||
...
|
||||
# Get state before context exits
|
||||
state = managed.get_state()
|
||||
```
|
||||
|
||||
The context manager:
|
||||
- Creates a `ManagedServer` instance
|
||||
- Returns it for use within the block
|
||||
- Automatically cleans up when the block exits
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Always use the context manager pattern** for automatic cleanup:
|
||||
```python
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
...
|
||||
```
|
||||
|
||||
2. **Get state before exiting the context**:
|
||||
```python
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
completion = await managed.completion(...)
|
||||
state = managed.get_state() # ✅ Do this inside the context
|
||||
# ❌ Don't try to access state here - context has exited
|
||||
```
|
||||
|
||||
3. **Use separate contexts for independent completions**:
|
||||
```python
|
||||
# Context 1: Generate candidates
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
candidates = await managed.completion(...)
|
||||
state1 = managed.get_state()
|
||||
|
||||
# Context 2: Judge candidates (independent)
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
judgments = await managed.completion(...)
|
||||
state2 = managed.get_state()
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue: "Extension detection not working"
|
||||
|
||||
**Cause:** The new prompt doesn't exactly start with previous node's `full_text`.
|
||||
|
||||
**Solution:** Ensure prompt strings match exactly, including whitespace:
|
||||
```python
|
||||
# Turn 1 produces: "Hello World"
|
||||
# Turn 2 prompt must be: "Hello World..." (exact prefix match)
|
||||
```
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [ManagedServer Source Code](managed_server.py)
|
||||
- [ManagedServer Tests](../../tests/test_managed_server.py)
|
||||
- [Example: math_server_zero.py](../../../../environments/math_server_zero.py#L320-L332)
|
||||
- [Example: math_server.py](../../../../environments/math_server.py#L377-L387)
|
||||
- [BaseEnv Documentation](../README.md)
|
||||
- [API Server Documentation](../../api/README.md)
|
||||
|
|
@ -373,18 +373,37 @@ class MathEnv(BaseEnv):
|
|||
if thinking_len < 1024:
|
||||
print("thinking_len is less than 1024, skipping", flush=True)
|
||||
return None, []
|
||||
# Use managed server for automatic token/logprob tracking
|
||||
|
||||
# ============================================================================
|
||||
# MANAGED SERVER USAGE - Chat Completion API
|
||||
# ============================================================================
|
||||
# This demonstrates using ManagedServer with the chat_completion() API.
|
||||
# The process is identical to the completion() API (see math_server_zero.py),
|
||||
# but uses OpenAI chat message format instead of raw text prompts.
|
||||
#
|
||||
# ManagedServer automatically:
|
||||
# 1. Applies the tokenizer's chat template to convert messages to text
|
||||
# 2. Tokenizes both prompt and completion
|
||||
# 3. Applies proper masking (-100 for prompt tokens, actual IDs for completion)
|
||||
# 4. Applies proper logprob masking (1.0 for prompt, actual values for completion)
|
||||
# 5. Ensures perfect alignment between tokens and logprobs
|
||||
#
|
||||
# See: atroposlib/envs/server_handling/MANAGED_SERVER.md for full documentation
|
||||
# ============================================================================
|
||||
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
# Call chat_completion through the managed server wrapper
|
||||
# Returns standard OpenAI-compatible ChatCompletion object
|
||||
chat_completions = await managed.chat_completion(
|
||||
messages=chat,
|
||||
n=self.config.group_size,
|
||||
n=self.config.group_size, # Generate multiple completions for GRPO
|
||||
max_tokens=thinking_len,
|
||||
temperature=1.0,
|
||||
top_p=0.95,
|
||||
)
|
||||
# Get tracked sequences with aligned tokens and logprobs
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
nodes = state["nodes"] # List of SequenceNode objects, one per completion
|
||||
|
||||
print("Finished generation", flush=True)
|
||||
to_score = list()
|
||||
|
|
@ -397,14 +416,18 @@ class MathEnv(BaseEnv):
|
|||
{"role": "user", "content": user_prompt},
|
||||
{"role": "assistant", "content": chat_completion.message.content},
|
||||
)
|
||||
# Extract pre-computed data from SequenceNode
|
||||
# node.tokens: Full unmasked tokens [prompt + completion]
|
||||
# node.masked_tokens: [-100, ..., -100, tok1, tok2, ...] for training
|
||||
# node.logprobs: [1.0, ..., 1.0, logp1, logp2, ...] for training
|
||||
to_score.append(
|
||||
(
|
||||
messages,
|
||||
item[1],
|
||||
item[1], # Ground truth answer
|
||||
chat_completion.finish_reason,
|
||||
node.tokens,
|
||||
node.masked_tokens,
|
||||
node.logprobs,
|
||||
node.tokens, # Pre-computed by ManagedServer
|
||||
node.masked_tokens, # Pre-computed by ManagedServer
|
||||
node.logprobs, # Pre-computed by ManagedServer
|
||||
)
|
||||
)
|
||||
print("scoring normal", flush=True)
|
||||
|
|
@ -776,7 +799,23 @@ class MathEnv(BaseEnv):
|
|||
self.tokenizer.apply_chat_template(chat_bwd, add_generation_prompt=True)
|
||||
)
|
||||
|
||||
# Use managed server for both forward and backward completions
|
||||
# ============================================================================
|
||||
# MULTIPLE MANAGED SERVER CONTEXTS - RLAIF Pattern
|
||||
# ============================================================================
|
||||
# This demonstrates using SEPARATE managed_server contexts for independent
|
||||
# completions. Each context tracks its own set of sequences independently.
|
||||
#
|
||||
# Pattern: Create separate async functions that each use their own context,
|
||||
# then gather them in parallel. This is useful for:
|
||||
# - RLAIF (forward/backward preference judgments)
|
||||
# - Multi-step workflows where completions don't extend each other
|
||||
# - Comparing different prompts or conditions
|
||||
#
|
||||
# Note: The tokens/masks/logprobs from these contexts are NOT used directly
|
||||
# in this RLAIF workflow. Instead, we stored them earlier from the original
|
||||
# completions (see lines 461-471 where they're added to backlog_item).
|
||||
# ============================================================================
|
||||
|
||||
async def get_fwd_completion():
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
return await managed.chat_completion(
|
||||
|
|
@ -896,7 +935,8 @@ class MathEnv(BaseEnv):
|
|||
max_token_length = self.config.max_token_length - len(
|
||||
self.tokenizer.apply_chat_template(chat, add_generation_prompt=True)
|
||||
)
|
||||
# Use managed server for judge completions
|
||||
# Judge completions: Standard managed_server usage
|
||||
# Tokens/masks/logprobs from nodes will be used directly for training
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
chat_completions = await managed.chat_completion(
|
||||
messages=chat,
|
||||
|
|
@ -990,7 +1030,10 @@ class MathEnv(BaseEnv):
|
|||
retry_messages, add_generation_prompt=True
|
||||
)
|
||||
)
|
||||
# Use managed server for retry completions
|
||||
# Retry/self-correction completions: Nested managed_server usage
|
||||
# This demonstrates using managed_server INSIDE another workflow.
|
||||
# Tokens/masks/logprobs from retry_nodes will be stored in backlog
|
||||
# for potential use in the "selfcorrect" trajectory type (see lines 1070-1077)
|
||||
async with self.server.managed_server(
|
||||
tokenizer=self.tokenizer
|
||||
) as managed:
|
||||
|
|
@ -1031,7 +1074,9 @@ class MathEnv(BaseEnv):
|
|||
)
|
||||
)
|
||||
backlog_reasons.append(retry_chat_completion.finish_reason)
|
||||
# Store tokens, masks, and logprobs from managed_server
|
||||
# Store pre-computed tokens/masks/logprobs from ManagedServer
|
||||
# These will be passed through the backlog (line 1110-1116) and
|
||||
# eventually used in collect_trajectories "selfcorrect" case (line 620-636)
|
||||
backlog_tokens.append(retry_node.tokens)
|
||||
backlog_masks.append(retry_node.masked_tokens)
|
||||
backlog_logprobs.append(retry_node.logprobs)
|
||||
|
|
|
|||
|
|
@ -316,33 +316,68 @@ class MathEnv(BaseEnv):
|
|||
curr_length += self.config.start_tok_length
|
||||
thinking_len = min(thinking_len, curr_length)
|
||||
|
||||
# Use managed server for automatic token/logprob tracking
|
||||
# ============================================================================
|
||||
# MANAGED SERVER USAGE - Automatic Token & Logprob Tracking
|
||||
# ============================================================================
|
||||
# This is the RECOMMENDED approach for handling inference in Atropos environments.
|
||||
# ManagedServer automatically:
|
||||
# 1. Tokenizes the prompt and completion
|
||||
# 2. Applies proper masking (-100 for prompt tokens, actual IDs for completion)
|
||||
# 3. Applies proper logprob masking (1.0 for prompt, actual values for completion)
|
||||
# 4. Ensures perfect alignment between tokens and logprobs
|
||||
# 5. Handles the n>1 case (multiple completions from same prompt)
|
||||
#
|
||||
# Benefits over manual handling:
|
||||
# - No manual tokenization needed
|
||||
# - No off-by-one errors
|
||||
# - No manual masking calculations
|
||||
# - Guaranteed correct alignment
|
||||
# - Clean, simple code
|
||||
#
|
||||
# See: atroposlib/envs/server_handling/MANAGED_SERVER.md for full documentation
|
||||
# ============================================================================
|
||||
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
# Call completion as usual, but through the managed server wrapper
|
||||
# This returns a standard OpenAI-compatible Completion object
|
||||
completion = await managed.completion(
|
||||
prompt=user_prompt,
|
||||
n=self.config.group_size,
|
||||
n=self.config.group_size, # Generate multiple completions for GRPO
|
||||
max_tokens=thinking_len,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
stop=stop_list,
|
||||
)
|
||||
|
||||
# Get tracked sequences with aligned tokens and logprobs
|
||||
# Get the tracked sequences with aligned tokens and logprobs
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
nodes = state["nodes"] # List of SequenceNode objects, one per completion
|
||||
|
||||
# ============================================================================
|
||||
# Extract Pre-Computed Data from SequenceNodes
|
||||
# ============================================================================
|
||||
# Each SequenceNode contains:
|
||||
# - full_text: Complete text (prompt + completion)
|
||||
# - tokens: Full unmasked token sequence [1, 2, 3, ..., N]
|
||||
# - masked_tokens: Training format [-100, -100, ..., -100, actual, actual, ...]
|
||||
# - logprobs: Training format [1.0, 1.0, ..., 1.0, -0.5, -0.3, ...]
|
||||
# - metadata: Contains finish_reason, etc.
|
||||
#
|
||||
# Note: -100 is used for prompt token masking (standard PyTorch ignore_index)
|
||||
# 1.0 is used for prompt logprob masking (obviously bad probability)
|
||||
# ============================================================================
|
||||
|
||||
# Extract data from SequenceNodes for scoring
|
||||
to_score = list()
|
||||
to_backlog = list()
|
||||
for i, (choice, node) in enumerate(zip(completion.choices, nodes)):
|
||||
to_score.append(
|
||||
(
|
||||
node.full_text, # Complete text (prompt + completion)
|
||||
item[1], # Answer
|
||||
choice.finish_reason, # finish_reason (already a clean string)
|
||||
node.tokens, # all tokens (prompt + completion)
|
||||
node.masked_tokens, # masked tokens (already formatted correctly)
|
||||
node.logprobs, # logprobs (already formatted correctly)
|
||||
item[1], # Ground truth answer
|
||||
choice.finish_reason, # "stop" or "length"
|
||||
node.tokens, # Full unmasked tokens [prompt + completion]
|
||||
node.masked_tokens, # [-100, ..., -100, tok1, tok2, ...]
|
||||
node.logprobs, # [1.0, ..., 1.0, logp1, logp2, ...]
|
||||
)
|
||||
)
|
||||
to_postprocess = await self.score(to_score)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue