add docs :)

This commit is contained in:
Dakota 2025-10-29 11:26:43 -05:00
parent c3a118f50d
commit 5d6d6bb0dc
6 changed files with 892 additions and 21 deletions

View file

@ -17,6 +17,98 @@ To achieve this generality, our environment abstraction deviates from other open
- **Environments return tokens (not messages!)**: One of the most peculiar design choices we made was that at least for text-only environments, environments are responsible for tokenization. This gives us the flexibility to assign token-level rewards and to mix completions-based (e.g. autocomplete suggestion accept/reject) and chat-based (e.g. instruct-model code generation) environments together in the same training run. For cases like multimodal where a OpenAI-formatted message list needs to be passed to a transformers `AutoProcessor`, we support a `list[dict]`-valued `messages` key within our group abstraction [ScoredDataGroup](https://github.com/NousResearch/atropos/blob/a282604baac8dbb3b201f992cfc889ee1e5a0f4a/atroposlib/envs/base.py#L55).
## Working with Servers and ManagedServer
**🎯 Recommended Approach:** Use `ManagedServer` for automatic token and logprob tracking!
When implementing `collect_trajectory` or `collect_trajectories`, you need to interact with your inference server to generate completions and extract tokens/logprobs for training. The **recommended way** to do this is using `ManagedServer`, which automatically handles tokenization, masking, and logprob alignment.
### ManagedServer Overview
`ManagedServer` wraps your `APIServer` and automatically tracks:
- **Tokens**: Full unmasked token sequences
- **Masked Tokens**: Training format with `-100` for prompt positions, actual token IDs for completion
- **Logprobs**: Training format with `1.0` for masked positions, actual logprob values for completion
- **Full Text**: Complete text (prompt + completion)
- **Metadata**: Finish reasons and other information
**Why 1.0 for masked logprobs?** It represents an "obviously bad" probability (e^1.0 ≈ 2.718 > 1.0, which is invalid), making it easy to identify and ignore during training.
### Basic Usage Pattern
```python
async def collect_trajectories(self, item):
prompt = format_prompt(item)
# Use managed server with context manager
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
completion = await managed.completion(
prompt=prompt,
n=self.config.group_size,
max_tokens=4096,
temperature=1.0,
)
# Get tracked sequences with aligned tokens and logprobs
state = managed.get_state()
nodes = state["nodes"]
# Extract pre-computed, guaranteed-aligned data
for choice, node in zip(completion.choices, nodes):
tokens = node.tokens # ✅ Automatically computed
masked_tokens = node.masked_tokens # ✅ Automatically masked
logprobs = node.logprobs # ✅ Automatically aligned
finish_reason = node.metadata["finish_reason"]
# Score and return...
```
### Chat Completion Pattern
For chat-based environments, use `chat_completion()`:
```python
async def collect_trajectories(self, item):
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": item["question"]},
]
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
chat_completion = await managed.chat_completion(
messages=messages,
n=self.config.group_size,
max_tokens=4096,
)
state = managed.get_state()
nodes = state["nodes"]
# Process nodes...
```
### Benefits Over Manual Handling
❌ **Without ManagedServer:**
- Manually tokenize prompts and completions
- Manually compute prompt lengths
- Manually apply masking logic
- Manually extract and align logprobs
- Prone to off-by-one errors
✅ **With ManagedServer:**
- Automatic tokenization
- Automatic masking
- Guaranteed alignment
- Clean, simple code
- Works with both `completion()` and `chat_completion()` APIs
### Complete Documentation
For detailed examples, advanced patterns (multi-turn, RLAIF, backlog workflows), API reference, and migration guide, see:
📚 **[ManagedServer Complete Guide](server_handling/MANAGED_SERVER.md)**
## Core Methods to Implement
These methods **must** be implemented in your subclass: