mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
add docs :)
This commit is contained in:
parent
c3a118f50d
commit
5d6d6bb0dc
6 changed files with 892 additions and 21 deletions
|
|
@ -17,6 +17,98 @@ To achieve this generality, our environment abstraction deviates from other open
|
|||
|
||||
- **Environments return tokens (not messages!)**: One of the most peculiar design choices we made was that at least for text-only environments, environments are responsible for tokenization. This gives us the flexibility to assign token-level rewards and to mix completions-based (e.g. autocomplete suggestion accept/reject) and chat-based (e.g. instruct-model code generation) environments together in the same training run. For cases like multimodal where a OpenAI-formatted message list needs to be passed to a transformers `AutoProcessor`, we support a `list[dict]`-valued `messages` key within our group abstraction [ScoredDataGroup](https://github.com/NousResearch/atropos/blob/a282604baac8dbb3b201f992cfc889ee1e5a0f4a/atroposlib/envs/base.py#L55).
|
||||
|
||||
## Working with Servers and ManagedServer
|
||||
|
||||
**🎯 Recommended Approach:** Use `ManagedServer` for automatic token and logprob tracking!
|
||||
|
||||
When implementing `collect_trajectory` or `collect_trajectories`, you need to interact with your inference server to generate completions and extract tokens/logprobs for training. The **recommended way** to do this is using `ManagedServer`, which automatically handles tokenization, masking, and logprob alignment.
|
||||
|
||||
### ManagedServer Overview
|
||||
|
||||
`ManagedServer` wraps your `APIServer` and automatically tracks:
|
||||
- **Tokens**: Full unmasked token sequences
|
||||
- **Masked Tokens**: Training format with `-100` for prompt positions, actual token IDs for completion
|
||||
- **Logprobs**: Training format with `1.0` for masked positions, actual logprob values for completion
|
||||
- **Full Text**: Complete text (prompt + completion)
|
||||
- **Metadata**: Finish reasons and other information
|
||||
|
||||
**Why 1.0 for masked logprobs?** It represents an "obviously bad" probability (e^1.0 ≈ 2.718 > 1.0, which is invalid), making it easy to identify and ignore during training.
|
||||
|
||||
### Basic Usage Pattern
|
||||
|
||||
```python
|
||||
async def collect_trajectories(self, item):
|
||||
prompt = format_prompt(item)
|
||||
|
||||
# Use managed server with context manager
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
completion = await managed.completion(
|
||||
prompt=prompt,
|
||||
n=self.config.group_size,
|
||||
max_tokens=4096,
|
||||
temperature=1.0,
|
||||
)
|
||||
|
||||
# Get tracked sequences with aligned tokens and logprobs
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
# Extract pre-computed, guaranteed-aligned data
|
||||
for choice, node in zip(completion.choices, nodes):
|
||||
tokens = node.tokens # ✅ Automatically computed
|
||||
masked_tokens = node.masked_tokens # ✅ Automatically masked
|
||||
logprobs = node.logprobs # ✅ Automatically aligned
|
||||
finish_reason = node.metadata["finish_reason"]
|
||||
|
||||
# Score and return...
|
||||
```
|
||||
|
||||
### Chat Completion Pattern
|
||||
|
||||
For chat-based environments, use `chat_completion()`:
|
||||
|
||||
```python
|
||||
async def collect_trajectories(self, item):
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": item["question"]},
|
||||
]
|
||||
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
chat_completion = await managed.chat_completion(
|
||||
messages=messages,
|
||||
n=self.config.group_size,
|
||||
max_tokens=4096,
|
||||
)
|
||||
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
# Process nodes...
|
||||
```
|
||||
|
||||
### Benefits Over Manual Handling
|
||||
|
||||
❌ **Without ManagedServer:**
|
||||
- Manually tokenize prompts and completions
|
||||
- Manually compute prompt lengths
|
||||
- Manually apply masking logic
|
||||
- Manually extract and align logprobs
|
||||
- Prone to off-by-one errors
|
||||
|
||||
✅ **With ManagedServer:**
|
||||
- Automatic tokenization
|
||||
- Automatic masking
|
||||
- Guaranteed alignment
|
||||
- Clean, simple code
|
||||
- Works with both `completion()` and `chat_completion()` APIs
|
||||
|
||||
### Complete Documentation
|
||||
|
||||
For detailed examples, advanced patterns (multi-turn, RLAIF, backlog workflows), API reference, and migration guide, see:
|
||||
|
||||
📚 **[ManagedServer Complete Guide](server_handling/MANAGED_SERVER.md)**
|
||||
|
||||
## Core Methods to Implement
|
||||
|
||||
These methods **must** be implemented in your subclass:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue