diff --git a/example_trainer/README.md b/example_trainer/README.md index 409c3c7b..c396d45f 100644 --- a/example_trainer/README.md +++ b/example_trainer/README.md @@ -371,6 +371,260 @@ pip install peft --- +## Checkpoint Locations + +### Where Are Trained Models Saved? + +| Mode | Location | Contents | +|------|----------|----------| +| **Legacy** | `trained_model_checkpoints/step_N/` | Full model + tokenizer | +| **Legacy** | `trained_model_checkpoints/final_model/` | Final checkpoint | +| **Shared vLLM** | `trained_model_checkpoints/step_N/` | Full model + tokenizer | +| **LoRA** | `trained_model_checkpoints/adapter_step_N/` | LoRA adapters only (~10-50MB) | +| **LoRA** | `trained_model_checkpoints/final_adapter/` | Final adapter | + +### Customizing Save Path + +```bash +python example_trainer/grpo.py \ + --save-path /path/to/my/checkpoints \ + ... +``` + +### Loading Checkpoints for Inference + +```python +# Full model (Legacy/Shared modes) +from transformers import AutoModelForCausalLM, AutoTokenizer +model = AutoModelForCausalLM.from_pretrained("trained_model_checkpoints/final_model") +tokenizer = AutoTokenizer.from_pretrained("trained_model_checkpoints/final_model") + +# LoRA adapter +from peft import PeftModel, PeftConfig +from transformers import AutoModelForCausalLM + +base_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B-Instruct") +model = PeftModel.from_pretrained(base_model, "trained_model_checkpoints/final_adapter") +``` + +--- + +## vLLM Server Requirements + +When using `--openai.server_type vllm` or the shared_vllm bridge, your vLLM server must expose these endpoints: + +### Required Endpoints + +| Endpoint | Method | Purpose | Used By | +|----------|--------|---------|---------| +| `/health` | GET | Health check | All modes | +| `/generate` | POST | Native generation with token IDs + logprobs | VLLMServer class | + +### Required `/generate` Request Format + +The vLLM server must handle **both** prompt formats: + +```json +// String prompt (simple) +{ + "prompt": "Hello, world!", + "max_tokens": 100, + "temperature": 1.0, + "logprobs": 1 +} + +// Token ID prompt (used by atroposlib) +{ + "prompt": {"prompt_token_ids": [1, 2, 3, 4, 5]}, + "max_tokens": 100, + "temperature": 1.0, + "logprobs": 1 +} +``` + +### Required `/generate` Response Format + +```json +{ + "text": ["generated text here"], + "prompt": "original prompt", + "finish_reasons": ["stop"], + "logprobs": [ + [ + [{"12345": -0.5}], + [{"67890": -1.2}] + ] + ], + "prompt_token_ids": [1, 2, 3, 4, 5], + "token_ids": [[12345, 67890, ...]] +} +``` + +The `logprobs` field format: `List[List[List[Dict[token_id, logprob]]]]` +- Outer list: per completion (n samples) +- Middle list: per token in completion +- Inner list: contains single dict `{token_id: logprob}` + +### Optional Bridge Endpoints (for shared_vllm mode) + +| Endpoint | Method | Purpose | +|----------|--------|---------| +| `/bridge/info` | GET | Get bridge status | +| `/bridge/notify_update` | POST | Receive weight update notifications | +| `/bridge/state_dict_info` | GET | Get model parameter mappings | + +### Optional LoRA Endpoints (for lora_only mode) + +| Endpoint | Method | Purpose | +|----------|--------|---------| +| `/lora/status` | GET | Get active LoRA adapter | +| `/lora/load` | POST | Load new LoRA adapter | +| `/lora/unload` | POST | Unload current adapter | + +### Using Standard vLLM vs Custom Server + +| Server | Supports `/generate` with logprobs | Supports bridge | Supports LoRA hot-swap | +|--------|-----------------------------------|-----------------|------------------------| +| `vllm serve ...` | ❌ No | ❌ No | ❌ No | +| `vllm_api_server.py` | ✅ Yes | ✅ Yes | ✅ Yes | + +**Use `example_trainer/vllm_api_server.py` for full feature support.** + +--- + +## Benchmarking Speed & Memory + +### Memory Usage Comparison + +```bash +# Run this during training to monitor GPU memory +watch -n 1 nvidia-smi +``` + +**Expected Memory Usage (Qwen2.5-3B-Instruct):** + +| Mode | Trainer GPU | vLLM GPU | Total | +|------|------------|----------|-------| +| **Legacy** | ~8GB | ~8GB | ~16GB (2x model) | +| **Shared vLLM** | ~8GB (shared) | ~8GB (shared) | ~8GB (1x model) | +| **LoRA** | ~10GB (frozen base) | ~8GB | ~18GB | + +### Speed Benchmarking + +Add these measurements to your training script or use the built-in wandb logging: + +```python +import time +import torch + +# Track step times +step_times = [] +sync_times = [] + +for step in range(training_steps): + # Measure training step time + step_start = time.time() + # ... training code ... + step_time = time.time() - step_start + step_times.append(step_time) + + # Measure sync time (Legacy mode only) + if step % vllm_restart_interval == 0: + sync_start = time.time() + # ... checkpoint + restart vLLM ... + sync_time = time.time() - sync_start + sync_times.append(sync_time) + +# Print summary +print(f"Avg step time: {sum(step_times)/len(step_times):.2f}s") +print(f"Avg sync time: {sum(sync_times)/len(sync_times):.2f}s" if sync_times else "No syncs") +``` + +### Benchmark Script + +Create a benchmark comparing modes: + +```bash +#!/bin/bash +# benchmark_modes.sh + +MODEL="Qwen/Qwen2.5-3B-Instruct" +STEPS=50 +BATCH=2 +ACCUM=16 + +echo "=== Benchmarking Legacy Mode ===" +time python example_trainer/grpo.py \ + --model-name $MODEL \ + --weight-bridge-mode none \ + --training-steps $STEPS \ + --batch-size $BATCH \ + --gradient-accumulation-steps $ACCUM \ + --vllm-restart-interval 10 \ + 2>&1 | tee benchmark_legacy.log + +echo "=== Benchmarking Shared vLLM Mode ===" +export LOGDIR=/tmp/bench_shared +export NUM_INFERENCE_NODES=0 +mkdir -p $LOGDIR + +# Start vLLM first +python example_trainer/vllm_api_server.py \ + --model $MODEL --port 9001 --gpu-memory-utilization 0.45 & +VLLM_PID=$! +sleep 60 # Wait for vLLM to load + +time python example_trainer/grpo.py \ + --model-name $MODEL \ + --weight-bridge-mode shared_vllm \ + --training-steps $STEPS \ + --batch-size $BATCH \ + --gradient-accumulation-steps $ACCUM \ + --num-inference-nodes 0 \ + 2>&1 | tee benchmark_shared.log + +kill $VLLM_PID + +echo "=== Benchmarking LoRA Mode ===" +time python example_trainer/grpo.py \ + --model-name $MODEL \ + --weight-bridge-mode lora_only \ + --training-steps $STEPS \ + --batch-size $BATCH \ + --gradient-accumulation-steps $ACCUM \ + --lora-r 16 \ + --vllm-restart-interval 25 \ + 2>&1 | tee benchmark_lora.log + +echo "=== Summary ===" +echo "Check benchmark_*.log for detailed timing" +``` + +### Expected Benchmark Results + +| Metric | Legacy | Shared vLLM | LoRA | +|--------|--------|-------------|------| +| **Step time** | ~2-5s | ~2-5s | ~1-3s | +| **Sync overhead** | ~30-60s every N steps | ~0ms | ~5-10s every N steps | +| **Total time (50 steps, sync every 10)** | ~15-20 min | ~3-5 min | ~5-8 min | +| **Peak GPU memory** | ~16GB | ~8GB | ~10GB | +| **Checkpoint size** | ~6GB | ~6GB | ~50MB | + +### WandB Metrics to Watch + +If using `--use-wandb`, these metrics are logged: + +| Metric | Description | +|--------|-------------| +| `train/loss` | GRPO loss | +| `train/grad_norm` | Gradient norm | +| `train/pos_logp` | Log prob of positive examples | +| `train/neg_logp` | Log prob of negative examples | +| `train/step_time` | Time per training step | +| `train/sync_time` | Time for weight sync (legacy/lora) | + +--- + ## Files in This Directory | File | Description | diff --git a/example_trainer/grpo.py b/example_trainer/grpo.py index f9d24608..4464869f 100644 --- a/example_trainer/grpo.py +++ b/example_trainer/grpo.py @@ -682,9 +682,37 @@ def log_metrics( wandb.log(log_dict, step=step) -def finalize_training(use_wandb: bool) -> None: - """Clean up after training completes.""" +def finalize_training( + use_wandb: bool, + training_start_time: Optional[float] = None, + mode: str = "unknown", + total_steps: int = 0, +) -> None: + """Clean up after training and log benchmark summary.""" print("\nTraining finished.") + + # Log benchmark summary + if training_start_time is not None: + total_time = time.time() - training_start_time + gpu_mem_gb = torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0 + + print(f"\n{'='*60}") + print(f"BENCHMARK SUMMARY ({mode})") + print(f"{'='*60}") + print(f" Total training time: {total_time:.2f}s ({total_time/60:.2f} min)") + print(f" Total steps: {total_steps}") + print(f" Avg time per step: {total_time/max(total_steps,1):.2f}s") + print(f" Peak GPU memory: {gpu_mem_gb:.2f} GB") + print(f"{'='*60}\n") + + if use_wandb: + wandb.summary["benchmark/total_time_seconds"] = total_time + wandb.summary["benchmark/total_time_minutes"] = total_time / 60 + wandb.summary["benchmark/avg_step_time_seconds"] = total_time / max(total_steps, 1) + wandb.summary["benchmark/peak_gpu_memory_gb"] = gpu_mem_gb + wandb.summary["benchmark/mode"] = mode + wandb.summary["benchmark/total_steps"] = total_steps + if use_wandb: wandb.finish() @@ -697,6 +725,7 @@ def train(config: TrainingConfig): Use weight_bridge_mode='shared_vllm' for in-place weight updates without restarts. """ global vllm_process + training_start_time = time.time() # === Setup === use_wandb = setup_wandb(config) @@ -708,6 +737,7 @@ def train(config: TrainingConfig): print(f"{'='*60}") print(f"Training for {config.training_steps} steps on {config.device}") print(f"vLLM restart interval: every {config.vllm_restart_interval} steps") + print(f"Save path: {config.save_path}") print(f"{'='*60}\n") os.makedirs(config.save_path, exist_ok=True) @@ -753,8 +783,8 @@ def train(config: TrainingConfig): _check_vllm_health() # === Cleanup === - finalize_training(use_wandb) save_checkpoint(model, tokenizer, config.save_path, config.training_steps, is_final=True) + finalize_training(use_wandb, training_start_time, "legacy", config.training_steps) # ============================================================================= @@ -846,6 +876,8 @@ def train_shared_vllm(config: TrainingConfig): "Ensure vllm_weight_bridge.py is in the same directory." ) + training_start_time = time.time() + # === Setup === use_wandb = setup_wandb(config) @@ -856,6 +888,7 @@ def train_shared_vllm(config: TrainingConfig): print(f"Distributed: rank={config.trainer_rank}/{config.world_size}") print(f"Init method: {config.init_method}") print(f"Inference nodes: {config.num_inference_nodes}") + print(f"Save path: {config.save_path}") print(f"{'='*60}\n") # Initialize weight bridge @@ -907,8 +940,8 @@ def train_shared_vllm(config: TrainingConfig): # === Cleanup === bridge.cleanup() - finalize_training(use_wandb) save_checkpoint(model, tokenizer, config.save_path, config.training_steps, is_final=True) + finalize_training(use_wandb, training_start_time, "shared_vllm", config.training_steps) def train_lora(config: TrainingConfig): @@ -926,6 +959,8 @@ def train_lora(config: TrainingConfig): "PEFT library required for LoRA mode. Install with: pip install peft" ) + training_start_time = time.time() + # === Setup === use_wandb = setup_wandb(config) @@ -934,6 +969,7 @@ def train_lora(config: TrainingConfig): print(f"{'='*60}") print(f"Base model: {config.model_name}") print(f"LoRA config: r={config.lora_r}, alpha={config.lora_alpha}") + print(f"Save path: {config.save_path}") print(f"{'='*60}\n") # Load model with LoRA adapters @@ -984,10 +1020,10 @@ def train_lora(config: TrainingConfig): # === Cleanup === _terminate_vllm_process() - finalize_training(use_wandb) # Save final adapter save_lora_checkpoint(model, config.save_path, config.training_steps, is_final=True) + finalize_training(use_wandb, training_start_time, "lora_only", config.training_steps) # Also save tokenizer for convenience tokenizer_path = os.path.join(config.save_path, "tokenizer")