atropos/atroposlib/envs/server_handling/MANAGED_SERVER.md
Jai Suphavadeeprasit 1eeb31065f fixing comments
2026-03-03 23:16:05 -05:00

788 lines
26 KiB
Markdown

# ManagedServer: Automatic Token and Logprob Tracking
## Overview
`ManagedServer` is a wrapper around `APIServer` that automatically tracks text sequences with aligned tokens and logprobs. It also exposes a normalized `get_logprobs(...)` API for backend-agnostic logprob access. This eliminates the need for manual token extraction, alignment, and masking in your environment code, making it **the recommended approach** for handling inference in Atropos environments.
**Server Compatibility:** ManagedServer works with `VLLMServer`, `SGLangServer`, and `TrlVllmServer`. Simply set the `server_type` field in your `APIServerConfig` to `"vllm"`, `"sglang"`, or `"trl"` to use the appropriate backend with automatic server class selection.
> **⚠️ OpenAI Endpoints:** OpenAI's API does not expose token IDs or detailed logprobs required for full ManagedServer functionality. See [OpenAI Endpoint Limitations](#openai-endpoint-limitations) for details and workarounds.
### Why Use ManagedServer?
**Before ManagedServer** (manual approach):
```python
# Manual token extraction
response = await self.server.completion(prompt=prompt, n=8)
# Manually tokenize and align
tokens = self.tokenizer.encode(prompt + response.text)
# Manually apply masking
prompt_len = len(self.tokenizer.encode(prompt))
masked_tokens = [-100] * prompt_len + tokens[prompt_len:]
# Manually extract and align logprobs
logprobs = extract_logprobs_somehow(response)
```
**With ManagedServer** (automatic):
```python
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
response = await managed.completion(prompt=prompt, n=8)
state = managed.get_state()
nodes = state["nodes"]
# tokens, masked_tokens, and logprobs are already aligned and ready!
```
### Key Benefits
-**Automatic Tokenization**: No need to manually tokenize prompts and completions
-**Automatic Masking**: Prompt tokens automatically masked with -100, logprobs with 1.0
-**Perfect Alignment**: Tokens and logprobs align positionally for tracked sequences
-**Normalized Alignment Contract**: Tokens/logprobs are shape-normalized for downstream consumers
-**Multi-turn Support**: Automatically handles conversation extensions
-**Branching Support**: Handles n>1 completions naturally
-**Clean API**: Simple context manager pattern
-**Less Error-Prone**: Eliminates common token alignment bugs
## Core Concepts
### SequenceNode Structure
Each completion tracked by ManagedServer is stored as a `SequenceNode`:
```python
class SequenceNode(BaseModel):
full_text: str # Complete text (prompt + completion)
tokens: List[int] # Full token sequence (unmasked)
masked_tokens: List[int] # Tokens for training (-100 for prompt, actual IDs for completion)
logprobs: List[float] # Logprobs for training (1.0 for prompt, actual values for completion)
metadata: Optional[Dict[str, Any]] # Contains finish_reason, etc.
```
### Masking Methodology
ManagedServer applies automatic masking to distinguish between prompt and completion:
| Field | Masked Positions | Completion Positions | Purpose |
|-------|------------------|-----------------------|--------------------------------|
| `tokens` | Actual token IDs | Actual token IDs | Full unmasked sequence |
| `masked_tokens` | **-100** | Actual token IDs | Training input (mask prompts) |
| `logprobs` | **1.0** | Actual logprob values | Training target (mask prompts) |
**Why 1.0 for masked logprobs?**
The value 1.0 is used to indicate "obviously bad" logprobs for prompt positions:
- `e^1.0 ≈ 2.718`, which would represent a probability > 1.0 (invalid)
- This makes masked positions easy to identify and filter during training
- Trainers should ignore positions where `logprobs > 0.0` or where `masked_tokens == -100`
**Example:**
```python
# Prompt: "What is 2+2?"
# Completion: " 4"
# Tokenized: [1, 1867, 374, 220, 17, 10, 17, 30] + [220, 19]
node.tokens = [1, 1867, 374, 220, 17, 10, 17, 30, 220, 19]
node.masked_tokens = [-100, -100, -100, -100, -100, -100, -100, -100, 220, 19]
node.logprobs = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -0.342, -0.156]
```
### Two Operating Modes
ManagedServer supports two modes for tracking sequences:
#### 1. Default Mode (track_tree=False)
- Maintains a simple list of current nodes
- When a new prompt **extends** an existing node's `full_text`, it **replaces** that node
- Best for most RL scenarios (GRPO, DPO, etc.)
- Accessed via `state["nodes"]`
```python
async with server.managed_server(tokenizer=tokenizer) as managed:
# First completion
await managed.completion(prompt="Hello", n=1)
state = managed.get_state()
len(state["nodes"]) # → 1
# Extension (prompt starts with previous full_text)
await managed.completion(prompt="Hello World", n=1)
state = managed.get_state()
len(state["nodes"]) # → 1 (replaced, not added)
```
#### 2. Tree Mode (track_tree=True)
- Maintains a dictionary of nodes keyed by `full_text`
- Every unique `full_text` creates a new entry
- Useful for multi-turn RL with per-step advantages
- Accessed via `state["sequences"]` or `state["tree"]`
```python
managed = ManagedServer(server, tokenizer=tokenizer, track_tree=True)
```
## Usage Patterns
### Pattern 1: Basic Single-Turn (Completion API)
Use with completion-style prompts (like math_server_zero.py):
```python
async def collect_trajectories(self, item):
prompt = format_prompt(item)
# Use managed server context
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
completion = await managed.completion(
prompt=prompt,
n=self.config.group_size, # e.g., 16
max_tokens=4096,
temperature=1.0,
top_p=1.0,
)
# Get tracked sequences
state = managed.get_state()
nodes = state["nodes"]
# Process nodes for training
to_score = []
for choice, node in zip(completion.choices, nodes):
to_score.append({
"full_text": node.full_text,
"tokens": node.tokens,
"masked_tokens": node.masked_tokens,
"logprobs": node.logprobs,
"finish_reason": node.metadata["finish_reason"],
})
return await self.score(to_score)
```
### Pattern 2: Basic Single-Turn (Chat Completion API)
Use with chat messages (like math_server.py):
```python
async def collect_trajectories(self, item):
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": item["question"]},
]
# Use managed server context
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
chat_completion = await managed.chat_completion(
messages=messages,
n=self.config.group_size,
max_tokens=4096,
temperature=1.0,
top_p=0.95,
)
# Get tracked sequences
state = managed.get_state()
nodes = state["nodes"]
# Process nodes
to_score = []
for choice, node in zip(chat_completion.choices, nodes):
to_score.append({
"content": choice.message.content,
"tokens": node.tokens,
"masked_tokens": node.masked_tokens,
"logprobs": node.logprobs,
"finish_reason": choice.finish_reason,
})
return await self.score(to_score)
```
### Pattern 3: Multi-Turn Conversations
ManagedServer automatically detects when a prompt extends a previous sequence:
```python
# Turn 1
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
await managed.completion(prompt="Hello", n=1)
state = managed.get_state()
# nodes[0].full_text = "Hello World"
# Turn 2: Extends turn 1
# This prompt starts with "Hello World" (turn 1's full_text)
await managed.completion(prompt="Hello World! How are you?", n=1)
state = managed.get_state()
# nodes[0].full_text = "Hello World! How are you? I'm great!"
# The node from turn 1 has been replaced with the extended version
```
**How Extension Detection Works:**
1. ManagedServer checks if the new prompt starts with any existing node's `full_text`
2. If yes, it reuses those tokens and only tokenizes the new suffix
3. The extended node replaces the original in the list
### Pattern 4: Multiple Contexts in One Method
You can use multiple managed_server contexts for complex workflows:
```python
async def collect_trajectories_rlaif(self, item):
# First set of completions
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
completions_fwd = await managed.chat_completion(
messages=messages_fwd,
n=3,
temperature=1.0,
)
state_fwd = managed.get_state()
# Second set of completions (independent context)
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
completions_bwd = await managed.chat_completion(
messages=messages_bwd,
n=3,
temperature=1.0,
)
state_bwd = managed.get_state()
# Process both sets
nodes_fwd = state_fwd["nodes"]
nodes_bwd = state_bwd["nodes"]
```
### Pattern 5: Passing Tokens Through Backlog
For complex multi-step workflows, you can pass pre-computed tokens/masks/logprobs through backlog items:
```python
async def collect_trajectories_normal(self, item):
# Generate initial completions
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
response = await managed.chat_completion(messages=chat, n=16)
state = managed.get_state()
nodes = state["nodes"]
# Find interesting pairs for RLAIF
if should_do_rlaif:
# Pass tokens/masks/logprobs to next stage
backlog_item = (
item["problem"],
item["answer"],
"rlaif", # Type marker
messages_1,
messages_2,
# Pre-computed data from managed_server
nodes[idx1].tokens, # Solution 1 tokens
nodes[idx1].masked_tokens, # Solution 1 masks
nodes[idx1].logprobs, # Solution 1 logprobs
nodes[idx2].tokens, # Solution 2 tokens
nodes[idx2].masked_tokens, # Solution 2 masks
nodes[idx2].logprobs, # Solution 2 logprobs
)
return None, [backlog_item]
async def collect_trajectories_rlaif(self, item):
# Extract pre-computed data
tokens_1 = item[5]
masks_1 = item[6]
logprobs_1 = item[7]
tokens_2 = item[8]
masks_2 = item[9]
logprobs_2 = item[10]
# Do RLAIF judgment...
# Use pre-computed tokens/masks/logprobs directly
return {
"tokens": [tokens_1, tokens_2],
"masks": [masks_1, masks_2],
"inference_logprobs": [logprobs_1, logprobs_2],
"scores": [score_1, score_2],
}
```
## Complete Examples
### Example 1: Completion API (math_server_zero.py)
```python
async def collect_trajectories(self, item) -> Tuple[List, List]:
# Format prompt
user_prompt = prompt_format.format(
prompt=problem_format.format(problem=item[0])
)
# Calculate max tokens
thinking_len = self.config.max_token_length - len(
self.tokenizer.encode(user_prompt)
)
# Use managed server for automatic token/logprob tracking
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
completion = await managed.completion(
prompt=user_prompt,
n=self.config.group_size,
max_tokens=thinking_len,
temperature=1.0,
top_p=1.0,
stop=stop_list,
)
# Get tracked sequences with aligned tokens and logprobs
state = managed.get_state()
nodes = state["nodes"]
# Extract data from SequenceNodes for scoring
to_score = []
for choice, node in zip(completion.choices, nodes):
to_score.append((
node.full_text, # Complete text (prompt + completion)
item[1], # Answer
choice.finish_reason, # Finish reason
node.tokens, # All tokens (prompt + completion)
node.masked_tokens, # Masked tokens (-100 for prompt, IDs for completion)
node.logprobs, # Logprobs (1.0 for prompt, actual for completion)
))
# Score and return
to_postprocess = await self.score(to_score)
return to_postprocess, []
```
### Example 2: Chat Completion API (math_server.py)
```python
async def collect_trajectories_normal(self, item) -> Tuple[List, List]:
# Prepare chat messages
user_prompt = problem_format.format(problem=item[0])
chat = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
# Calculate max tokens
thinking_len = self.config.max_token_length - len(
self.tokenizer.apply_chat_template(chat, add_generation_prompt=True)
)
# Use managed server for automatic token/logprob tracking
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
chat_completions = await managed.chat_completion(
messages=chat,
n=self.config.group_size,
max_tokens=thinking_len,
temperature=1.0,
top_p=0.95,
)
# Get tracked sequences with aligned tokens and logprobs
state = managed.get_state()
nodes = state["nodes"]
# Extract data from SequenceNodes for scoring
to_score = []
for chat_completion, node in zip(chat_completions.choices, nodes):
messages = (
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
{"role": "assistant", "content": chat_completion.message.content},
)
to_score.append((
messages, # Full conversation
item[1], # Answer
chat_completion.finish_reason, # Finish reason
node.tokens, # All tokens
node.masked_tokens, # Masked tokens
node.logprobs, # Logprobs
))
# Score and return
to_postprocess = await self.score_normal(to_score)
return to_postprocess, []
```
### Example 3: RLAIF with Multiple Contexts (math_server.py)
```python
async def collect_trajectories_rlaif(self, item) -> Tuple[List, List]:
# Prepare forward and backward prompts
user_prompt_fwd = rlaif_format.format(
problem=item[0],
solution1=solution1_text,
solution2=solution2_text,
)
user_prompt_bwd = rlaif_format.format(
problem=item[0],
solution1=solution2_text, # Swapped
solution2=solution1_text, # Swapped
)
# Generate both forward and backward judgments in parallel
async def get_fwd_completion():
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
return await managed.chat_completion(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt_fwd},
],
n=3,
max_tokens=max_tokens,
temperature=1.0,
top_p=0.95,
)
async def get_bwd_completion():
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
return await managed.chat_completion(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt_bwd},
],
n=3,
max_tokens=max_tokens,
temperature=1.0,
top_p=0.95,
)
# Gather both completions
completions_fwd, completions_bwd = await asyncio.gather(
get_fwd_completion(),
get_bwd_completion()
)
# Extract pre-computed tokens/masks/logprobs from item
# (These were stored when the original solutions were generated)
tokens_1 = item[6]
masks_1 = item[7]
logprobs_1 = item[8]
tokens_2 = item[9]
masks_2 = item[10]
logprobs_2 = item[11]
# Score based on judgments...
score_1, score_2 = calculate_scores(completions_fwd, completions_bwd)
# Return using pre-computed tokens
return {
"tokens": [tokens_1, tokens_2],
"masks": [masks_1, masks_2],
"inference_logprobs": [logprobs_1, logprobs_2],
"scores": [score_1, score_2],
"messages": [messages_1, messages_2],
}, []
```
## Migration from Manual Token Handling
### Before: Manual Approach
```python
async def collect_trajectories(self, item):
prompt = format_prompt(item)
# Call server
completion = await self.server.completion(
prompt=prompt,
n=8,
max_tokens=4096,
logprobs=True,
)
# Manually handle tokens
to_score = []
for choice in completion.choices:
# Manually tokenize full text
full_text = prompt + choice.text
tokens = self.tokenizer.encode(full_text, add_special_tokens=True)
# Manually compute prompt length
prompt_tokens = self.tokenizer.encode(prompt, add_special_tokens=True)
prompt_len = len(prompt_tokens)
# Manually apply masking
masked_tokens = [-100] * prompt_len + tokens[prompt_len:]
# Manually extract and align logprobs (error-prone!)
logprobs = [1.0] * prompt_len
if hasattr(choice, 'logprobs') and choice.logprobs:
for logprob_obj in choice.logprobs:
logprobs.append(logprob_obj.logprob)
# Manually pad/truncate to match length
while len(logprobs) < len(tokens):
logprobs.append(1.0)
to_score.append({
"tokens": tokens,
"masked_tokens": masked_tokens,
"logprobs": logprobs,
})
```
### After: ManagedServer Approach
```python
async def collect_trajectories(self, item):
prompt = format_prompt(item)
# Use managed server - everything automatic!
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
completion = await managed.completion(
prompt=prompt,
n=8,
max_tokens=4096,
)
state = managed.get_state()
nodes = state["nodes"]
# Extract pre-computed, guaranteed-aligned data
to_score = []
for node in nodes:
to_score.append({
"tokens": node.tokens, # ✅ Automatically computed
"masked_tokens": node.masked_tokens, # ✅ Automatically masked
"logprobs": node.logprobs, # ✅ Automatically aligned
})
```
**Benefits:**
- ❌ No manual tokenization needed
- ❌ No manual masking calculations
- ❌ No logprob extraction and alignment
- ❌ No off-by-one errors
- ✅ Clean, simple code
- ✅ Guaranteed correctness
## API Reference
### ManagedServer Class
```python
class ManagedServer:
def __init__(
self,
server: APIServer,
tokenizer: Optional[Any] = None,
track_tree: bool = False,
):
"""
Initialize the managed server.
Args:
server: The underlying APIServer instance to wrap
tokenizer: Tokenizer for encoding/decoding. If not provided,
will attempt to extract from server or create from model name.
track_tree: If True, maintains a tree structure with parent-child links.
If False (default), maintains a simple list that updates in-place.
"""
```
### Methods
#### `async def chat_completion(**kwargs) -> ChatCompletion`
Intercept chat completion call and track sequences.
**Args:**
- `messages`: List of message dicts with 'role' and 'content'
- `n`: Number of completions to generate
- `max_tokens`: Maximum tokens in completion
- Other standard chat completion parameters
**Returns:**
- `ChatCompletion` response (same as OpenAI API)
**Side Effects:**
- Tracks sequences in internal storage
- Updates `current_nodes` list (default mode) or `sequences` dict (tree mode)
#### `async def completion(**kwargs) -> Completion`
Intercept completion call and track sequences.
**Args:**
- `prompt`: The prompt string
- `n`: Number of completions to generate
- `max_tokens`: Maximum tokens in completion
- Other standard completion parameters
**Returns:**
- `Completion` response (same as OpenAI API)
**Side Effects:**
- Tracks sequences in internal storage
#### `async def get_logprobs(**kwargs) -> Dict[str, Any]`
Fetch logprobs with a normalized schema that is backend-agnostic.
**Args (common):**
- `messages` or `prompt` or `input_ids`
- `n`: Number of sampled sequences
- `max_tokens`
- Optional backend kwargs such as `top_k` / `top_logprobs`, `temperature`, `stop`
**Returns (normalized):**
```python
{
"prompt_tokens": List[int],
"prompt_topk_token_ids": List[List[int]], # [pos][k]
"prompt_topk_logprobs": List[List[float]], # [pos][k]
}
```
**Notes:**
- Strict mode: backend must provide real prompt top-k arrays.
- Missing keys should be treated as backend contract violations.
#### `def get_state() -> Dict[str, Any]`
Get the current state of tracked sequences.
**Returns:**
- For default mode (track_tree=False):
```python
{
"nodes": List[SequenceNode] # List of tracked sequences
}
```
- For tree mode (track_tree=True):
```python
{
"sequences": Dict[str, SequenceNode], # Keyed by full_text
"tree": Dict[str, SequenceNode], # Alias for compatibility
}
```
#### `def reset()`
Clear all tracked sequences.
### Context Manager (Recommended Usage)
```python
async with server_manager.managed_server(tokenizer=tokenizer) as managed:
# Use managed.completion() or managed.chat_completion()
...
# Get state before context exits
state = managed.get_state()
```
The context manager:
- Creates a `ManagedServer` instance
- Returns it for use within the block
- Automatically cleans up when the block exits
## Best Practices
1. **Always use the context manager pattern** for automatic cleanup:
```python
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
...
```
2. **Get state before exiting the context**:
```python
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
completion = await managed.completion(...)
state = managed.get_state() # ✅ Do this inside the context
# ❌ Don't try to access state here - context has exited
```
3. **Use separate contexts for independent completions**:
```python
# Context 1: Generate candidates
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
candidates = await managed.completion(...)
state1 = managed.get_state()
# Context 2: Judge candidates (independent)
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
judgments = await managed.completion(...)
state2 = managed.get_state()
```
## Troubleshooting
### Issue: "Extension detection not working"
**Cause:** The new prompt doesn't exactly start with previous node's `full_text`.
**Solution:** Ensure prompt strings match exactly, including whitespace:
```python
# Turn 1 produces: "Hello World"
# Turn 2 prompt must be: "Hello World..." (exact prefix match)
```
## OpenAI Endpoint Limitations
OpenAI's API does not expose token IDs or detailed logprobs in the same way that vLLM, SGLang, and other self-hosted inference servers do. This means **ManagedServer cannot provide accurate token-level training data** when using OpenAI endpoints.
### Default Behavior
By default, attempting to use `managed_server()` with an `OpenAIServer` will raise a `NotImplementedError`:
```python
async with self.server.managed_server() as managed:
# Raises NotImplementedError if server is OpenAIServer
...
```
The error message will explain the limitation and how to opt-in if you don't need real token data.
### DummyManagedServer (Opt-in)
If you're using OpenAI endpoints for **evaluation or testing** (not training) and don't need actual token IDs or logprobs, you can opt-in to use `DummyManagedServer` by setting an environment variable:
```bash
export ATROPOS_ALLOW_DUMMY_MANAGED_SERVER=1
```
With this flag set, `managed_server()` will return a `DummyManagedServer` that:
- Provides the same interface as `ManagedServer`
- Returns **fixed placeholder values** for tokens and logprobs (constant synthetic arrays)
- Uses simple text formatting for `full_text`: `role:content` joined by `\n\n`
- Raises for `get_logprobs(...)` in strict mode (no fake prompt-logprob payload)
### When to Use DummyManagedServer
✅ **Appropriate uses:**
- Testing environment logic without needing real token data
- Evaluation workflows where you only need completion text
- Prototyping before switching to a self-hosted inference server
❌ **Not appropriate for:**
- Training (tokens and logprobs are meaningless placeholders)
- Any workflow that depends on accurate token-level information
### Example
```python
import os
# Opt-in to dummy managed server for OpenAI
os.environ["ATROPOS_ALLOW_DUMMY_MANAGED_SERVER"] = "1"
# Now this works with OpenAI endpoints
async with self.server.managed_server() as managed:
response = await managed.chat_completion(messages=messages, n=4)
state = managed.get_state()
nodes = state["nodes"]
# nodes contain placeholder token data - DO NOT use for training
for node in nodes:
print(node.full_text) # Real completion text
print(node.tokens[:5]) # placeholder values
print(node.logprobs[:5]) # placeholder values
# Strict mode: get_logprobs is not available on DummyManagedServer
# and will raise NotImplementedError.
```
### Recommendation
For training workloads, use a self-hosted inference server (`VLLMServer`, `SGLangServer`, or `TrlVllmServer`) that provides full token and logprob access. OpenAI endpoints are best suited for evaluation, testing, or workflows that only need completion text.
## Additional Resources
- [ManagedServer Source Code](managed_server.py)
- [ManagedServer Tests](../../tests/test_managed_server.py)
- [Example: math_server_zero.py](../../../environments/math_server_zero.py#L320-L332)
- [Example: math_server.py](../../../environments/math_server.py#L377-L387)
- [BaseEnv Documentation](../README.md)
- [API Server Documentation](../../api/README.md)