This commit is contained in:
Jai Suphavadeeprasit 2025-12-08 12:37:43 -05:00
parent f5c847d39d
commit a7bdc0270d
2 changed files with 295 additions and 5 deletions

View file

@ -371,6 +371,260 @@ pip install peft
---
## Checkpoint Locations
### Where Are Trained Models Saved?
| Mode | Location | Contents |
|------|----------|----------|
| **Legacy** | `trained_model_checkpoints/step_N/` | Full model + tokenizer |
| **Legacy** | `trained_model_checkpoints/final_model/` | Final checkpoint |
| **Shared vLLM** | `trained_model_checkpoints/step_N/` | Full model + tokenizer |
| **LoRA** | `trained_model_checkpoints/adapter_step_N/` | LoRA adapters only (~10-50MB) |
| **LoRA** | `trained_model_checkpoints/final_adapter/` | Final adapter |
### Customizing Save Path
```bash
python example_trainer/grpo.py \
--save-path /path/to/my/checkpoints \
...
```
### Loading Checkpoints for Inference
```python
# Full model (Legacy/Shared modes)
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("trained_model_checkpoints/final_model")
tokenizer = AutoTokenizer.from_pretrained("trained_model_checkpoints/final_model")
# LoRA adapter
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM
base_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B-Instruct")
model = PeftModel.from_pretrained(base_model, "trained_model_checkpoints/final_adapter")
```
---
## vLLM Server Requirements
When using `--openai.server_type vllm` or the shared_vllm bridge, your vLLM server must expose these endpoints:
### Required Endpoints
| Endpoint | Method | Purpose | Used By |
|----------|--------|---------|---------|
| `/health` | GET | Health check | All modes |
| `/generate` | POST | Native generation with token IDs + logprobs | VLLMServer class |
### Required `/generate` Request Format
The vLLM server must handle **both** prompt formats:
```json
// String prompt (simple)
{
"prompt": "Hello, world!",
"max_tokens": 100,
"temperature": 1.0,
"logprobs": 1
}
// Token ID prompt (used by atroposlib)
{
"prompt": {"prompt_token_ids": [1, 2, 3, 4, 5]},
"max_tokens": 100,
"temperature": 1.0,
"logprobs": 1
}
```
### Required `/generate` Response Format
```json
{
"text": ["generated text here"],
"prompt": "original prompt",
"finish_reasons": ["stop"],
"logprobs": [
[
[{"12345": -0.5}],
[{"67890": -1.2}]
]
],
"prompt_token_ids": [1, 2, 3, 4, 5],
"token_ids": [[12345, 67890, ...]]
}
```
The `logprobs` field format: `List[List[List[Dict[token_id, logprob]]]]`
- Outer list: per completion (n samples)
- Middle list: per token in completion
- Inner list: contains single dict `{token_id: logprob}`
### Optional Bridge Endpoints (for shared_vllm mode)
| Endpoint | Method | Purpose |
|----------|--------|---------|
| `/bridge/info` | GET | Get bridge status |
| `/bridge/notify_update` | POST | Receive weight update notifications |
| `/bridge/state_dict_info` | GET | Get model parameter mappings |
### Optional LoRA Endpoints (for lora_only mode)
| Endpoint | Method | Purpose |
|----------|--------|---------|
| `/lora/status` | GET | Get active LoRA adapter |
| `/lora/load` | POST | Load new LoRA adapter |
| `/lora/unload` | POST | Unload current adapter |
### Using Standard vLLM vs Custom Server
| Server | Supports `/generate` with logprobs | Supports bridge | Supports LoRA hot-swap |
|--------|-----------------------------------|-----------------|------------------------|
| `vllm serve ...` | ❌ No | ❌ No | ❌ No |
| `vllm_api_server.py` | ✅ Yes | ✅ Yes | ✅ Yes |
**Use `example_trainer/vllm_api_server.py` for full feature support.**
---
## Benchmarking Speed & Memory
### Memory Usage Comparison
```bash
# Run this during training to monitor GPU memory
watch -n 1 nvidia-smi
```
**Expected Memory Usage (Qwen2.5-3B-Instruct):**
| Mode | Trainer GPU | vLLM GPU | Total |
|------|------------|----------|-------|
| **Legacy** | ~8GB | ~8GB | ~16GB (2x model) |
| **Shared vLLM** | ~8GB (shared) | ~8GB (shared) | ~8GB (1x model) |
| **LoRA** | ~10GB (frozen base) | ~8GB | ~18GB |
### Speed Benchmarking
Add these measurements to your training script or use the built-in wandb logging:
```python
import time
import torch
# Track step times
step_times = []
sync_times = []
for step in range(training_steps):
# Measure training step time
step_start = time.time()
# ... training code ...
step_time = time.time() - step_start
step_times.append(step_time)
# Measure sync time (Legacy mode only)
if step % vllm_restart_interval == 0:
sync_start = time.time()
# ... checkpoint + restart vLLM ...
sync_time = time.time() - sync_start
sync_times.append(sync_time)
# Print summary
print(f"Avg step time: {sum(step_times)/len(step_times):.2f}s")
print(f"Avg sync time: {sum(sync_times)/len(sync_times):.2f}s" if sync_times else "No syncs")
```
### Benchmark Script
Create a benchmark comparing modes:
```bash
#!/bin/bash
# benchmark_modes.sh
MODEL="Qwen/Qwen2.5-3B-Instruct"
STEPS=50
BATCH=2
ACCUM=16
echo "=== Benchmarking Legacy Mode ==="
time python example_trainer/grpo.py \
--model-name $MODEL \
--weight-bridge-mode none \
--training-steps $STEPS \
--batch-size $BATCH \
--gradient-accumulation-steps $ACCUM \
--vllm-restart-interval 10 \
2>&1 | tee benchmark_legacy.log
echo "=== Benchmarking Shared vLLM Mode ==="
export LOGDIR=/tmp/bench_shared
export NUM_INFERENCE_NODES=0
mkdir -p $LOGDIR
# Start vLLM first
python example_trainer/vllm_api_server.py \
--model $MODEL --port 9001 --gpu-memory-utilization 0.45 &
VLLM_PID=$!
sleep 60 # Wait for vLLM to load
time python example_trainer/grpo.py \
--model-name $MODEL \
--weight-bridge-mode shared_vllm \
--training-steps $STEPS \
--batch-size $BATCH \
--gradient-accumulation-steps $ACCUM \
--num-inference-nodes 0 \
2>&1 | tee benchmark_shared.log
kill $VLLM_PID
echo "=== Benchmarking LoRA Mode ==="
time python example_trainer/grpo.py \
--model-name $MODEL \
--weight-bridge-mode lora_only \
--training-steps $STEPS \
--batch-size $BATCH \
--gradient-accumulation-steps $ACCUM \
--lora-r 16 \
--vllm-restart-interval 25 \
2>&1 | tee benchmark_lora.log
echo "=== Summary ==="
echo "Check benchmark_*.log for detailed timing"
```
### Expected Benchmark Results
| Metric | Legacy | Shared vLLM | LoRA |
|--------|--------|-------------|------|
| **Step time** | ~2-5s | ~2-5s | ~1-3s |
| **Sync overhead** | ~30-60s every N steps | ~0ms | ~5-10s every N steps |
| **Total time (50 steps, sync every 10)** | ~15-20 min | ~3-5 min | ~5-8 min |
| **Peak GPU memory** | ~16GB | ~8GB | ~10GB |
| **Checkpoint size** | ~6GB | ~6GB | ~50MB |
### WandB Metrics to Watch
If using `--use-wandb`, these metrics are logged:
| Metric | Description |
|--------|-------------|
| `train/loss` | GRPO loss |
| `train/grad_norm` | Gradient norm |
| `train/pos_logp` | Log prob of positive examples |
| `train/neg_logp` | Log prob of negative examples |
| `train/step_time` | Time per training step |
| `train/sync_time` | Time for weight sync (legacy/lora) |
---
## Files in This Directory
| File | Description |

View file

@ -682,9 +682,37 @@ def log_metrics(
wandb.log(log_dict, step=step)
def finalize_training(use_wandb: bool) -> None:
"""Clean up after training completes."""
def finalize_training(
use_wandb: bool,
training_start_time: Optional[float] = None,
mode: str = "unknown",
total_steps: int = 0,
) -> None:
"""Clean up after training and log benchmark summary."""
print("\nTraining finished.")
# Log benchmark summary
if training_start_time is not None:
total_time = time.time() - training_start_time
gpu_mem_gb = torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0
print(f"\n{'='*60}")
print(f"BENCHMARK SUMMARY ({mode})")
print(f"{'='*60}")
print(f" Total training time: {total_time:.2f}s ({total_time/60:.2f} min)")
print(f" Total steps: {total_steps}")
print(f" Avg time per step: {total_time/max(total_steps,1):.2f}s")
print(f" Peak GPU memory: {gpu_mem_gb:.2f} GB")
print(f"{'='*60}\n")
if use_wandb:
wandb.summary["benchmark/total_time_seconds"] = total_time
wandb.summary["benchmark/total_time_minutes"] = total_time / 60
wandb.summary["benchmark/avg_step_time_seconds"] = total_time / max(total_steps, 1)
wandb.summary["benchmark/peak_gpu_memory_gb"] = gpu_mem_gb
wandb.summary["benchmark/mode"] = mode
wandb.summary["benchmark/total_steps"] = total_steps
if use_wandb:
wandb.finish()
@ -697,6 +725,7 @@ def train(config: TrainingConfig):
Use weight_bridge_mode='shared_vllm' for in-place weight updates without restarts.
"""
global vllm_process
training_start_time = time.time()
# === Setup ===
use_wandb = setup_wandb(config)
@ -708,6 +737,7 @@ def train(config: TrainingConfig):
print(f"{'='*60}")
print(f"Training for {config.training_steps} steps on {config.device}")
print(f"vLLM restart interval: every {config.vllm_restart_interval} steps")
print(f"Save path: {config.save_path}")
print(f"{'='*60}\n")
os.makedirs(config.save_path, exist_ok=True)
@ -753,8 +783,8 @@ def train(config: TrainingConfig):
_check_vllm_health()
# === Cleanup ===
finalize_training(use_wandb)
save_checkpoint(model, tokenizer, config.save_path, config.training_steps, is_final=True)
finalize_training(use_wandb, training_start_time, "legacy", config.training_steps)
# =============================================================================
@ -846,6 +876,8 @@ def train_shared_vllm(config: TrainingConfig):
"Ensure vllm_weight_bridge.py is in the same directory."
)
training_start_time = time.time()
# === Setup ===
use_wandb = setup_wandb(config)
@ -856,6 +888,7 @@ def train_shared_vllm(config: TrainingConfig):
print(f"Distributed: rank={config.trainer_rank}/{config.world_size}")
print(f"Init method: {config.init_method}")
print(f"Inference nodes: {config.num_inference_nodes}")
print(f"Save path: {config.save_path}")
print(f"{'='*60}\n")
# Initialize weight bridge
@ -907,8 +940,8 @@ def train_shared_vllm(config: TrainingConfig):
# === Cleanup ===
bridge.cleanup()
finalize_training(use_wandb)
save_checkpoint(model, tokenizer, config.save_path, config.training_steps, is_final=True)
finalize_training(use_wandb, training_start_time, "shared_vllm", config.training_steps)
def train_lora(config: TrainingConfig):
@ -926,6 +959,8 @@ def train_lora(config: TrainingConfig):
"PEFT library required for LoRA mode. Install with: pip install peft"
)
training_start_time = time.time()
# === Setup ===
use_wandb = setup_wandb(config)
@ -934,6 +969,7 @@ def train_lora(config: TrainingConfig):
print(f"{'='*60}")
print(f"Base model: {config.model_name}")
print(f"LoRA config: r={config.lora_r}, alpha={config.lora_alpha}")
print(f"Save path: {config.save_path}")
print(f"{'='*60}\n")
# Load model with LoRA adapters
@ -984,10 +1020,10 @@ def train_lora(config: TrainingConfig):
# === Cleanup ===
_terminate_vllm_process()
finalize_training(use_wandb)
# Save final adapter
save_lora_checkpoint(model, config.save_path, config.training_steps, is_final=True)
finalize_training(use_wandb, training_start_time, "lora_only", config.training_steps)
# Also save tokenizer for convenience
tokenizer_path = os.path.join(config.save_path, "tokenizer")