This commit is contained in:
Jai Suphavadeeprasit 2025-12-08 12:37:43 -05:00
parent f5c847d39d
commit a7bdc0270d
2 changed files with 295 additions and 5 deletions

View file

@ -371,6 +371,260 @@ pip install peft
---
## Checkpoint Locations
### Where Are Trained Models Saved?
| Mode | Location | Contents |
|------|----------|----------|
| **Legacy** | `trained_model_checkpoints/step_N/` | Full model + tokenizer |
| **Legacy** | `trained_model_checkpoints/final_model/` | Final checkpoint |
| **Shared vLLM** | `trained_model_checkpoints/step_N/` | Full model + tokenizer |
| **LoRA** | `trained_model_checkpoints/adapter_step_N/` | LoRA adapters only (~10-50MB) |
| **LoRA** | `trained_model_checkpoints/final_adapter/` | Final adapter |
### Customizing Save Path
```bash
python example_trainer/grpo.py \
--save-path /path/to/my/checkpoints \
...
```
### Loading Checkpoints for Inference
```python
# Full model (Legacy/Shared modes)
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("trained_model_checkpoints/final_model")
tokenizer = AutoTokenizer.from_pretrained("trained_model_checkpoints/final_model")
# LoRA adapter
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM
base_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B-Instruct")
model = PeftModel.from_pretrained(base_model, "trained_model_checkpoints/final_adapter")
```
---
## vLLM Server Requirements
When using `--openai.server_type vllm` or the shared_vllm bridge, your vLLM server must expose these endpoints:
### Required Endpoints
| Endpoint | Method | Purpose | Used By |
|----------|--------|---------|---------|
| `/health` | GET | Health check | All modes |
| `/generate` | POST | Native generation with token IDs + logprobs | VLLMServer class |
### Required `/generate` Request Format
The vLLM server must handle **both** prompt formats:
```json
// String prompt (simple)
{
"prompt": "Hello, world!",
"max_tokens": 100,
"temperature": 1.0,
"logprobs": 1
}
// Token ID prompt (used by atroposlib)
{
"prompt": {"prompt_token_ids": [1, 2, 3, 4, 5]},
"max_tokens": 100,
"temperature": 1.0,
"logprobs": 1
}
```
### Required `/generate` Response Format
```json
{
"text": ["generated text here"],
"prompt": "original prompt",
"finish_reasons": ["stop"],
"logprobs": [
[
[{"12345": -0.5}],
[{"67890": -1.2}]
]
],
"prompt_token_ids": [1, 2, 3, 4, 5],
"token_ids": [[12345, 67890, ...]]
}
```
The `logprobs` field format: `List[List[List[Dict[token_id, logprob]]]]`
- Outer list: per completion (n samples)
- Middle list: per token in completion
- Inner list: contains single dict `{token_id: logprob}`
### Optional Bridge Endpoints (for shared_vllm mode)
| Endpoint | Method | Purpose |
|----------|--------|---------|
| `/bridge/info` | GET | Get bridge status |
| `/bridge/notify_update` | POST | Receive weight update notifications |
| `/bridge/state_dict_info` | GET | Get model parameter mappings |
### Optional LoRA Endpoints (for lora_only mode)
| Endpoint | Method | Purpose |
|----------|--------|---------|
| `/lora/status` | GET | Get active LoRA adapter |
| `/lora/load` | POST | Load new LoRA adapter |
| `/lora/unload` | POST | Unload current adapter |
### Using Standard vLLM vs Custom Server
| Server | Supports `/generate` with logprobs | Supports bridge | Supports LoRA hot-swap |
|--------|-----------------------------------|-----------------|------------------------|
| `vllm serve ...` | ❌ No | ❌ No | ❌ No |
| `vllm_api_server.py` | ✅ Yes | ✅ Yes | ✅ Yes |
**Use `example_trainer/vllm_api_server.py` for full feature support.**
---
## Benchmarking Speed & Memory
### Memory Usage Comparison
```bash
# Run this during training to monitor GPU memory
watch -n 1 nvidia-smi
```
**Expected Memory Usage (Qwen2.5-3B-Instruct):**
| Mode | Trainer GPU | vLLM GPU | Total |
|------|------------|----------|-------|
| **Legacy** | ~8GB | ~8GB | ~16GB (2x model) |
| **Shared vLLM** | ~8GB (shared) | ~8GB (shared) | ~8GB (1x model) |
| **LoRA** | ~10GB (frozen base) | ~8GB | ~18GB |
### Speed Benchmarking
Add these measurements to your training script or use the built-in wandb logging:
```python
import time
import torch
# Track step times
step_times = []
sync_times = []
for step in range(training_steps):
# Measure training step time
step_start = time.time()
# ... training code ...
step_time = time.time() - step_start
step_times.append(step_time)
# Measure sync time (Legacy mode only)
if step % vllm_restart_interval == 0:
sync_start = time.time()
# ... checkpoint + restart vLLM ...
sync_time = time.time() - sync_start
sync_times.append(sync_time)
# Print summary
print(f"Avg step time: {sum(step_times)/len(step_times):.2f}s")
print(f"Avg sync time: {sum(sync_times)/len(sync_times):.2f}s" if sync_times else "No syncs")
```
### Benchmark Script
Create a benchmark comparing modes:
```bash
#!/bin/bash
# benchmark_modes.sh
MODEL="Qwen/Qwen2.5-3B-Instruct"
STEPS=50
BATCH=2
ACCUM=16
echo "=== Benchmarking Legacy Mode ==="
time python example_trainer/grpo.py \
--model-name $MODEL \
--weight-bridge-mode none \
--training-steps $STEPS \
--batch-size $BATCH \
--gradient-accumulation-steps $ACCUM \
--vllm-restart-interval 10 \
2>&1 | tee benchmark_legacy.log
echo "=== Benchmarking Shared vLLM Mode ==="
export LOGDIR=/tmp/bench_shared
export NUM_INFERENCE_NODES=0
mkdir -p $LOGDIR
# Start vLLM first
python example_trainer/vllm_api_server.py \
--model $MODEL --port 9001 --gpu-memory-utilization 0.45 &
VLLM_PID=$!
sleep 60 # Wait for vLLM to load
time python example_trainer/grpo.py \
--model-name $MODEL \
--weight-bridge-mode shared_vllm \
--training-steps $STEPS \
--batch-size $BATCH \
--gradient-accumulation-steps $ACCUM \
--num-inference-nodes 0 \
2>&1 | tee benchmark_shared.log
kill $VLLM_PID
echo "=== Benchmarking LoRA Mode ==="
time python example_trainer/grpo.py \
--model-name $MODEL \
--weight-bridge-mode lora_only \
--training-steps $STEPS \
--batch-size $BATCH \
--gradient-accumulation-steps $ACCUM \
--lora-r 16 \
--vllm-restart-interval 25 \
2>&1 | tee benchmark_lora.log
echo "=== Summary ==="
echo "Check benchmark_*.log for detailed timing"
```
### Expected Benchmark Results
| Metric | Legacy | Shared vLLM | LoRA |
|--------|--------|-------------|------|
| **Step time** | ~2-5s | ~2-5s | ~1-3s |
| **Sync overhead** | ~30-60s every N steps | ~0ms | ~5-10s every N steps |
| **Total time (50 steps, sync every 10)** | ~15-20 min | ~3-5 min | ~5-8 min |
| **Peak GPU memory** | ~16GB | ~8GB | ~10GB |
| **Checkpoint size** | ~6GB | ~6GB | ~50MB |
### WandB Metrics to Watch
If using `--use-wandb`, these metrics are logged:
| Metric | Description |
|--------|-------------|
| `train/loss` | GRPO loss |
| `train/grad_norm` | Gradient norm |
| `train/pos_logp` | Log prob of positive examples |
| `train/neg_logp` | Log prob of negative examples |
| `train/step_time` | Time per training step |
| `train/sync_time` | Time for weight sync (legacy/lora) |
---
## Files in This Directory
| File | Description |