mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
stuff
This commit is contained in:
parent
f5c847d39d
commit
a7bdc0270d
2 changed files with 295 additions and 5 deletions
|
|
@ -371,6 +371,260 @@ pip install peft
|
|||
|
||||
---
|
||||
|
||||
## Checkpoint Locations
|
||||
|
||||
### Where Are Trained Models Saved?
|
||||
|
||||
| Mode | Location | Contents |
|
||||
|------|----------|----------|
|
||||
| **Legacy** | `trained_model_checkpoints/step_N/` | Full model + tokenizer |
|
||||
| **Legacy** | `trained_model_checkpoints/final_model/` | Final checkpoint |
|
||||
| **Shared vLLM** | `trained_model_checkpoints/step_N/` | Full model + tokenizer |
|
||||
| **LoRA** | `trained_model_checkpoints/adapter_step_N/` | LoRA adapters only (~10-50MB) |
|
||||
| **LoRA** | `trained_model_checkpoints/final_adapter/` | Final adapter |
|
||||
|
||||
### Customizing Save Path
|
||||
|
||||
```bash
|
||||
python example_trainer/grpo.py \
|
||||
--save-path /path/to/my/checkpoints \
|
||||
...
|
||||
```
|
||||
|
||||
### Loading Checkpoints for Inference
|
||||
|
||||
```python
|
||||
# Full model (Legacy/Shared modes)
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
model = AutoModelForCausalLM.from_pretrained("trained_model_checkpoints/final_model")
|
||||
tokenizer = AutoTokenizer.from_pretrained("trained_model_checkpoints/final_model")
|
||||
|
||||
# LoRA adapter
|
||||
from peft import PeftModel, PeftConfig
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
base_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B-Instruct")
|
||||
model = PeftModel.from_pretrained(base_model, "trained_model_checkpoints/final_adapter")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## vLLM Server Requirements
|
||||
|
||||
When using `--openai.server_type vllm` or the shared_vllm bridge, your vLLM server must expose these endpoints:
|
||||
|
||||
### Required Endpoints
|
||||
|
||||
| Endpoint | Method | Purpose | Used By |
|
||||
|----------|--------|---------|---------|
|
||||
| `/health` | GET | Health check | All modes |
|
||||
| `/generate` | POST | Native generation with token IDs + logprobs | VLLMServer class |
|
||||
|
||||
### Required `/generate` Request Format
|
||||
|
||||
The vLLM server must handle **both** prompt formats:
|
||||
|
||||
```json
|
||||
// String prompt (simple)
|
||||
{
|
||||
"prompt": "Hello, world!",
|
||||
"max_tokens": 100,
|
||||
"temperature": 1.0,
|
||||
"logprobs": 1
|
||||
}
|
||||
|
||||
// Token ID prompt (used by atroposlib)
|
||||
{
|
||||
"prompt": {"prompt_token_ids": [1, 2, 3, 4, 5]},
|
||||
"max_tokens": 100,
|
||||
"temperature": 1.0,
|
||||
"logprobs": 1
|
||||
}
|
||||
```
|
||||
|
||||
### Required `/generate` Response Format
|
||||
|
||||
```json
|
||||
{
|
||||
"text": ["generated text here"],
|
||||
"prompt": "original prompt",
|
||||
"finish_reasons": ["stop"],
|
||||
"logprobs": [
|
||||
[
|
||||
[{"12345": -0.5}],
|
||||
[{"67890": -1.2}]
|
||||
]
|
||||
],
|
||||
"prompt_token_ids": [1, 2, 3, 4, 5],
|
||||
"token_ids": [[12345, 67890, ...]]
|
||||
}
|
||||
```
|
||||
|
||||
The `logprobs` field format: `List[List[List[Dict[token_id, logprob]]]]`
|
||||
- Outer list: per completion (n samples)
|
||||
- Middle list: per token in completion
|
||||
- Inner list: contains single dict `{token_id: logprob}`
|
||||
|
||||
### Optional Bridge Endpoints (for shared_vllm mode)
|
||||
|
||||
| Endpoint | Method | Purpose |
|
||||
|----------|--------|---------|
|
||||
| `/bridge/info` | GET | Get bridge status |
|
||||
| `/bridge/notify_update` | POST | Receive weight update notifications |
|
||||
| `/bridge/state_dict_info` | GET | Get model parameter mappings |
|
||||
|
||||
### Optional LoRA Endpoints (for lora_only mode)
|
||||
|
||||
| Endpoint | Method | Purpose |
|
||||
|----------|--------|---------|
|
||||
| `/lora/status` | GET | Get active LoRA adapter |
|
||||
| `/lora/load` | POST | Load new LoRA adapter |
|
||||
| `/lora/unload` | POST | Unload current adapter |
|
||||
|
||||
### Using Standard vLLM vs Custom Server
|
||||
|
||||
| Server | Supports `/generate` with logprobs | Supports bridge | Supports LoRA hot-swap |
|
||||
|--------|-----------------------------------|-----------------|------------------------|
|
||||
| `vllm serve ...` | ❌ No | ❌ No | ❌ No |
|
||||
| `vllm_api_server.py` | ✅ Yes | ✅ Yes | ✅ Yes |
|
||||
|
||||
**Use `example_trainer/vllm_api_server.py` for full feature support.**
|
||||
|
||||
---
|
||||
|
||||
## Benchmarking Speed & Memory
|
||||
|
||||
### Memory Usage Comparison
|
||||
|
||||
```bash
|
||||
# Run this during training to monitor GPU memory
|
||||
watch -n 1 nvidia-smi
|
||||
```
|
||||
|
||||
**Expected Memory Usage (Qwen2.5-3B-Instruct):**
|
||||
|
||||
| Mode | Trainer GPU | vLLM GPU | Total |
|
||||
|------|------------|----------|-------|
|
||||
| **Legacy** | ~8GB | ~8GB | ~16GB (2x model) |
|
||||
| **Shared vLLM** | ~8GB (shared) | ~8GB (shared) | ~8GB (1x model) |
|
||||
| **LoRA** | ~10GB (frozen base) | ~8GB | ~18GB |
|
||||
|
||||
### Speed Benchmarking
|
||||
|
||||
Add these measurements to your training script or use the built-in wandb logging:
|
||||
|
||||
```python
|
||||
import time
|
||||
import torch
|
||||
|
||||
# Track step times
|
||||
step_times = []
|
||||
sync_times = []
|
||||
|
||||
for step in range(training_steps):
|
||||
# Measure training step time
|
||||
step_start = time.time()
|
||||
# ... training code ...
|
||||
step_time = time.time() - step_start
|
||||
step_times.append(step_time)
|
||||
|
||||
# Measure sync time (Legacy mode only)
|
||||
if step % vllm_restart_interval == 0:
|
||||
sync_start = time.time()
|
||||
# ... checkpoint + restart vLLM ...
|
||||
sync_time = time.time() - sync_start
|
||||
sync_times.append(sync_time)
|
||||
|
||||
# Print summary
|
||||
print(f"Avg step time: {sum(step_times)/len(step_times):.2f}s")
|
||||
print(f"Avg sync time: {sum(sync_times)/len(sync_times):.2f}s" if sync_times else "No syncs")
|
||||
```
|
||||
|
||||
### Benchmark Script
|
||||
|
||||
Create a benchmark comparing modes:
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
# benchmark_modes.sh
|
||||
|
||||
MODEL="Qwen/Qwen2.5-3B-Instruct"
|
||||
STEPS=50
|
||||
BATCH=2
|
||||
ACCUM=16
|
||||
|
||||
echo "=== Benchmarking Legacy Mode ==="
|
||||
time python example_trainer/grpo.py \
|
||||
--model-name $MODEL \
|
||||
--weight-bridge-mode none \
|
||||
--training-steps $STEPS \
|
||||
--batch-size $BATCH \
|
||||
--gradient-accumulation-steps $ACCUM \
|
||||
--vllm-restart-interval 10 \
|
||||
2>&1 | tee benchmark_legacy.log
|
||||
|
||||
echo "=== Benchmarking Shared vLLM Mode ==="
|
||||
export LOGDIR=/tmp/bench_shared
|
||||
export NUM_INFERENCE_NODES=0
|
||||
mkdir -p $LOGDIR
|
||||
|
||||
# Start vLLM first
|
||||
python example_trainer/vllm_api_server.py \
|
||||
--model $MODEL --port 9001 --gpu-memory-utilization 0.45 &
|
||||
VLLM_PID=$!
|
||||
sleep 60 # Wait for vLLM to load
|
||||
|
||||
time python example_trainer/grpo.py \
|
||||
--model-name $MODEL \
|
||||
--weight-bridge-mode shared_vllm \
|
||||
--training-steps $STEPS \
|
||||
--batch-size $BATCH \
|
||||
--gradient-accumulation-steps $ACCUM \
|
||||
--num-inference-nodes 0 \
|
||||
2>&1 | tee benchmark_shared.log
|
||||
|
||||
kill $VLLM_PID
|
||||
|
||||
echo "=== Benchmarking LoRA Mode ==="
|
||||
time python example_trainer/grpo.py \
|
||||
--model-name $MODEL \
|
||||
--weight-bridge-mode lora_only \
|
||||
--training-steps $STEPS \
|
||||
--batch-size $BATCH \
|
||||
--gradient-accumulation-steps $ACCUM \
|
||||
--lora-r 16 \
|
||||
--vllm-restart-interval 25 \
|
||||
2>&1 | tee benchmark_lora.log
|
||||
|
||||
echo "=== Summary ==="
|
||||
echo "Check benchmark_*.log for detailed timing"
|
||||
```
|
||||
|
||||
### Expected Benchmark Results
|
||||
|
||||
| Metric | Legacy | Shared vLLM | LoRA |
|
||||
|--------|--------|-------------|------|
|
||||
| **Step time** | ~2-5s | ~2-5s | ~1-3s |
|
||||
| **Sync overhead** | ~30-60s every N steps | ~0ms | ~5-10s every N steps |
|
||||
| **Total time (50 steps, sync every 10)** | ~15-20 min | ~3-5 min | ~5-8 min |
|
||||
| **Peak GPU memory** | ~16GB | ~8GB | ~10GB |
|
||||
| **Checkpoint size** | ~6GB | ~6GB | ~50MB |
|
||||
|
||||
### WandB Metrics to Watch
|
||||
|
||||
If using `--use-wandb`, these metrics are logged:
|
||||
|
||||
| Metric | Description |
|
||||
|--------|-------------|
|
||||
| `train/loss` | GRPO loss |
|
||||
| `train/grad_norm` | Gradient norm |
|
||||
| `train/pos_logp` | Log prob of positive examples |
|
||||
| `train/neg_logp` | Log prob of negative examples |
|
||||
| `train/step_time` | Time per training step |
|
||||
| `train/sync_time` | Time for weight sync (legacy/lora) |
|
||||
|
||||
---
|
||||
|
||||
## Files in This Directory
|
||||
|
||||
| File | Description |
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue