mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-28 17:29:30 +00:00
Merge pull request #264 from NousResearch/add-logprob-server-manager-fn
add sglang specific token level logprob handling and server manager/b…
This commit is contained in:
commit
b1e164eef5
14 changed files with 2730 additions and 153 deletions
|
|
@ -17,6 +17,98 @@ To achieve this generality, our environment abstraction deviates from other open
|
|||
|
||||
- **Environments return tokens (not messages!)**: One of the most peculiar design choices we made was that at least for text-only environments, environments are responsible for tokenization. This gives us the flexibility to assign token-level rewards and to mix completions-based (e.g. autocomplete suggestion accept/reject) and chat-based (e.g. instruct-model code generation) environments together in the same training run. For cases like multimodal where a OpenAI-formatted message list needs to be passed to a transformers `AutoProcessor`, we support a `list[dict]`-valued `messages` key within our group abstraction [ScoredDataGroup](https://github.com/NousResearch/atropos/blob/a282604baac8dbb3b201f992cfc889ee1e5a0f4a/atroposlib/envs/base.py#L55).
|
||||
|
||||
## Working with Servers and ManagedServer
|
||||
|
||||
**🎯 Recommended Approach:** Use `ManagedServer` for automatic token and logprob tracking!
|
||||
|
||||
When implementing `collect_trajectory` or `collect_trajectories`, you need to interact with your inference server to generate completions and extract tokens/logprobs for training. The **recommended way** to do this is using `ManagedServer`, which automatically handles tokenization, masking, and logprob alignment.
|
||||
|
||||
### ManagedServer Overview
|
||||
|
||||
`ManagedServer` wraps your `APIServer` and automatically tracks:
|
||||
- **Tokens**: Full unmasked token sequences
|
||||
- **Masked Tokens**: Training format with `-100` for prompt positions, actual token IDs for completion
|
||||
- **Logprobs**: Training format with `1.0` for masked positions, actual logprob values for completion
|
||||
- **Full Text**: Complete text (prompt + completion)
|
||||
- **Metadata**: Finish reasons and other information
|
||||
|
||||
**Why 1.0 for masked logprobs?** It represents an "obviously bad" probability (e^1.0 ≈ 2.718 > 1.0, which is invalid), making it easy to identify and ignore during training.
|
||||
|
||||
### Basic Usage Pattern
|
||||
|
||||
```python
|
||||
async def collect_trajectories(self, item):
|
||||
prompt = format_prompt(item)
|
||||
|
||||
# Use managed server with context manager
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
completion = await managed.completion(
|
||||
prompt=prompt,
|
||||
n=self.config.group_size,
|
||||
max_tokens=4096,
|
||||
temperature=1.0,
|
||||
)
|
||||
|
||||
# Get tracked sequences with aligned tokens and logprobs
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
# Extract pre-computed, guaranteed-aligned data
|
||||
for choice, node in zip(completion.choices, nodes):
|
||||
tokens = node.tokens # ✅ Automatically computed
|
||||
masked_tokens = node.masked_tokens # ✅ Automatically masked
|
||||
logprobs = node.logprobs # ✅ Automatically aligned
|
||||
finish_reason = node.metadata["finish_reason"]
|
||||
|
||||
# Score and return...
|
||||
```
|
||||
|
||||
### Chat Completion Pattern
|
||||
|
||||
For chat-based environments, use `chat_completion()`:
|
||||
|
||||
```python
|
||||
async def collect_trajectories(self, item):
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": item["question"]},
|
||||
]
|
||||
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
chat_completion = await managed.chat_completion(
|
||||
messages=messages,
|
||||
n=self.config.group_size,
|
||||
max_tokens=4096,
|
||||
)
|
||||
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
# Process nodes...
|
||||
```
|
||||
|
||||
### Benefits Over Manual Handling
|
||||
|
||||
❌ **Without ManagedServer:**
|
||||
- Manually tokenize prompts and completions
|
||||
- Manually compute prompt lengths
|
||||
- Manually apply masking logic
|
||||
- Manually extract and align logprobs
|
||||
- Prone to off-by-one errors
|
||||
|
||||
✅ **With ManagedServer:**
|
||||
- Automatic tokenization
|
||||
- Automatic masking
|
||||
- Guaranteed alignment
|
||||
- Clean, simple code
|
||||
- Works with both `completion()` and `chat_completion()` APIs
|
||||
|
||||
### Complete Documentation
|
||||
|
||||
For detailed examples, advanced patterns (multi-turn, RLAIF, backlog workflows), API reference, and migration guide, see:
|
||||
|
||||
📚 **[ManagedServer Complete Guide](server_handling/MANAGED_SERVER.md)**
|
||||
|
||||
## Core Methods to Implement
|
||||
|
||||
These methods **must** be implemented in your subclass:
|
||||
|
|
|
|||
692
atroposlib/envs/server_handling/MANAGED_SERVER.md
Normal file
692
atroposlib/envs/server_handling/MANAGED_SERVER.md
Normal file
|
|
@ -0,0 +1,692 @@
|
|||
# ManagedServer: Automatic Token and Logprob Tracking
|
||||
|
||||
## Overview
|
||||
|
||||
`ManagedServer` is a wrapper around `APIServer` that automatically tracks text sequences with aligned tokens and logprobs. It eliminates the need for manual token extraction, alignment, and masking in your environment code, making it **the recommended approach** for handling inference in Atropos environments.
|
||||
|
||||
### Why Use ManagedServer?
|
||||
|
||||
**Before ManagedServer** (manual approach):
|
||||
```python
|
||||
# Manual token extraction
|
||||
response = await self.server.completion(prompt=prompt, n=8)
|
||||
# Manually tokenize and align
|
||||
tokens = self.tokenizer.encode(prompt + response.text)
|
||||
# Manually apply masking
|
||||
prompt_len = len(self.tokenizer.encode(prompt))
|
||||
masked_tokens = [-100] * prompt_len + tokens[prompt_len:]
|
||||
# Manually extract and align logprobs
|
||||
logprobs = extract_logprobs_somehow(response)
|
||||
```
|
||||
|
||||
**With ManagedServer** (automatic):
|
||||
```python
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
response = await managed.completion(prompt=prompt, n=8)
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
# tokens, masked_tokens, and logprobs are already aligned and ready!
|
||||
```
|
||||
|
||||
### Key Benefits
|
||||
|
||||
- ✅ **Automatic Tokenization**: No need to manually tokenize prompts and completions
|
||||
- ✅ **Automatic Masking**: Prompt tokens automatically masked with -100, logprobs with 1.0
|
||||
- ✅ **Perfect Alignment**: Tokens and logprobs are guaranteed to align correctly
|
||||
- ✅ **Multi-turn Support**: Automatically handles conversation extensions
|
||||
- ✅ **Branching Support**: Handles n>1 completions naturally
|
||||
- ✅ **Clean API**: Simple context manager pattern
|
||||
- ✅ **Less Error-Prone**: Eliminates common token alignment bugs
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### SequenceNode Structure
|
||||
|
||||
Each completion tracked by ManagedServer is stored as a `SequenceNode`:
|
||||
|
||||
```python
|
||||
class SequenceNode(BaseModel):
|
||||
full_text: str # Complete text (prompt + completion)
|
||||
tokens: List[int] # Full token sequence (unmasked)
|
||||
masked_tokens: List[int] # Tokens for training (-100 for prompt, actual IDs for completion)
|
||||
logprobs: List[float] # Logprobs for training (1.0 for prompt, actual values for completion)
|
||||
metadata: Optional[Dict[str, Any]] # Contains finish_reason, etc.
|
||||
```
|
||||
|
||||
### Masking Methodology
|
||||
|
||||
ManagedServer applies automatic masking to distinguish between prompt and completion:
|
||||
|
||||
| Field | Masked Positions | Completion Positions | Purpose |
|
||||
|-------|------------------|-----------------------|--------------------------------|
|
||||
| `tokens` | Actual token IDs | Actual token IDs | Full unmasked sequence |
|
||||
| `masked_tokens` | **-100** | Actual token IDs | Training input (mask prompts) |
|
||||
| `logprobs` | **1.0** | Actual logprob values | Training target (mask prompts) |
|
||||
|
||||
**Why 1.0 for masked logprobs?**
|
||||
|
||||
The value 1.0 is used to indicate "obviously bad" logprobs for prompt positions:
|
||||
- `e^1.0 ≈ 2.718`, which would represent a probability > 1.0 (invalid)
|
||||
- This makes masked positions easy to identify and filter during training
|
||||
- Trainers should ignore positions where `logprobs > 0.0` or where `masked_tokens == -100`
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
# Prompt: "What is 2+2?"
|
||||
# Completion: " 4"
|
||||
# Tokenized: [1, 1867, 374, 220, 17, 10, 17, 30] + [220, 19]
|
||||
|
||||
node.tokens = [1, 1867, 374, 220, 17, 10, 17, 30, 220, 19]
|
||||
node.masked_tokens = [-100, -100, -100, -100, -100, -100, -100, -100, 220, 19]
|
||||
node.logprobs = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -0.342, -0.156]
|
||||
```
|
||||
|
||||
### Two Operating Modes
|
||||
|
||||
ManagedServer supports two modes for tracking sequences:
|
||||
|
||||
#### 1. Default Mode (track_tree=False)
|
||||
- Maintains a simple list of current nodes
|
||||
- When a new prompt **extends** an existing node's `full_text`, it **replaces** that node
|
||||
- Best for most RL scenarios (GRPO, DPO, etc.)
|
||||
- Accessed via `state["nodes"]`
|
||||
|
||||
```python
|
||||
async with server.managed_server(tokenizer=tokenizer) as managed:
|
||||
# First completion
|
||||
await managed.completion(prompt="Hello", n=1)
|
||||
state = managed.get_state()
|
||||
len(state["nodes"]) # → 1
|
||||
|
||||
# Extension (prompt starts with previous full_text)
|
||||
await managed.completion(prompt="Hello World", n=1)
|
||||
state = managed.get_state()
|
||||
len(state["nodes"]) # → 1 (replaced, not added)
|
||||
```
|
||||
|
||||
#### 2. Tree Mode (track_tree=True)
|
||||
- Maintains a dictionary of nodes keyed by `full_text`
|
||||
- Every unique `full_text` creates a new entry
|
||||
- Useful for multi-turn RL with per-step advantages
|
||||
- Accessed via `state["sequences"]` or `state["tree"]`
|
||||
|
||||
```python
|
||||
managed = ManagedServer(server, tokenizer=tokenizer, track_tree=True)
|
||||
```
|
||||
|
||||
## Usage Patterns
|
||||
|
||||
### Pattern 1: Basic Single-Turn (Completion API)
|
||||
|
||||
Use with completion-style prompts (like math_server_zero.py):
|
||||
|
||||
```python
|
||||
async def collect_trajectories(self, item):
|
||||
prompt = format_prompt(item)
|
||||
|
||||
# Use managed server context
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
completion = await managed.completion(
|
||||
prompt=prompt,
|
||||
n=self.config.group_size, # e.g., 16
|
||||
max_tokens=4096,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
)
|
||||
|
||||
# Get tracked sequences
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
# Process nodes for training
|
||||
to_score = []
|
||||
for choice, node in zip(completion.choices, nodes):
|
||||
to_score.append({
|
||||
"full_text": node.full_text,
|
||||
"tokens": node.tokens,
|
||||
"masked_tokens": node.masked_tokens,
|
||||
"logprobs": node.logprobs,
|
||||
"finish_reason": node.metadata["finish_reason"],
|
||||
})
|
||||
|
||||
return await self.score(to_score)
|
||||
```
|
||||
|
||||
### Pattern 2: Basic Single-Turn (Chat Completion API)
|
||||
|
||||
Use with chat messages (like math_server.py):
|
||||
|
||||
```python
|
||||
async def collect_trajectories(self, item):
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": item["question"]},
|
||||
]
|
||||
|
||||
# Use managed server context
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
chat_completion = await managed.chat_completion(
|
||||
messages=messages,
|
||||
n=self.config.group_size,
|
||||
max_tokens=4096,
|
||||
temperature=1.0,
|
||||
top_p=0.95,
|
||||
)
|
||||
|
||||
# Get tracked sequences
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
# Process nodes
|
||||
to_score = []
|
||||
for choice, node in zip(chat_completion.choices, nodes):
|
||||
to_score.append({
|
||||
"content": choice.message.content,
|
||||
"tokens": node.tokens,
|
||||
"masked_tokens": node.masked_tokens,
|
||||
"logprobs": node.logprobs,
|
||||
"finish_reason": choice.finish_reason,
|
||||
})
|
||||
|
||||
return await self.score(to_score)
|
||||
```
|
||||
|
||||
### Pattern 3: Multi-Turn Conversations
|
||||
|
||||
ManagedServer automatically detects when a prompt extends a previous sequence:
|
||||
|
||||
```python
|
||||
# Turn 1
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
await managed.completion(prompt="Hello", n=1)
|
||||
state = managed.get_state()
|
||||
# nodes[0].full_text = "Hello World"
|
||||
|
||||
# Turn 2: Extends turn 1
|
||||
# This prompt starts with "Hello World" (turn 1's full_text)
|
||||
await managed.completion(prompt="Hello World! How are you?", n=1)
|
||||
state = managed.get_state()
|
||||
# nodes[0].full_text = "Hello World! How are you? I'm great!"
|
||||
# The node from turn 1 has been replaced with the extended version
|
||||
```
|
||||
|
||||
**How Extension Detection Works:**
|
||||
1. ManagedServer checks if the new prompt starts with any existing node's `full_text`
|
||||
2. If yes, it reuses those tokens and only tokenizes the new suffix
|
||||
3. The extended node replaces the original in the list
|
||||
|
||||
### Pattern 4: Multiple Contexts in One Method
|
||||
|
||||
You can use multiple managed_server contexts for complex workflows:
|
||||
|
||||
```python
|
||||
async def collect_trajectories_rlaif(self, item):
|
||||
# First set of completions
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
completions_fwd = await managed.chat_completion(
|
||||
messages=messages_fwd,
|
||||
n=3,
|
||||
temperature=1.0,
|
||||
)
|
||||
state_fwd = managed.get_state()
|
||||
|
||||
# Second set of completions (independent context)
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
completions_bwd = await managed.chat_completion(
|
||||
messages=messages_bwd,
|
||||
n=3,
|
||||
temperature=1.0,
|
||||
)
|
||||
state_bwd = managed.get_state()
|
||||
|
||||
# Process both sets
|
||||
nodes_fwd = state_fwd["nodes"]
|
||||
nodes_bwd = state_bwd["nodes"]
|
||||
```
|
||||
|
||||
### Pattern 5: Passing Tokens Through Backlog
|
||||
|
||||
For complex multi-step workflows, you can pass pre-computed tokens/masks/logprobs through backlog items:
|
||||
|
||||
```python
|
||||
async def collect_trajectories_normal(self, item):
|
||||
# Generate initial completions
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
response = await managed.chat_completion(messages=chat, n=16)
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
# Find interesting pairs for RLAIF
|
||||
if should_do_rlaif:
|
||||
# Pass tokens/masks/logprobs to next stage
|
||||
backlog_item = (
|
||||
item["problem"],
|
||||
item["answer"],
|
||||
"rlaif", # Type marker
|
||||
messages_1,
|
||||
messages_2,
|
||||
# Pre-computed data from managed_server
|
||||
nodes[idx1].tokens, # Solution 1 tokens
|
||||
nodes[idx1].masked_tokens, # Solution 1 masks
|
||||
nodes[idx1].logprobs, # Solution 1 logprobs
|
||||
nodes[idx2].tokens, # Solution 2 tokens
|
||||
nodes[idx2].masked_tokens, # Solution 2 masks
|
||||
nodes[idx2].logprobs, # Solution 2 logprobs
|
||||
)
|
||||
return None, [backlog_item]
|
||||
|
||||
async def collect_trajectories_rlaif(self, item):
|
||||
# Extract pre-computed data
|
||||
tokens_1 = item[5]
|
||||
masks_1 = item[6]
|
||||
logprobs_1 = item[7]
|
||||
tokens_2 = item[8]
|
||||
masks_2 = item[9]
|
||||
logprobs_2 = item[10]
|
||||
|
||||
# Do RLAIF judgment...
|
||||
# Use pre-computed tokens/masks/logprobs directly
|
||||
return {
|
||||
"tokens": [tokens_1, tokens_2],
|
||||
"masks": [masks_1, masks_2],
|
||||
"inference_logprobs": [logprobs_1, logprobs_2],
|
||||
"scores": [score_1, score_2],
|
||||
}
|
||||
```
|
||||
|
||||
## Complete Examples
|
||||
|
||||
### Example 1: Completion API (math_server_zero.py)
|
||||
|
||||
```python
|
||||
async def collect_trajectories(self, item) -> Tuple[List, List]:
|
||||
# Format prompt
|
||||
user_prompt = prompt_format.format(
|
||||
prompt=problem_format.format(problem=item[0])
|
||||
)
|
||||
|
||||
# Calculate max tokens
|
||||
thinking_len = self.config.max_token_length - len(
|
||||
self.tokenizer.encode(user_prompt)
|
||||
)
|
||||
|
||||
# Use managed server for automatic token/logprob tracking
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
completion = await managed.completion(
|
||||
prompt=user_prompt,
|
||||
n=self.config.group_size,
|
||||
max_tokens=thinking_len,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
stop=stop_list,
|
||||
)
|
||||
|
||||
# Get tracked sequences with aligned tokens and logprobs
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
# Extract data from SequenceNodes for scoring
|
||||
to_score = []
|
||||
for choice, node in zip(completion.choices, nodes):
|
||||
to_score.append((
|
||||
node.full_text, # Complete text (prompt + completion)
|
||||
item[1], # Answer
|
||||
choice.finish_reason, # Finish reason
|
||||
node.tokens, # All tokens (prompt + completion)
|
||||
node.masked_tokens, # Masked tokens (-100 for prompt, IDs for completion)
|
||||
node.logprobs, # Logprobs (1.0 for prompt, actual for completion)
|
||||
))
|
||||
|
||||
# Score and return
|
||||
to_postprocess = await self.score(to_score)
|
||||
return to_postprocess, []
|
||||
```
|
||||
|
||||
### Example 2: Chat Completion API (math_server.py)
|
||||
|
||||
```python
|
||||
async def collect_trajectories_normal(self, item) -> Tuple[List, List]:
|
||||
# Prepare chat messages
|
||||
user_prompt = problem_format.format(problem=item[0])
|
||||
chat = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
# Calculate max tokens
|
||||
thinking_len = self.config.max_token_length - len(
|
||||
self.tokenizer.apply_chat_template(chat, add_generation_prompt=True)
|
||||
)
|
||||
|
||||
# Use managed server for automatic token/logprob tracking
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
chat_completions = await managed.chat_completion(
|
||||
messages=chat,
|
||||
n=self.config.group_size,
|
||||
max_tokens=thinking_len,
|
||||
temperature=1.0,
|
||||
top_p=0.95,
|
||||
)
|
||||
|
||||
# Get tracked sequences with aligned tokens and logprobs
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
# Extract data from SequenceNodes for scoring
|
||||
to_score = []
|
||||
for chat_completion, node in zip(chat_completions.choices, nodes):
|
||||
messages = (
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
{"role": "assistant", "content": chat_completion.message.content},
|
||||
)
|
||||
to_score.append((
|
||||
messages, # Full conversation
|
||||
item[1], # Answer
|
||||
chat_completion.finish_reason, # Finish reason
|
||||
node.tokens, # All tokens
|
||||
node.masked_tokens, # Masked tokens
|
||||
node.logprobs, # Logprobs
|
||||
))
|
||||
|
||||
# Score and return
|
||||
to_postprocess = await self.score_normal(to_score)
|
||||
return to_postprocess, []
|
||||
```
|
||||
|
||||
### Example 3: RLAIF with Multiple Contexts (math_server.py)
|
||||
|
||||
```python
|
||||
async def collect_trajectories_rlaif(self, item) -> Tuple[List, List]:
|
||||
# Prepare forward and backward prompts
|
||||
user_prompt_fwd = rlaif_format.format(
|
||||
problem=item[0],
|
||||
solution1=solution1_text,
|
||||
solution2=solution2_text,
|
||||
)
|
||||
user_prompt_bwd = rlaif_format.format(
|
||||
problem=item[0],
|
||||
solution1=solution2_text, # Swapped
|
||||
solution2=solution1_text, # Swapped
|
||||
)
|
||||
|
||||
# Generate both forward and backward judgments in parallel
|
||||
async def get_fwd_completion():
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
return await managed.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt_fwd},
|
||||
],
|
||||
n=3,
|
||||
max_tokens=max_tokens,
|
||||
temperature=1.0,
|
||||
top_p=0.95,
|
||||
)
|
||||
|
||||
async def get_bwd_completion():
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
return await managed.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt_bwd},
|
||||
],
|
||||
n=3,
|
||||
max_tokens=max_tokens,
|
||||
temperature=1.0,
|
||||
top_p=0.95,
|
||||
)
|
||||
|
||||
# Gather both completions
|
||||
completions_fwd, completions_bwd = await asyncio.gather(
|
||||
get_fwd_completion(),
|
||||
get_bwd_completion()
|
||||
)
|
||||
|
||||
# Extract pre-computed tokens/masks/logprobs from item
|
||||
# (These were stored when the original solutions were generated)
|
||||
tokens_1 = item[6]
|
||||
masks_1 = item[7]
|
||||
logprobs_1 = item[8]
|
||||
tokens_2 = item[9]
|
||||
masks_2 = item[10]
|
||||
logprobs_2 = item[11]
|
||||
|
||||
# Score based on judgments...
|
||||
score_1, score_2 = calculate_scores(completions_fwd, completions_bwd)
|
||||
|
||||
# Return using pre-computed tokens
|
||||
return {
|
||||
"tokens": [tokens_1, tokens_2],
|
||||
"masks": [masks_1, masks_2],
|
||||
"inference_logprobs": [logprobs_1, logprobs_2],
|
||||
"scores": [score_1, score_2],
|
||||
"messages": [messages_1, messages_2],
|
||||
}, []
|
||||
```
|
||||
|
||||
## Migration from Manual Token Handling
|
||||
|
||||
### Before: Manual Approach
|
||||
|
||||
```python
|
||||
async def collect_trajectories(self, item):
|
||||
prompt = format_prompt(item)
|
||||
|
||||
# Call server
|
||||
completion = await self.server.completion(
|
||||
prompt=prompt,
|
||||
n=8,
|
||||
max_tokens=4096,
|
||||
logprobs=True,
|
||||
)
|
||||
|
||||
# Manually handle tokens
|
||||
to_score = []
|
||||
for choice in completion.choices:
|
||||
# Manually tokenize full text
|
||||
full_text = prompt + choice.text
|
||||
tokens = self.tokenizer.encode(full_text, add_special_tokens=True)
|
||||
|
||||
# Manually compute prompt length
|
||||
prompt_tokens = self.tokenizer.encode(prompt, add_special_tokens=True)
|
||||
prompt_len = len(prompt_tokens)
|
||||
|
||||
# Manually apply masking
|
||||
masked_tokens = [-100] * prompt_len + tokens[prompt_len:]
|
||||
|
||||
# Manually extract and align logprobs (error-prone!)
|
||||
logprobs = [1.0] * prompt_len
|
||||
if hasattr(choice, 'logprobs') and choice.logprobs:
|
||||
for logprob_obj in choice.logprobs:
|
||||
logprobs.append(logprob_obj.logprob)
|
||||
|
||||
# Manually pad/truncate to match length
|
||||
while len(logprobs) < len(tokens):
|
||||
logprobs.append(1.0)
|
||||
|
||||
to_score.append({
|
||||
"tokens": tokens,
|
||||
"masked_tokens": masked_tokens,
|
||||
"logprobs": logprobs,
|
||||
})
|
||||
```
|
||||
|
||||
### After: ManagedServer Approach
|
||||
|
||||
```python
|
||||
async def collect_trajectories(self, item):
|
||||
prompt = format_prompt(item)
|
||||
|
||||
# Use managed server - everything automatic!
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
completion = await managed.completion(
|
||||
prompt=prompt,
|
||||
n=8,
|
||||
max_tokens=4096,
|
||||
)
|
||||
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
# Extract pre-computed, guaranteed-aligned data
|
||||
to_score = []
|
||||
for node in nodes:
|
||||
to_score.append({
|
||||
"tokens": node.tokens, # ✅ Automatically computed
|
||||
"masked_tokens": node.masked_tokens, # ✅ Automatically masked
|
||||
"logprobs": node.logprobs, # ✅ Automatically aligned
|
||||
})
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- ❌ No manual tokenization needed
|
||||
- ❌ No manual masking calculations
|
||||
- ❌ No logprob extraction and alignment
|
||||
- ❌ No off-by-one errors
|
||||
- ✅ Clean, simple code
|
||||
- ✅ Guaranteed correctness
|
||||
|
||||
## API Reference
|
||||
|
||||
### ManagedServer Class
|
||||
|
||||
```python
|
||||
class ManagedServer:
|
||||
def __init__(
|
||||
self,
|
||||
server: APIServer,
|
||||
tokenizer: Optional[Any] = None,
|
||||
track_tree: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the managed server.
|
||||
|
||||
Args:
|
||||
server: The underlying APIServer instance to wrap
|
||||
tokenizer: Tokenizer for encoding/decoding. If not provided,
|
||||
will attempt to extract from server or create from model name.
|
||||
track_tree: If True, maintains a tree structure with parent-child links.
|
||||
If False (default), maintains a simple list that updates in-place.
|
||||
"""
|
||||
```
|
||||
|
||||
### Methods
|
||||
|
||||
#### `async def chat_completion(**kwargs) -> ChatCompletion`
|
||||
Intercept chat completion call and track sequences.
|
||||
|
||||
**Args:**
|
||||
- `messages`: List of message dicts with 'role' and 'content'
|
||||
- `n`: Number of completions to generate
|
||||
- `max_tokens`: Maximum tokens in completion
|
||||
- Other standard chat completion parameters
|
||||
|
||||
**Returns:**
|
||||
- `ChatCompletion` response (same as OpenAI API)
|
||||
|
||||
**Side Effects:**
|
||||
- Tracks sequences in internal storage
|
||||
- Updates `current_nodes` list (default mode) or `sequences` dict (tree mode)
|
||||
|
||||
#### `async def completion(**kwargs) -> Completion`
|
||||
Intercept completion call and track sequences.
|
||||
|
||||
**Args:**
|
||||
- `prompt`: The prompt string
|
||||
- `n`: Number of completions to generate
|
||||
- `max_tokens`: Maximum tokens in completion
|
||||
- Other standard completion parameters
|
||||
|
||||
**Returns:**
|
||||
- `Completion` response (same as OpenAI API)
|
||||
|
||||
**Side Effects:**
|
||||
- Tracks sequences in internal storage
|
||||
|
||||
#### `def get_state() -> Dict[str, Any]`
|
||||
Get the current state of tracked sequences.
|
||||
|
||||
**Returns:**
|
||||
- For default mode (track_tree=False):
|
||||
```python
|
||||
{
|
||||
"nodes": List[SequenceNode] # List of tracked sequences
|
||||
}
|
||||
```
|
||||
- For tree mode (track_tree=True):
|
||||
```python
|
||||
{
|
||||
"sequences": Dict[str, SequenceNode], # Keyed by full_text
|
||||
"tree": Dict[str, SequenceNode], # Alias for compatibility
|
||||
}
|
||||
```
|
||||
|
||||
#### `def reset()`
|
||||
Clear all tracked sequences.
|
||||
|
||||
### Context Manager (Recommended Usage)
|
||||
|
||||
```python
|
||||
async with server_manager.managed_server(tokenizer=tokenizer) as managed:
|
||||
# Use managed.completion() or managed.chat_completion()
|
||||
...
|
||||
# Get state before context exits
|
||||
state = managed.get_state()
|
||||
```
|
||||
|
||||
The context manager:
|
||||
- Creates a `ManagedServer` instance
|
||||
- Returns it for use within the block
|
||||
- Automatically cleans up when the block exits
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Always use the context manager pattern** for automatic cleanup:
|
||||
```python
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
...
|
||||
```
|
||||
|
||||
2. **Get state before exiting the context**:
|
||||
```python
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
completion = await managed.completion(...)
|
||||
state = managed.get_state() # ✅ Do this inside the context
|
||||
# ❌ Don't try to access state here - context has exited
|
||||
```
|
||||
|
||||
3. **Use separate contexts for independent completions**:
|
||||
```python
|
||||
# Context 1: Generate candidates
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
candidates = await managed.completion(...)
|
||||
state1 = managed.get_state()
|
||||
|
||||
# Context 2: Judge candidates (independent)
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
judgments = await managed.completion(...)
|
||||
state2 = managed.get_state()
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue: "Extension detection not working"
|
||||
|
||||
**Cause:** The new prompt doesn't exactly start with previous node's `full_text`.
|
||||
|
||||
**Solution:** Ensure prompt strings match exactly, including whitespace:
|
||||
```python
|
||||
# Turn 1 produces: "Hello World"
|
||||
# Turn 2 prompt must be: "Hello World..." (exact prefix match)
|
||||
```
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [ManagedServer Source Code](managed_server.py)
|
||||
- [ManagedServer Tests](../../tests/test_managed_server.py)
|
||||
- [Example: math_server_zero.py](../../../../environments/math_server_zero.py#L320-L332)
|
||||
- [Example: math_server.py](../../../../environments/math_server.py#L377-L387)
|
||||
- [BaseEnv Documentation](../README.md)
|
||||
- [API Server Documentation](../../api/README.md)
|
||||
511
atroposlib/envs/server_handling/managed_server.py
Normal file
511
atroposlib/envs/server_handling/managed_server.py
Normal file
|
|
@ -0,0 +1,511 @@
|
|||
"""
|
||||
Managed server wrapper that tracks text sequences with aligned tokens and logprobs.
|
||||
|
||||
This wrapper maintains a tree structure of sequences, where:
|
||||
- Each node represents a complete text sequence (prompt + completion)
|
||||
- Tokens and logprobs are tracked with proper masking for training
|
||||
- Branching occurs organically from different contexts and n > 1 completions
|
||||
"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion import (
|
||||
ChatCompletion,
|
||||
ChatCompletionMessage,
|
||||
Choice,
|
||||
)
|
||||
from openai.types.completion import Completion, CompletionChoice
|
||||
from pydantic import BaseModel
|
||||
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServer
|
||||
|
||||
|
||||
class SequenceNode(BaseModel):
|
||||
"""
|
||||
A node in the sequence tree representing a complete text sequence.
|
||||
|
||||
Attributes:
|
||||
full_text: Complete text (prompt + completion)
|
||||
tokens: Full token sequence (actual token IDs)
|
||||
masked_tokens: Tokens with -100 for prompt positions, actual IDs for completion
|
||||
logprobs: Logprobs with 1.0 for prompt positions, actual values for completion
|
||||
metadata: Optional metadata (e.g., role information, finish_reason, etc.)
|
||||
"""
|
||||
|
||||
full_text: str
|
||||
tokens: List[int]
|
||||
masked_tokens: List[int]
|
||||
logprobs: List[float]
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ManagedServer:
|
||||
"""
|
||||
Wrapper around APIServer that tracks sequences with aligned tokens and logprobs.
|
||||
|
||||
Maintains a tree structure keyed by input text, where each completion creates
|
||||
new branches. Provides proper masking for training (prompt tokens masked with -100,
|
||||
logprobs set to 1.0).
|
||||
|
||||
Uses the clean tokens_and_logprobs_completion interface internally.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: APIServer,
|
||||
tokenizer: Optional[Any] = None,
|
||||
track_tree: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the managed server.
|
||||
|
||||
Args:
|
||||
server: The underlying APIServer instance to wrap
|
||||
tokenizer: Optional tokenizer for encoding/decoding. If not provided,
|
||||
will attempt to extract from server or create from model name.
|
||||
track_tree: If True, maintains a tree structure with parent-child links
|
||||
(for multi-turn RL with per-step advantages). If False (default),
|
||||
maintains a simple list of current nodes that updates in-place.
|
||||
"""
|
||||
self.server = server
|
||||
self.tokenizer = tokenizer
|
||||
self.track_tree = track_tree
|
||||
|
||||
# Initialize storage based on mode
|
||||
if track_tree:
|
||||
self.sequences: Dict[str, SequenceNode] = {} # Tree mode: dict lookup
|
||||
else:
|
||||
self.current_nodes: List[SequenceNode] = [] # Default mode: simple list
|
||||
|
||||
# Try to get tokenizer from server if not provided
|
||||
if self.tokenizer is None:
|
||||
self._initialize_tokenizer()
|
||||
|
||||
def _initialize_tokenizer(self):
|
||||
"""Initialize tokenizer from server or model name."""
|
||||
# Check if the wrapped server has a tokenizer
|
||||
if hasattr(self.server, "tokenizer"):
|
||||
self.tokenizer = self.server.tokenizer
|
||||
else:
|
||||
# Try to create from model name
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
model_name = self.server.config.model_name
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
f"Could not initialize tokenizer: {e}. "
|
||||
"Sequence tracking will be limited without tokenizer."
|
||||
)
|
||||
self.tokenizer = None
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
|
||||
"""
|
||||
Convert chat messages to prompt text using tokenizer's chat template.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'
|
||||
|
||||
Returns:
|
||||
Formatted prompt string
|
||||
"""
|
||||
if self.tokenizer is None:
|
||||
# Fallback: simple concatenation
|
||||
return "\n".join([f"{m['role']}: {m['content']}" for m in messages])
|
||||
|
||||
if hasattr(self.tokenizer, "apply_chat_template"):
|
||||
# Only add generation prompt if last message is not from assistant
|
||||
add_generation_prompt = (
|
||||
len(messages) == 0 or messages[-1].get("role") != "assistant"
|
||||
)
|
||||
|
||||
# Use the tokenizer's chat template
|
||||
return self.tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=add_generation_prompt
|
||||
)
|
||||
else:
|
||||
# Fallback for tokenizers without chat template
|
||||
return "\n".join([f"{m['role']}: {m['content']}" for m in messages])
|
||||
|
||||
def _find_extending_node(self, input_text: str) -> Optional[SequenceNode]:
|
||||
"""
|
||||
Find a node that this input extends (default mode).
|
||||
|
||||
Args:
|
||||
input_text: The input text to check
|
||||
|
||||
Returns:
|
||||
The node that input_text extends, or None if no match
|
||||
"""
|
||||
if not self.current_nodes:
|
||||
return None
|
||||
|
||||
# Check if any current node's full_text is a prefix of the input
|
||||
# This means the input is extending that node
|
||||
for node in self.current_nodes:
|
||||
if input_text.startswith(node.full_text):
|
||||
return node
|
||||
return None
|
||||
|
||||
def _compute_input_ids(
|
||||
self, input_text: str, extending_node: Optional[SequenceNode]
|
||||
) -> List[int]:
|
||||
"""
|
||||
Compute input_ids for the prompt, using existing tokens if extending.
|
||||
|
||||
Args:
|
||||
input_text: The full input prompt text
|
||||
extending_node: Node being extended, if any
|
||||
|
||||
Returns:
|
||||
List of token IDs to use as input_ids
|
||||
"""
|
||||
if extending_node is not None:
|
||||
# Extending an existing sequence - use its tokens + tokenize the new part
|
||||
existing_text = extending_node.full_text
|
||||
new_text_suffix = input_text[len(existing_text) :]
|
||||
|
||||
# Tokenize only the new suffix (without BOS since we're continuing)
|
||||
if new_text_suffix:
|
||||
new_tokens = self.tokenizer.encode(
|
||||
new_text_suffix, add_special_tokens=False
|
||||
)
|
||||
return extending_node.tokens + new_tokens
|
||||
else:
|
||||
# No new text, just use existing tokens
|
||||
return extending_node.tokens.copy()
|
||||
else:
|
||||
# New sequence - tokenize the whole thing
|
||||
return self.tokenizer.encode(input_text, add_special_tokens=True)
|
||||
|
||||
def _find_parent_node(self, input_text: str) -> Optional[SequenceNode]:
|
||||
"""
|
||||
Find a parent node whose full_text matches the input_text (tree mode).
|
||||
|
||||
Args:
|
||||
input_text: The input text to search for
|
||||
|
||||
Returns:
|
||||
Parent SequenceNode if found, None otherwise
|
||||
"""
|
||||
return self.sequences.get(input_text, None)
|
||||
|
||||
def _create_sequence_node(
|
||||
self,
|
||||
input_text: str,
|
||||
parent_node: Optional[SequenceNode],
|
||||
prompt_tokens: List[int],
|
||||
output_tokens: List[int],
|
||||
output_logprobs: List[float],
|
||||
completion_text: str,
|
||||
finish_reason: str = "stop",
|
||||
) -> SequenceNode:
|
||||
"""
|
||||
Create a sequence node with proper masking.
|
||||
|
||||
Args:
|
||||
input_text: The input prompt text
|
||||
parent_node: Parent node to extend from (if available)
|
||||
prompt_tokens: Token IDs for the prompt
|
||||
output_tokens: Token IDs for the output/completion
|
||||
output_logprobs: Logprobs for output tokens
|
||||
completion_text: The completion text
|
||||
finish_reason: Finish reason from server
|
||||
|
||||
Returns:
|
||||
SequenceNode with properly masked tokens and logprobs
|
||||
"""
|
||||
# Combine text
|
||||
full_text = input_text + completion_text
|
||||
|
||||
# If we have a parent node, we should use its tokens as the prompt base
|
||||
if parent_node is not None:
|
||||
# Use parent's full tokens as the prompt
|
||||
prompt_tokens = parent_node.tokens.copy()
|
||||
|
||||
# Combine tokens
|
||||
full_tokens = prompt_tokens + output_tokens
|
||||
prompt_len = len(prompt_tokens)
|
||||
|
||||
# Create masked tokens: -100 for prompt, actual IDs for completion
|
||||
masked_tokens = [-100] * prompt_len + output_tokens
|
||||
|
||||
# Create masked logprobs: 1.0 for prompt, actual for completion
|
||||
# Pad logprobs to match token length if needed
|
||||
if len(output_logprobs) < len(output_tokens):
|
||||
output_logprobs = output_logprobs + [1.0] * (
|
||||
len(output_tokens) - len(output_logprobs)
|
||||
)
|
||||
elif len(output_logprobs) > len(output_tokens):
|
||||
output_logprobs = output_logprobs[: len(output_tokens)]
|
||||
|
||||
full_logprobs = [1.0] * prompt_len + output_logprobs
|
||||
|
||||
return SequenceNode(
|
||||
full_text=full_text,
|
||||
tokens=full_tokens,
|
||||
masked_tokens=masked_tokens,
|
||||
logprobs=full_logprobs,
|
||||
metadata={"finish_reason": finish_reason},
|
||||
)
|
||||
|
||||
async def chat_completion(self, **kwargs) -> ChatCompletion:
|
||||
"""
|
||||
Intercept chat completion call and track sequences.
|
||||
|
||||
Internally converts to prompt, calls tokens_and_logprobs_completion,
|
||||
tracks the sequence, and reconstructs a ChatCompletion response.
|
||||
|
||||
Args:
|
||||
**kwargs: Standard chat completion kwargs (messages, n, etc.)
|
||||
|
||||
Returns:
|
||||
ChatCompletion response
|
||||
"""
|
||||
# Get input text
|
||||
messages = kwargs.get("messages", [])
|
||||
prompt = self._convert_messages_to_prompt(messages)
|
||||
|
||||
# Handle parent node and extending logic based on mode
|
||||
if self.track_tree:
|
||||
# Tree mode: look up parent in dict
|
||||
parent_node = self._find_parent_node(prompt)
|
||||
extending_node = None
|
||||
else:
|
||||
# Default mode: check if extending existing sequence
|
||||
extending_node = self._find_extending_node(prompt)
|
||||
parent_node = None # Don't use parent merging in default mode
|
||||
|
||||
# Convert to completion format
|
||||
completion_kwargs = kwargs.copy()
|
||||
completion_kwargs["prompt"] = prompt
|
||||
completion_kwargs.pop("messages", None)
|
||||
|
||||
# Set model name if not provided
|
||||
if "model" not in completion_kwargs:
|
||||
completion_kwargs["model"] = self.server.config.model_name
|
||||
|
||||
# Compute input_ids (using existing tokens if extending)
|
||||
if not self.track_tree and self.tokenizer is not None:
|
||||
input_ids = self._compute_input_ids(prompt, extending_node)
|
||||
completion_kwargs["input_ids"] = input_ids
|
||||
|
||||
# Call the tokens and logprobs wrapper directly
|
||||
(
|
||||
prompt_tokens,
|
||||
output_tokens_list,
|
||||
output_logprobs_list,
|
||||
finish_reasons,
|
||||
) = await self.server.tokens_and_logprobs_completion(**completion_kwargs)
|
||||
|
||||
# Track each completion and build choices
|
||||
n = len(output_tokens_list)
|
||||
choices = []
|
||||
|
||||
for i in range(n):
|
||||
output_tokens = output_tokens_list[i]
|
||||
output_logprobs = output_logprobs_list[i]
|
||||
finish_reason_raw = finish_reasons[i] if i < len(finish_reasons) else "stop"
|
||||
|
||||
# Extract finish_reason string from dict if needed
|
||||
if isinstance(finish_reason_raw, dict):
|
||||
finish_reason = finish_reason_raw.get("type", "stop")
|
||||
else:
|
||||
finish_reason = finish_reason_raw
|
||||
|
||||
# Decode completion text
|
||||
if self.tokenizer is not None:
|
||||
completion_text = self.tokenizer.decode(
|
||||
output_tokens, skip_special_tokens=True
|
||||
)
|
||||
else:
|
||||
completion_text = "".join([chr(t) for t in output_tokens if t > 31])
|
||||
|
||||
# Create and store sequence node
|
||||
node = self._create_sequence_node(
|
||||
input_text=prompt,
|
||||
parent_node=parent_node,
|
||||
prompt_tokens=prompt_tokens,
|
||||
output_tokens=output_tokens,
|
||||
output_logprobs=output_logprobs,
|
||||
completion_text=completion_text,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
# Store node based on mode
|
||||
if self.track_tree:
|
||||
# Tree mode: key by full text in dict
|
||||
self.sequences[node.full_text] = node
|
||||
else:
|
||||
# Default mode: replace if extending, append if new context
|
||||
if extending_node is not None:
|
||||
# Replace the extending node with the new extended version
|
||||
try:
|
||||
idx = self.current_nodes.index(extending_node)
|
||||
self.current_nodes[idx] = node
|
||||
except ValueError:
|
||||
# Extending node not in list anymore, just append
|
||||
self.current_nodes.append(node)
|
||||
else:
|
||||
# New context - append to list
|
||||
self.current_nodes.append(node)
|
||||
|
||||
# Build choice
|
||||
choice = Choice(
|
||||
finish_reason=finish_reason,
|
||||
index=i,
|
||||
message=ChatCompletionMessage(
|
||||
content=completion_text, role="assistant"
|
||||
),
|
||||
)
|
||||
choices.append(choice)
|
||||
|
||||
# Construct ChatCompletion response
|
||||
return ChatCompletion(
|
||||
id=str(uuid.uuid4()),
|
||||
created=int(time.time()),
|
||||
model=self.server.config.model_name,
|
||||
object="chat.completion",
|
||||
choices=choices,
|
||||
)
|
||||
|
||||
async def completion(self, **kwargs) -> Completion:
|
||||
"""
|
||||
Intercept completion call and track sequences.
|
||||
|
||||
Uses tokens_and_logprobs_completion internally, tracks the sequence,
|
||||
and reconstructs a Completion response.
|
||||
|
||||
Args:
|
||||
**kwargs: Standard completion kwargs (prompt, n, etc.)
|
||||
|
||||
Returns:
|
||||
Completion response
|
||||
"""
|
||||
# Get input text
|
||||
prompt = kwargs.get("prompt", "")
|
||||
|
||||
# Handle parent node and extending logic based on mode
|
||||
if self.track_tree:
|
||||
# Tree mode: look up parent in dict
|
||||
parent_node = self._find_parent_node(prompt)
|
||||
extending_node = None
|
||||
else:
|
||||
# Default mode: check if extending existing sequence
|
||||
extending_node = self._find_extending_node(prompt)
|
||||
parent_node = None # Don't use parent merging in default mode
|
||||
|
||||
# Set model name if not provided
|
||||
if "model" not in kwargs:
|
||||
kwargs["model"] = self.server.config.model_name
|
||||
|
||||
# Compute input_ids (using existing tokens if extending)
|
||||
if not self.track_tree and self.tokenizer is not None:
|
||||
input_ids = self._compute_input_ids(prompt, extending_node)
|
||||
kwargs["input_ids"] = input_ids
|
||||
|
||||
# Call the tokens and logprobs wrapper directly
|
||||
(
|
||||
prompt_tokens,
|
||||
output_tokens_list,
|
||||
output_logprobs_list,
|
||||
finish_reasons,
|
||||
) = await self.server.tokens_and_logprobs_completion(**kwargs)
|
||||
|
||||
# Track each completion and build choices
|
||||
n = len(output_tokens_list)
|
||||
choices = []
|
||||
|
||||
for i in range(n):
|
||||
output_tokens = output_tokens_list[i]
|
||||
output_logprobs = output_logprobs_list[i]
|
||||
finish_reason_raw = finish_reasons[i] if i < len(finish_reasons) else "stop"
|
||||
|
||||
# Extract finish_reason string from dict if needed
|
||||
if isinstance(finish_reason_raw, dict):
|
||||
finish_reason = finish_reason_raw.get("type", "stop")
|
||||
else:
|
||||
finish_reason = finish_reason_raw
|
||||
|
||||
# Decode completion text
|
||||
if self.tokenizer is not None:
|
||||
completion_text = self.tokenizer.decode(
|
||||
output_tokens, skip_special_tokens=True
|
||||
)
|
||||
else:
|
||||
completion_text = "".join([chr(t) for t in output_tokens if t > 31])
|
||||
|
||||
# Create and store sequence node
|
||||
node = self._create_sequence_node(
|
||||
input_text=prompt,
|
||||
parent_node=parent_node,
|
||||
prompt_tokens=prompt_tokens,
|
||||
output_tokens=output_tokens,
|
||||
output_logprobs=output_logprobs,
|
||||
completion_text=completion_text,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
# Store node based on mode
|
||||
if self.track_tree:
|
||||
# Tree mode: key by full text in dict
|
||||
self.sequences[node.full_text] = node
|
||||
else:
|
||||
# Default mode: replace if extending, append if new context
|
||||
if extending_node is not None:
|
||||
# Replace the extending node with the new extended version
|
||||
try:
|
||||
idx = self.current_nodes.index(extending_node)
|
||||
self.current_nodes[idx] = node
|
||||
except ValueError:
|
||||
# Extending node not in list anymore, just append
|
||||
self.current_nodes.append(node)
|
||||
else:
|
||||
# New context - append to list
|
||||
self.current_nodes.append(node)
|
||||
|
||||
# Build choice
|
||||
choice = CompletionChoice(
|
||||
finish_reason=finish_reason, index=i, text=completion_text
|
||||
)
|
||||
choices.append(choice)
|
||||
|
||||
# Construct Completion response
|
||||
return Completion(
|
||||
id=str(uuid.uuid4()),
|
||||
created=int(time.time()),
|
||||
model=self.server.config.model_name,
|
||||
object="text_completion",
|
||||
choices=choices,
|
||||
)
|
||||
|
||||
def get_state(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get the current state of tracked sequences.
|
||||
|
||||
Returns:
|
||||
For default mode (track_tree=False):
|
||||
Dictionary with 'nodes': List[SequenceNode] - ready for training
|
||||
For tree mode (track_tree=True):
|
||||
Dictionary with 'sequences': Dict[str, SequenceNode] and 'tree' alias
|
||||
"""
|
||||
if self.track_tree:
|
||||
return {
|
||||
"sequences": self.sequences.copy(),
|
||||
"tree": self.sequences.copy(), # Alias for compatibility
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"nodes": self.current_nodes.copy(), # Return a copy so reset() doesn't affect it
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
"""Clear all tracked sequences."""
|
||||
if self.track_tree:
|
||||
self.sequences.clear()
|
||||
else:
|
||||
self.current_nodes.clear()
|
||||
|
|
@ -134,6 +134,16 @@ class OpenAIServer(APIServer):
|
|||
completions.choices.extend(c.choices)
|
||||
return completions
|
||||
|
||||
async def _tokens_and_logprobs_completion_wrapper(
|
||||
self, **kwargs
|
||||
) -> tuple[list, list, list, list]:
|
||||
"""
|
||||
Wrapper for the tokens and logprobs completion using the openai client.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Tokens and logprobs not supported by base OpenAI API, use specific API servers."
|
||||
)
|
||||
|
||||
|
||||
def resolve_openai_configs(
|
||||
default_server_configs,
|
||||
|
|
|
|||
|
|
@ -108,7 +108,7 @@ class ServerBaseline(BaseModel):
|
|||
rolling_buffer_length: int = Field(
|
||||
default=1000, description="Length of the rolling buffer to store metrics."
|
||||
)
|
||||
server_type: Literal["openai", "trl"] = Field(
|
||||
server_type: Literal["openai", "trl", "sglang"] = Field(
|
||||
default="openai", description="Type of server to use, openai or trl"
|
||||
)
|
||||
|
||||
|
|
@ -217,6 +217,16 @@ class APIServer(ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _tokens_and_logprobs_completion_wrapper(
|
||||
self, **kwargs
|
||||
) -> tuple[list, list, list, list]:
|
||||
"""
|
||||
Wrapper for tokens and logprobs completion. Should be overridden by the child class.
|
||||
Returns a tuple of (prompt_tokens, output_tokens, output_logprobs, finish_reasons).
|
||||
"""
|
||||
pass
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
)
|
||||
|
|
@ -352,3 +362,77 @@ class APIServer(ABC):
|
|||
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
|
||||
self.eval_attempts_list.append(stat_dict["attempts"])
|
||||
return ret_data
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
)
|
||||
async def _tokens_and_logprobs_comp(
|
||||
self, stat_dict, **kwargs
|
||||
) -> tuple[list, list, list, list]:
|
||||
"""
|
||||
Simple retry and stat collection wrapper for tokens and logprobs completion.
|
||||
"""
|
||||
while not self.server_healthy:
|
||||
await asyncio.sleep(1)
|
||||
async with self.sem:
|
||||
if stat_dict.get("start", None) is None:
|
||||
stat_dict["start"] = time.time()
|
||||
stat_dict["attempts"] += 1
|
||||
completions = await self._tokens_and_logprobs_completion_wrapper(**kwargs)
|
||||
stat_dict["end"] = time.time()
|
||||
return completions
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
)
|
||||
async def _tokens_and_logprobs_comp_eval(
|
||||
self, stat_dict, **kwargs
|
||||
) -> tuple[list, list, list, list]:
|
||||
"""
|
||||
Simple retry and stat collection wrapper for tokens and logprobs completion.
|
||||
"""
|
||||
while not self.server_healthy:
|
||||
await asyncio.sleep(1)
|
||||
async with self.eval_sem:
|
||||
if stat_dict.get("start", None) is None:
|
||||
stat_dict["start"] = time.time()
|
||||
stat_dict["attempts"] += 1
|
||||
completions = await self._tokens_and_logprobs_completion_wrapper(**kwargs)
|
||||
stat_dict["end"] = time.time()
|
||||
return completions
|
||||
|
||||
async def tokens_and_logprobs_completion(
|
||||
self, **kwargs
|
||||
) -> tuple[list, list, list, list]:
|
||||
"""
|
||||
Tokens and logprobs completion handler, waits for the server to be healthy and then calls the wrapper.
|
||||
Returns a tuple of (prompt_tokens, output_tokens, output_logprobs, finish_reasons).
|
||||
"""
|
||||
if not self.initialized:
|
||||
if self.config.health_check:
|
||||
if (
|
||||
self.config.base_url is not None
|
||||
): # skip health check if using OpenAI API
|
||||
self.check_task = asyncio.create_task(
|
||||
self.check_server_status_task(chat_completion=False)
|
||||
)
|
||||
else:
|
||||
self.server_healthy = True
|
||||
else:
|
||||
# If health_check is False, always assume healthy
|
||||
self.server_healthy = True
|
||||
self.initialized = True
|
||||
kwargs["model"] = self.config.model_name
|
||||
split = kwargs.pop("split", "train")
|
||||
stat_dict = {}
|
||||
stat_dict["attempts"] = 0
|
||||
if split == "train":
|
||||
ret_data = await self._tokens_and_logprobs_comp(stat_dict, **kwargs)
|
||||
self.request_timings.append(stat_dict["end"] - stat_dict["start"])
|
||||
self.attempts_list.append(stat_dict["attempts"])
|
||||
else:
|
||||
# Give separate eval workers, if desired, gotta go fast for those evals
|
||||
ret_data = await self._tokens_and_logprobs_comp_eval(stat_dict, **kwargs)
|
||||
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
|
||||
self.eval_attempts_list.append(stat_dict["attempts"])
|
||||
return ret_data
|
||||
|
|
|
|||
|
|
@ -87,6 +87,7 @@ def create_completion(
|
|||
class ServerHarness:
|
||||
def __init__(self):
|
||||
self.response_map = dict()
|
||||
self.tokens_and_logprobs_map = dict() # Map for tokens/logprobs responses
|
||||
self.sem = asyncio.Semaphore(1)
|
||||
self.eval_sem = asyncio.Semaphore(1)
|
||||
pass
|
||||
|
|
@ -110,6 +111,31 @@ class ServerHarness:
|
|||
def set_desired_completion(self, input_message: str, completion: Completion):
|
||||
self.response_map[input_message] = completion
|
||||
|
||||
def set_tokens_and_logprobs_response(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_tokens: list,
|
||||
output_tokens_list: list,
|
||||
output_logprobs_list: list,
|
||||
finish_reasons: list,
|
||||
):
|
||||
"""
|
||||
Set expected response for _tokens_and_logprobs_completion_wrapper.
|
||||
|
||||
Args:
|
||||
prompt: The prompt string (key)
|
||||
prompt_tokens: List of prompt token IDs
|
||||
output_tokens_list: List of lists of output token IDs (one per completion)
|
||||
output_logprobs_list: List of lists of output logprobs (one per completion)
|
||||
finish_reasons: List of finish reasons (one per completion)
|
||||
"""
|
||||
self.tokens_and_logprobs_map[prompt] = (
|
||||
prompt_tokens,
|
||||
output_tokens_list,
|
||||
output_logprobs_list,
|
||||
finish_reasons,
|
||||
)
|
||||
|
||||
async def chat_completion(self, *args, **kwargs) -> ChatCompletion:
|
||||
messages = kwargs.get("messages")
|
||||
dictkey = self.conv_to_dictkey(messages)
|
||||
|
|
@ -125,6 +151,21 @@ class ServerHarness:
|
|||
except KeyError as e:
|
||||
raise KeyError(f"KeyError: {e} for key:\n{prompt}")
|
||||
|
||||
async def tokens_and_logprobs_completion(
|
||||
self, **kwargs
|
||||
) -> tuple[list, list, list, list]:
|
||||
"""
|
||||
Mock implementation of tokens and logprobs completion wrapper.
|
||||
|
||||
Returns:
|
||||
Tuple of (prompt_tokens, output_tokens_list, output_logprobs_list, finish_reasons)
|
||||
"""
|
||||
prompt = kwargs.get("prompt")
|
||||
try:
|
||||
return self.tokens_and_logprobs_map.get(prompt)
|
||||
except KeyError as e:
|
||||
raise KeyError(f"KeyError: {e} for prompt:\n{prompt}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from openai.types.chat.chat_completion import ChatCompletion
|
|||
from openai.types.completion import Completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
||||
from atroposlib.envs.server_handling.openai_server import OpenAIServer
|
||||
from atroposlib.envs.server_handling.server_baseline import (
|
||||
APIServer,
|
||||
|
|
@ -15,6 +16,7 @@ from atroposlib.envs.server_handling.server_baseline import (
|
|||
ServerBaseline,
|
||||
)
|
||||
from atroposlib.envs.server_handling.server_harness import ServerHarness
|
||||
from atroposlib.envs.server_handling.sglang_server import SGLangServer
|
||||
from atroposlib.envs.server_handling.trl_vllm_server import TrlVllmServer
|
||||
|
||||
|
||||
|
|
@ -54,6 +56,8 @@ class ServerManager:
|
|||
server_class = OpenAIServer
|
||||
elif configs.server_type == "trl":
|
||||
server_class = TrlVllmServer
|
||||
elif configs.server_type == "sglang":
|
||||
server_class = SGLangServer
|
||||
else:
|
||||
raise ValueError(f"Invalid server type: {configs.server_type}")
|
||||
else:
|
||||
|
|
@ -61,6 +65,8 @@ class ServerManager:
|
|||
server_class = OpenAIServer
|
||||
elif configs[0].server_type == "trl":
|
||||
server_class = TrlVllmServer
|
||||
elif configs[0].server_type == "sglang":
|
||||
server_class = SGLangServer
|
||||
else:
|
||||
raise ValueError(f"Invalid server type: {configs[0].server_type}")
|
||||
if testing:
|
||||
|
|
@ -241,6 +247,53 @@ class ServerManager:
|
|||
)
|
||||
return await self.servers[most_available_server].completion(**kwargs)
|
||||
|
||||
async def tokens_and_logprobs_completion(
|
||||
self, **kwargs
|
||||
) -> tuple[list, list, list, list]:
|
||||
"""
|
||||
Get tokens and logprobs from completion.
|
||||
Returns (prompt_tokens, output_tokens, output_logprobs, finish_reasons).
|
||||
"""
|
||||
n = kwargs.get("n", 1)
|
||||
if n > self.max_n_completions:
|
||||
# Split into multiple completions
|
||||
results = []
|
||||
total_n = n
|
||||
while total_n > 0:
|
||||
n_to_use = min(total_n, self.max_n_completions)
|
||||
kwargs["n"] = n_to_use
|
||||
results.append(self.tokens_and_logprobs_completion(**kwargs))
|
||||
total_n -= n_to_use
|
||||
results = await asyncio.gather(*results)
|
||||
# Merge results - prompt_tokens should be same, extend output lists
|
||||
prompt_tokens = results[0][0]
|
||||
output_tokens = []
|
||||
output_logprobs = []
|
||||
finish_reasons = []
|
||||
for _, out_tokens, out_logprobs, out_finish_reasons in results:
|
||||
output_tokens.extend(out_tokens)
|
||||
output_logprobs.extend(out_logprobs)
|
||||
finish_reasons.extend(out_finish_reasons)
|
||||
return (prompt_tokens, output_tokens, output_logprobs, finish_reasons)
|
||||
|
||||
is_train = kwargs.get("split", "train") == "train"
|
||||
most_available_server = 0
|
||||
most_available_server_num_slots = -1
|
||||
await self.wait_for_sem(is_train)
|
||||
for i, server in enumerate(self.servers):
|
||||
if not server.server_healthy:
|
||||
continue
|
||||
if (
|
||||
server.sem._value if is_train else server.eval_sem._value
|
||||
) > most_available_server_num_slots:
|
||||
most_available_server = i
|
||||
most_available_server_num_slots = (
|
||||
server.sem._value if is_train else server.eval_sem._value
|
||||
)
|
||||
return await self.servers[most_available_server].tokens_and_logprobs_completion(
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def dedicated_server(self) -> AsyncGenerator[OpenAIServer, None]:
|
||||
most_available_server = 0
|
||||
|
|
@ -256,3 +309,50 @@ class ServerManager:
|
|||
yield self.servers[most_available_server]
|
||||
finally:
|
||||
pass
|
||||
|
||||
@asynccontextmanager
|
||||
async def managed_server(
|
||||
self, tokenizer=None
|
||||
) -> AsyncGenerator[ManagedServer, None]:
|
||||
"""
|
||||
Context manager that provides a ManagedServer instance.
|
||||
|
||||
The ManagedServer wraps the most available server and tracks text sequences
|
||||
with aligned tokens and logprobs. State is automatically cleared on exit.
|
||||
|
||||
Args:
|
||||
tokenizer: Optional tokenizer to use. If not provided, will attempt to
|
||||
extract from server or create from model name.
|
||||
|
||||
Yields:
|
||||
ManagedServer instance wrapping the selected server
|
||||
|
||||
Example:
|
||||
async with server_manager.managed_server() as managed:
|
||||
response = await managed.chat_completion(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
n=2
|
||||
)
|
||||
state = managed.get_state()
|
||||
# Process state...
|
||||
# State is automatically cleared when exiting context
|
||||
"""
|
||||
most_available_server = 0
|
||||
most_available_server_num_slots = -1
|
||||
for i, server in enumerate(self.servers):
|
||||
if not server.server_healthy:
|
||||
continue
|
||||
if server.sem._value > most_available_server_num_slots:
|
||||
most_available_server = i
|
||||
most_available_server_num_slots = server.sem._value
|
||||
|
||||
# Create ManagedServer wrapping the selected server
|
||||
managed = ManagedServer(
|
||||
server=self.servers[most_available_server], tokenizer=tokenizer
|
||||
)
|
||||
|
||||
try:
|
||||
yield managed
|
||||
finally:
|
||||
# Clean up: reset tracked sequences
|
||||
managed.reset()
|
||||
|
|
|
|||
355
atroposlib/envs/server_handling/sglang_server.py
Normal file
355
atroposlib/envs/server_handling/sglang_server.py
Normal file
|
|
@ -0,0 +1,355 @@
|
|||
import asyncio
|
||||
import warnings
|
||||
|
||||
import aiohttp
|
||||
import openai
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.completion import Completion
|
||||
from pydantic_cli import FailedExecutionException
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from atroposlib.envs.constants import NAMESPACE_SEP, OPENAI_NAMESPACE
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServer, APIServerConfig
|
||||
|
||||
|
||||
class SGLangServer(APIServer):
|
||||
"""
|
||||
SGLang server handling.
|
||||
"""
|
||||
|
||||
def __init__(self, config: APIServerConfig):
|
||||
self.openai = openai.AsyncClient(
|
||||
api_key=config.api_key,
|
||||
base_url=config.base_url,
|
||||
timeout=config.timeout,
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
||||
super().__init__(config)
|
||||
|
||||
async def check_server_status_task(self, chat_completion: bool = True):
|
||||
while True:
|
||||
try:
|
||||
if chat_completion:
|
||||
await self.openai.chat.completions.create(
|
||||
model=self.config.model_name,
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
max_tokens=1,
|
||||
)
|
||||
else:
|
||||
await self.openai.completions.create(
|
||||
model=self.config.model_name,
|
||||
prompt="hi",
|
||||
max_tokens=1,
|
||||
)
|
||||
self.server_healthy = True
|
||||
except (
|
||||
aiohttp.ClientError,
|
||||
openai.OpenAIError,
|
||||
openai.APITimeoutError,
|
||||
Exception,
|
||||
):
|
||||
self.server_healthy = False
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion:
|
||||
"""
|
||||
Wrapper for the chat completion using the openai client.
|
||||
"""
|
||||
assert (
|
||||
kwargs.get("model", None) is not None
|
||||
), "Model is required for chat completion!"
|
||||
assert (
|
||||
kwargs.get("messages", None) is not None
|
||||
), "Messages are required for chat completion!"
|
||||
if self.config.n_kwarg_is_ignored:
|
||||
n = kwargs.pop("n", 1)
|
||||
completion_list = await asyncio.gather(
|
||||
*[self.openai.chat.completions.create(**kwargs) for _ in range(n)]
|
||||
)
|
||||
completions = completion_list[0]
|
||||
if n > 1:
|
||||
for c in completion_list[1:]:
|
||||
completions.choices.extend(c.choices)
|
||||
else:
|
||||
completions = await self.openai.chat.completions.create(**kwargs)
|
||||
else:
|
||||
if "n" in kwargs:
|
||||
n = kwargs["n"]
|
||||
else:
|
||||
n = 1
|
||||
completions = await self.openai.chat.completions.create(**kwargs)
|
||||
if len(completions.choices) != n:
|
||||
if len(completions.choices) != 1:
|
||||
raise ValueError(
|
||||
f"Expected 1 or {n} completions, got {len(completions.choices)}!"
|
||||
)
|
||||
else:
|
||||
warnings.warn("n kwarg is ignored by the API, setting to True")
|
||||
self.config.n_kwarg_is_ignored = True
|
||||
completion_list = await asyncio.gather(
|
||||
*[
|
||||
self.openai.chat.completions.create(**kwargs)
|
||||
for _ in range(1, n)
|
||||
]
|
||||
)
|
||||
for c in completion_list:
|
||||
completions.choices.extend(c.choices)
|
||||
return completions
|
||||
|
||||
async def _completion_wrapper(self, **kwargs) -> Completion:
|
||||
"""
|
||||
Wrapper for the completion using the openai client.
|
||||
"""
|
||||
assert (
|
||||
kwargs.get("model", None) is not None
|
||||
), "Model is required for completion!"
|
||||
assert (
|
||||
kwargs.get("prompt", None) is not None
|
||||
), "Prompt is required for completion!"
|
||||
if self.config.n_kwarg_is_ignored:
|
||||
n = kwargs.pop("n", 1)
|
||||
completion_list = await asyncio.gather(
|
||||
*[self.openai.completions.create(**kwargs) for _ in range(n)]
|
||||
)
|
||||
completions = completion_list[0]
|
||||
if n > 1:
|
||||
for c in completion_list[1:]:
|
||||
completions.choices.extend(c.choices)
|
||||
else:
|
||||
if "n" in kwargs:
|
||||
n = kwargs["n"]
|
||||
else:
|
||||
n = 1
|
||||
completions = await self.openai.completions.create(**kwargs)
|
||||
if len(completions.choices) != n:
|
||||
if len(completions.choices) != 1:
|
||||
raise ValueError(
|
||||
f"Expected 1 or {n} completions, got {len(completions.choices)}!"
|
||||
)
|
||||
else:
|
||||
warnings.warn("n kwarg is ignored by the API, setting to True")
|
||||
self.config.n_kwarg_is_ignored = True
|
||||
completion_list = await asyncio.gather(
|
||||
*[self.openai.completions.create(**kwargs) for _ in range(1, n)]
|
||||
)
|
||||
for c in completion_list:
|
||||
completions.choices.extend(c.choices)
|
||||
return completions
|
||||
|
||||
async def _tokens_and_logprobs_completion_wrapper(
|
||||
self, **kwargs
|
||||
) -> tuple[list, list, list, list]:
|
||||
"""
|
||||
Wrapper for tokens and logprobs completion using SGLang's native API.
|
||||
Returns a tuple of (prompt_tokens, output_tokens, output_logprobs, finish_reasons).
|
||||
Each element is a list of lists (one per completion in the batch).
|
||||
"""
|
||||
assert (
|
||||
kwargs.get("model", None) is not None
|
||||
), "Model is required for completion!"
|
||||
assert (
|
||||
kwargs.get("prompt", None) is not None
|
||||
or kwargs.get("input_ids", None) is not None
|
||||
), "Prompt or input_ids is required for completion!"
|
||||
|
||||
# Use input_ids if provided (from ManagedServer), otherwise tokenize prompt
|
||||
if "input_ids" in kwargs:
|
||||
prompt_tokens = kwargs.pop("input_ids")
|
||||
kwargs.pop("prompt", None) # Remove prompt if it exists
|
||||
else:
|
||||
prompt_tokens = self.tokenizer.encode(kwargs.pop("prompt"))
|
||||
|
||||
# Check for double BOS token, can happen if you use chat templates and forget that they insert a BOS token
|
||||
if (
|
||||
len(prompt_tokens) >= 2
|
||||
and prompt_tokens[0] == self.tokenizer.bos_token_id == prompt_tokens[1]
|
||||
):
|
||||
prompt_tokens = prompt_tokens[1:]
|
||||
if "max_tokens" in kwargs:
|
||||
kwargs["max_new_tokens"] = kwargs.pop("max_tokens")
|
||||
if "model" in kwargs:
|
||||
kwargs.pop("model")
|
||||
# Prepare request for SGLang native API
|
||||
request_data = {
|
||||
"input_ids": prompt_tokens,
|
||||
"sampling_params": kwargs,
|
||||
"return_logprob": True,
|
||||
"return_text_in_logprobs": False, # We want raw token IDs, not text
|
||||
}
|
||||
|
||||
# Make async request to SGLang /generate endpoint
|
||||
import aiohttp
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.config.base_url.replace('/v1', '')}/generate",
|
||||
json=request_data,
|
||||
headers=(
|
||||
{"Authorization": f"Bearer {self.config.api_key}"}
|
||||
if self.config.api_key
|
||||
else {}
|
||||
),
|
||||
timeout=aiohttp.ClientTimeout(total=self.config.timeout),
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
results = await response.json()
|
||||
|
||||
# Handle both single and batch responses
|
||||
if not isinstance(results, list):
|
||||
results = [results]
|
||||
|
||||
output_tokens_list = []
|
||||
output_logprobs_list = []
|
||||
finish_reasons_list = []
|
||||
|
||||
for result in results:
|
||||
meta_info = result.get("meta_info", {})
|
||||
|
||||
# Get output logprobs - extract just the logprob values
|
||||
output_token_logprobs = meta_info.get("output_token_logprobs", [])
|
||||
logprobs = [
|
||||
item[0] for item in output_token_logprobs
|
||||
] # Extract logprob from (logprob, token_id, text) tuples
|
||||
output_ids = [
|
||||
item[1] for item in output_token_logprobs
|
||||
] # Extract token ID from (logprob, token_id, text) tuples
|
||||
|
||||
# Get finish reason
|
||||
finish_reason = meta_info.get("finish_reason", None)
|
||||
|
||||
output_tokens_list.append(output_ids)
|
||||
output_logprobs_list.append(logprobs)
|
||||
finish_reasons_list.append(finish_reason)
|
||||
|
||||
return (
|
||||
prompt_tokens,
|
||||
output_tokens_list,
|
||||
output_logprobs_list,
|
||||
finish_reasons_list,
|
||||
)
|
||||
|
||||
|
||||
def resolve_openai_configs(
|
||||
default_server_configs,
|
||||
openai_config_dict,
|
||||
yaml_config,
|
||||
cli_passed_flags,
|
||||
logger,
|
||||
):
|
||||
"""
|
||||
Helper to resolve the final server_configs, handling single, multiple servers, and overrides.
|
||||
"""
|
||||
from atroposlib.envs.server_handling.server_manager import ServerBaseline
|
||||
|
||||
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
|
||||
openai_yaml_config = yaml_config.get(OPENAI_NAMESPACE, None)
|
||||
openai_cli_config = {
|
||||
k: v for k, v in cli_passed_flags.items() if k.startswith(openai_full_prefix)
|
||||
}
|
||||
|
||||
is_multi_server_yaml = (
|
||||
isinstance(openai_yaml_config, list) and len(openai_yaml_config) >= 2
|
||||
)
|
||||
is_multi_server_default = (
|
||||
(not is_multi_server_yaml)
|
||||
and isinstance(default_server_configs, list)
|
||||
and len(default_server_configs) >= 2
|
||||
)
|
||||
|
||||
if (is_multi_server_yaml or is_multi_server_default) and openai_cli_config:
|
||||
raise FailedExecutionException(
|
||||
message=f"CLI overrides for OpenAI settings (--{openai_full_prefix}*) are not supported "
|
||||
f"when multiple servers are defined (either via YAML list under '{OPENAI_NAMESPACE}' "
|
||||
"or a default list with length >= 2).",
|
||||
exit_code=2,
|
||||
)
|
||||
|
||||
if is_multi_server_yaml:
|
||||
logger.info(
|
||||
f"Using multi-server configuration defined in YAML under '{OPENAI_NAMESPACE}'."
|
||||
)
|
||||
try:
|
||||
server_configs = [APIServerConfig(**cfg) for cfg in openai_yaml_config]
|
||||
except Exception as e:
|
||||
raise FailedExecutionException(
|
||||
f"Error parsing multi-server OpenAI configuration from YAML under '{OPENAI_NAMESPACE}': {e}"
|
||||
) from e
|
||||
elif isinstance(default_server_configs, ServerBaseline):
|
||||
logger.info("Using ServerBaseline configuration.")
|
||||
server_configs = default_server_configs
|
||||
elif is_multi_server_default:
|
||||
logger.info("Using default multi-server configuration (length >= 2).")
|
||||
server_configs = default_server_configs
|
||||
else:
|
||||
logger.info(
|
||||
"Using single OpenAI server configuration based on merged settings (default/YAML/CLI)."
|
||||
)
|
||||
try:
|
||||
final_openai_config = APIServerConfig(**openai_config_dict)
|
||||
except Exception as e:
|
||||
raise FailedExecutionException(
|
||||
f"Error creating final OpenAI configuration from merged settings: {e}\n"
|
||||
f"Merged Dict: {openai_config_dict}"
|
||||
) from e
|
||||
|
||||
if isinstance(default_server_configs, APIServerConfig):
|
||||
server_configs = final_openai_config
|
||||
elif isinstance(default_server_configs, list):
|
||||
server_configs = [final_openai_config]
|
||||
else:
|
||||
logger.warning(
|
||||
f"Unexpected type for default_server_configs: {type(default_server_configs)}. "
|
||||
f"Proceeding with single OpenAI server configuration based on merged settings."
|
||||
)
|
||||
server_configs = [final_openai_config]
|
||||
|
||||
return server_configs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
async def test_tokens_and_logprobs():
|
||||
# Configure the server - update these values for your setup
|
||||
config = APIServerConfig(
|
||||
api_key="", # Add your API key if needed
|
||||
base_url="http://localhost:30000", # Update to your SGLang server URL
|
||||
model_name="Qwen/Qwen3-4B-Instruct-2507", # Update to your model name
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
server = SGLangServer(config)
|
||||
|
||||
# Test the tokens_and_logprobs_completion method
|
||||
print("Testing tokens_and_logprobs_completion...")
|
||||
try:
|
||||
prompt_tokens, output_tokens, output_logprobs, finish_reasons = (
|
||||
await server.tokens_and_logprobs_completion(
|
||||
prompt="The capital of France is",
|
||||
n=4,
|
||||
max_tokens=32,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
stop=["User:", "Human:", "Assistant:", "</answer>"],
|
||||
)
|
||||
)
|
||||
|
||||
print("\nResults:")
|
||||
print(f"Prompt tokens: {prompt_tokens}")
|
||||
print(f"Output tokens: {output_tokens}")
|
||||
print(f"Output logprobs (first 5): {[lp[:5] for lp in output_logprobs]}")
|
||||
print(f"Finish reasons: {finish_reasons}")
|
||||
print(f"\nNumber of completions: {len(output_tokens)}")
|
||||
print(f"Output length: {[len(tokens) for tokens in output_tokens]}")
|
||||
responses = "\n\n".join(
|
||||
[server.tokenizer.decode(tokens) for tokens in output_tokens]
|
||||
)
|
||||
print(f"Responses:\n-{responses}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
# Run the test
|
||||
asyncio.run(test_tokens_and_logprobs())
|
||||
|
|
@ -13,6 +13,7 @@ from openai.types.chat.chat_completion import (
|
|||
ChatCompletionMessage,
|
||||
Choice,
|
||||
)
|
||||
from openai.types.completion import Completion, CompletionChoice
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServer, APIServerConfig
|
||||
|
|
@ -81,7 +82,7 @@ class TrlVllmServer(APIServer):
|
|||
)
|
||||
return completions
|
||||
|
||||
async def _completion_wrapper(self, **kwargs) -> ChatCompletion:
|
||||
async def _completion_wrapper(self, **kwargs) -> Completion:
|
||||
"""
|
||||
Wrapper for the completion using the trl's vLLM server.
|
||||
"""
|
||||
|
|
@ -102,25 +103,30 @@ class TrlVllmServer(APIServer):
|
|||
},
|
||||
) as response:
|
||||
completions = await response.json()
|
||||
completions = ChatCompletion(
|
||||
completions = Completion(
|
||||
id=str(uuid.uuid4()),
|
||||
object="chat.completion",
|
||||
object="text_completion",
|
||||
created=int(time.time()),
|
||||
model=self.config.model_name,
|
||||
choices=[
|
||||
Choice(
|
||||
CompletionChoice(
|
||||
finish_reason=(
|
||||
"stop"
|
||||
if self.tokenizer.eos_token_id in completion
|
||||
else "length"
|
||||
),
|
||||
index=i,
|
||||
message=ChatCompletionMessage(
|
||||
content=self.tokenizer.decode(completion),
|
||||
role="assistant",
|
||||
),
|
||||
text=self.tokenizer.decode(completion),
|
||||
)
|
||||
for i, completion in enumerate(completions["completion_ids"])
|
||||
],
|
||||
)
|
||||
return completions
|
||||
|
||||
async def _tokens_and_logprobs_completion_wrapper(
|
||||
self, **kwargs
|
||||
) -> tuple[list, list, list, list]:
|
||||
"""
|
||||
Wrapper for the tokens and logprobs completion using the openai client.
|
||||
"""
|
||||
raise NotImplementedError("Not implemented for trl's vLLM server yet.")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue