diff --git a/example_trainer/README.md b/example_trainer/README.md index 910ec56f..c197250b 100644 --- a/example_trainer/README.md +++ b/example_trainer/README.md @@ -58,21 +58,35 @@ Data Flow: --- -## Three Training Modes +## Four Training Modes -| Mode | Description | Memory | Best For | -|------|-------------|--------|----------| -| **shared_vllm** | Single-copy via CUDA IPC | 1x model | Same GPU, maximum efficiency | -| **lora_only** | Train adapters, HTTP hot-swap | 1x + small adapter | Simple setup, debugging | -| **legacy** | Full model, restart vLLM | 2x model | Different GPUs, simple setup | +| Mode | Description | Memory | Inference Speed | Best For | +|------|-------------|--------|-----------------|----------| +| **shared_vllm** | Single-copy via CUDA IPC | 1x model | ~170 TPS | Same GPU, maximum efficiency | +| **lora_restart** | LoRA + vLLM restarts | 1x + adapter | ~170 TPS | LoRA training with speed | +| **lora_only** | LoRA + HTTP hot-swap | 1x + adapter | ~13 TPS ⚠️ | Debugging only | +| **legacy** | Full model, restart vLLM | 2x model | ~170 TPS | Different GPUs, simple setup | + +### ⚠️ IMPORTANT: `lora_only` Performance Warning + +The `lora_only` mode requires `--enforce-eager` which **disables CUDA graphs**, resulting in: +- **12x slower inference** (~13 TPS vs ~170 TPS) +- Training that takes **4x longer** (401 min vs 132 min for 120 steps) + +**Use `lora_restart` instead** - it restarts vLLM to keep CUDA graphs enabled. ### Recommendation -**Start with `lora_only`** - it's the easiest to set up and debug. +**Use `shared_vllm`** for production training when: +- You have enough GPU memory for the full model +- You want fastest training (no overhead) -**Use `shared_vllm`** for production training when you need: -- Fastest weight synchronization (CUDA IPC, zero-copy updates) -- True on-policy training (vLLM sees updates immediately) +**Use `lora_restart`** when: +- You want LoRA's memory efficiency +- You want fast inference (~170 TPS with CUDA graphs) +- You can tolerate ~45s restart overhead every N steps + +**Avoid `lora_only`** unless you're debugging - the 12x inference penalty is severe. **Use `shared_vllm`** for single-GPU training when you need maximum efficiency. diff --git a/example_trainer/cli.py b/example_trainer/cli.py index 046e3ec3..847e4202 100644 --- a/example_trainer/cli.py +++ b/example_trainer/cli.py @@ -121,6 +121,12 @@ def add_vllm_args(parser: argparse.ArgumentParser) -> None: default=9001, help="Port for the vLLM server", ) + group.add_argument( + "--vllm-gpu", + type=int, + default=None, + help="GPU ID for vLLM server. If not set, uses same GPU as trainer.", + ) group.add_argument( "--gpu-memory-utilization", "--vllm-gpu-memory-utilization", @@ -146,7 +152,7 @@ def add_vllm_args(parser: argparse.ArgumentParser) -> None: "--vllm-restart-interval", type=int, default=3, - help="Restart vLLM every N training steps (legacy mode only)", + help="Restart vLLM every N training steps (legacy/lora_restart modes)", ) @@ -189,9 +195,12 @@ def add_mode_args(parser: argparse.ArgumentParser) -> None: group.add_argument( "--weight-bridge-mode", type=str, - choices=["shared_vllm", "lora_only", "none"], + choices=["shared_vllm", "lora_only", "lora_restart", "none"], default="none", - help="Weight sync mode: 'shared_vllm', 'lora_only', or 'none' (legacy)", + help=( + "Weight sync mode: 'shared_vllm' (CUDA IPC), 'lora_only' (slow, --enforce-eager), " + "'lora_restart' (fast, restarts vLLM), or 'none' (legacy)" + ), ) group.add_argument( "--vllm-config-path", @@ -348,6 +357,7 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig: # vLLM settings vllm_restart_interval=getattr(args, "vllm_restart_interval", 3), vllm_port=args.vllm_port, + vllm_gpu=getattr(args, "vllm_gpu", None), vllm_gpu_memory_utilization=getattr(args, "gpu_memory_utilization", 0.45), max_model_len=getattr(args, "max_model_len", 4096), dtype=getattr(args, "dtype", "bfloat16"), diff --git a/example_trainer/config.py b/example_trainer/config.py index 6a99a8e1..c43524aa 100644 --- a/example_trainer/config.py +++ b/example_trainer/config.py @@ -87,6 +87,13 @@ class TrainingConfig(BaseModel): 3, description="Restart vLLM every N training steps (legacy mode)" ) vllm_port: int = Field(9001, description="Port for the vLLM server") + vllm_gpu: Optional[int] = Field( + None, + description=( + "GPU ID for vLLM server (lora_restart/legacy modes). " + "If None, uses same GPU as trainer. Set different for parallelism." + ), + ) vllm_gpu_memory_utilization: float = Field( 0.45, description="GPU memory utilization for vLLM server (0.0-1.0)" ) @@ -105,12 +112,13 @@ class TrainingConfig(BaseModel): wandb_group: Optional[str] = Field(None, description="Wandb group name") # === Training Mode Configuration === - weight_bridge_mode: Literal["shared_vllm", "lora_only", "none"] = Field( + weight_bridge_mode: Literal["shared_vllm", "lora_only", "lora_restart", "none"] = Field( "none", description=( "How to synchronize weights with inference server. " "'shared_vllm': attach to vLLM's shared memory tensors and update in-place. " - "'lora_only': keep base model frozen, train/swap LoRA adapters via HTTP. " + "'lora_only': keep base model frozen, train/swap LoRA adapters via HTTP (slow, needs --enforce-eager). " + "'lora_restart': LoRA training with vLLM restarts (fast, CUDA graphs enabled). " "'none': legacy mode, restart vLLM with new checkpoint files." ), ) diff --git a/example_trainer/grpo.py b/example_trainer/grpo.py index 5b5d39a1..41eb0063 100644 --- a/example_trainer/grpo.py +++ b/example_trainer/grpo.py @@ -2,10 +2,11 @@ """ GRPO (Group Relative Policy Optimization) Trainer. -Supports three training modes: +Supports four training modes: - none (legacy): Periodic checkpoint saves + vLLM restarts - shared_vllm: Single-copy mode with CUDA IPC weight sharing -- lora_only: LoRA adapter training with HTTP hot-swap +- lora_only: LoRA adapter training with HTTP hot-swap (SLOW - needs --enforce-eager) +- lora_restart: LoRA training with vLLM restarts (FAST - CUDA graphs enabled) Usage: # Legacy mode (manages vLLM internally) @@ -15,13 +16,18 @@ Usage: python -m example_trainer.grpo --model-name Qwen/Qwen2.5-3B-Instruct \\ --weight-bridge-mode shared_vllm - # LoRA mode (requires external vLLM with --enable-lora --enforce-eager) + # LoRA mode with HTTP hot-swap (SLOW - 13 TPS due to --enforce-eager) python -m example_trainer.grpo --model-name Qwen/Qwen2.5-3B-Instruct \\ --weight-bridge-mode lora_only --lora-r 16 --lora-alpha 32 + + # LoRA mode with vLLM restarts (FAST - 170 TPS with CUDA graphs) + python -m example_trainer.grpo --model-name Qwen/Qwen2.5-3B-Instruct \\ + --weight-bridge-mode lora_restart --lora-r 16 --lora-alpha 32 \\ + --vllm-restart-interval 3 """ from .cli import config_from_args, parse_args -from .trainers import train_legacy, train_lora, train_shared_vllm +from .trainers import train_legacy, train_lora, train_lora_restart, train_shared_vllm def main(): @@ -44,8 +50,14 @@ def main(): elif config.weight_bridge_mode == "lora_only": # LoRA mode: freeze base model, train adapters only (HTTP hot-swap) + # WARNING: This is SLOW (~13 TPS) because it requires --enforce-eager train_lora(config) + elif config.weight_bridge_mode == "lora_restart": + # LoRA mode with vLLM restarts (FAST - uses CUDA graphs) + # Restarts vLLM every vllm_restart_interval steps with new adapter + train_lora_restart(config) + else: # Legacy mode: periodic checkpoint saves + vLLM restarts train_legacy(config) diff --git a/example_trainer/scripts/compare_lora_modes.sh b/example_trainer/scripts/compare_lora_modes.sh new file mode 100755 index 00000000..a8f554cd --- /dev/null +++ b/example_trainer/scripts/compare_lora_modes.sh @@ -0,0 +1,326 @@ +#!/bin/bash +# ============================================================================ +# Compare lora_restart vs lora_only performance +# ============================================================================ +# Runs both modes in parallel with separate APIs/environments/ports +# All commands run in background (single terminal) +# Results uploaded to W&B +# +# Usage: +# ./compare_lora_modes.sh [steps] +# ./compare_lora_modes.sh 30 # 30 steps (default) +# ./compare_lora_modes.sh 10 # Quick 10-step test +# ============================================================================ + +set -e + +# Configuration +MODEL="Qwen/Qwen3-4B-Instruct-2507" +STEPS="${1:-30}" +RESTART_INTERVAL=3 +WANDB_PROJECT="lora-mode-comparison" + +# Port allocation +# lora_restart: API 8001, vLLM 9001 +# lora_only: API 8002, vLLM 9002 + +echo "============================================================================" +echo "LoRA Mode Comparison: lora_restart vs lora_only" +echo "============================================================================" +echo "Model: $MODEL" +echo "Steps: $STEPS" +echo "Restart interval: $RESTART_INTERVAL" +echo "W&B project: $WANDB_PROJECT" +echo "" +echo "Port allocation:" +echo " lora_restart: API=8001, vLLM=9001, GPU=0" +echo " lora_only: API=8002, vLLM=9002, GPU=1" +echo "============================================================================" + +# Get script directory and repo root +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" +cd "$REPO_ROOT" +echo "Working directory: $(pwd)" + +# Create log directory +LOGDIR="./lora_comparison_$(date +%Y%m%d_%H%M%S)" +mkdir -p "$LOGDIR" +echo "Log directory: $LOGDIR" + +# Cleanup function +cleanup() { + echo "" + echo "Cleaning up all processes..." + + # Kill by name + pkill -f "gsm8k_server.py" 2>/dev/null || true + pkill -f "run-api" 2>/dev/null || true + pkill -f "vllm_api_server.py" 2>/dev/null || true + pkill -f "example_trainer.grpo" 2>/dev/null || true + + # Kill by port + for port in 8001 8002 9001 9002; do + fuser -k ${port}/tcp 2>/dev/null || true + done + + echo "Cleanup complete." +} +trap cleanup EXIT + +# Initial cleanup +echo "" +echo "Killing any existing processes on ports 8001, 8002, 9001, 9002..." +cleanup +sleep 3 + +# ============================================================================ +# MODE 1: lora_restart (GPU 0, ports 8001/9001) +# ============================================================================ +echo "" +echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" +echo "[1/2] LORA_RESTART MODE (GPU 0, API:8001, vLLM:9001)" +echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + +# Start API for lora_restart +echo " Starting API server (port 8001)..." +run-api --port 8001 > "$LOGDIR/api_restart.log" 2>&1 & +RESTART_API_PID=$! +sleep 3 + +# Check API is up +if curl -s "http://localhost:8001/info" > /dev/null 2>&1; then + echo " ✓ API running (PID: $RESTART_API_PID)" +else + echo " ✗ API failed to start" + cat "$LOGDIR/api_restart.log" + exit 1 +fi + +# Start trainer (lora_restart manages vLLM internally) +echo " Starting lora_restart trainer (will launch vLLM on port 9001)..." +CUDA_VISIBLE_DEVICES=0 python -m example_trainer.grpo \ + --model-name "$MODEL" \ + --weight-bridge-mode lora_restart \ + --vllm-port 9001 \ + --atropos-url http://localhost:8001 \ + --lora-r 16 \ + --lora-alpha 32 \ + --training-steps $STEPS \ + --vllm-restart-interval $RESTART_INTERVAL \ + --save-path "$LOGDIR/checkpoints_restart" \ + --use-wandb \ + --wandb-project "$WANDB_PROJECT" \ + --wandb-group "comparison-$(date +%Y%m%d)" \ + --benchmark \ + > "$LOGDIR/trainer_restart.log" 2>&1 & +RESTART_TRAINER_PID=$! +echo " ✓ Trainer started (PID: $RESTART_TRAINER_PID)" + +# Wait for vLLM to be ready (trainer launches it) +echo " Waiting for vLLM to start (port 9001)..." +for i in {1..60}; do + if curl -s "http://localhost:9001/health" > /dev/null 2>&1; then + echo " ✓ vLLM ready after ~${i}s" + break + fi + sleep 2 +done + +# Start environment for lora_restart +echo " Starting environment..." +python -u environments/gsm8k_server.py serve \ + --env.tokenizer_name "$MODEL" \ + --env.rollout_server_url "http://localhost:8001" \ + --env.max_token_length 2048 \ + --env.use_wandb=True \ + --env.wandb_name "lora-restart-env" \ + --openai.model_name "$MODEL" \ + --openai.base_url "http://localhost:9001/v1" \ + --openai.server_type vllm \ + --slurm false \ + > "$LOGDIR/env_restart.log" 2>&1 & +RESTART_ENV_PID=$! +echo " ✓ Environment started (PID: $RESTART_ENV_PID)" + +# ============================================================================ +# MODE 2: lora_only (GPU 1, ports 8002/9002) +# ============================================================================ +echo "" +echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" +echo "[2/2] LORA_ONLY MODE (GPU 1, API:8002, vLLM:9002)" +echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + +# Start API for lora_only +echo " Starting API server (port 8002)..." +run-api --port 8002 > "$LOGDIR/api_only.log" 2>&1 & +ONLY_API_PID=$! +sleep 3 + +# Check API is up +if curl -s "http://localhost:8002/info" > /dev/null 2>&1; then + echo " ✓ API running (PID: $ONLY_API_PID)" +else + echo " ✗ API failed to start" + cat "$LOGDIR/api_only.log" + exit 1 +fi + +# Start vLLM for lora_only (external, with --enforce-eager) +echo " Starting vLLM with --enable-lora --enforce-eager (port 9002)..." +CUDA_VISIBLE_DEVICES=1 python example_trainer/vllm_api_server.py \ + --model "$MODEL" \ + --port 9002 \ + --gpu-memory-utilization 0.45 \ + --enable-lora \ + --max-lora-rank 32 \ + --enforce-eager \ + > "$LOGDIR/vllm_only.log" 2>&1 & +ONLY_VLLM_PID=$! +echo " ✓ vLLM started (PID: $ONLY_VLLM_PID)" + +# Wait for vLLM to be ready +echo " Waiting for vLLM to start (port 9002)..." +for i in {1..90}; do + if curl -s "http://localhost:9002/health" > /dev/null 2>&1; then + echo " ✓ vLLM ready after ~${i}s" + break + fi + sleep 2 +done + +# Start environment for lora_only +echo " Starting environment..." +python -u environments/gsm8k_server.py serve \ + --env.tokenizer_name "$MODEL" \ + --env.rollout_server_url "http://localhost:8002" \ + --env.max_token_length 2048 \ + --env.use_wandb=True \ + --env.wandb_name "lora-only-env" \ + --openai.model_name "$MODEL" \ + --openai.base_url "http://localhost:9002/v1" \ + --openai.server_type vllm \ + --slurm false \ + > "$LOGDIR/env_only.log" 2>&1 & +ONLY_ENV_PID=$! +echo " ✓ Environment started (PID: $ONLY_ENV_PID)" + +# Start trainer for lora_only +echo " Starting lora_only trainer..." +CUDA_VISIBLE_DEVICES=1 python -m example_trainer.grpo \ + --model-name "$MODEL" \ + --weight-bridge-mode lora_only \ + --vllm-port 9002 \ + --atropos-url http://localhost:8002 \ + --lora-r 16 \ + --lora-alpha 32 \ + --training-steps $STEPS \ + --save-path "$LOGDIR/checkpoints_only" \ + --use-wandb \ + --wandb-project "$WANDB_PROJECT" \ + --wandb-group "comparison-$(date +%Y%m%d)" \ + --benchmark \ + > "$LOGDIR/trainer_only.log" 2>&1 & +ONLY_TRAINER_PID=$! +echo " ✓ Trainer started (PID: $ONLY_TRAINER_PID)" + +# ============================================================================ +# Save PIDs and monitor +# ============================================================================ +cat > "$LOGDIR/pids.txt" << EOF +RESTART_API_PID=$RESTART_API_PID +RESTART_TRAINER_PID=$RESTART_TRAINER_PID +RESTART_ENV_PID=$RESTART_ENV_PID +ONLY_API_PID=$ONLY_API_PID +ONLY_VLLM_PID=$ONLY_VLLM_PID +ONLY_ENV_PID=$ONLY_ENV_PID +ONLY_TRAINER_PID=$ONLY_TRAINER_PID +EOF + +echo "" +echo "============================================================================" +echo "All components started!" +echo "============================================================================" +echo "" +echo "📊 Monitor progress:" +echo " tail -f $LOGDIR/trainer_restart.log # lora_restart" +echo " tail -f $LOGDIR/trainer_only.log # lora_only" +echo "" +echo "🔍 Watch both:" +echo " tail -f $LOGDIR/trainer_*.log" +echo "" +echo "📈 W&B Dashboard:" +echo " https://wandb.ai/$WANDB_PROJECT" +echo "" +echo "Waiting for trainers to complete..." +echo "(lora_restart should finish MUCH faster than lora_only)" +echo "" + +# Wait for trainers +RESTART_STATUS="running" +ONLY_STATUS="running" + +while [ "$RESTART_STATUS" = "running" ] || [ "$ONLY_STATUS" = "running" ]; do + sleep 30 + + # Check lora_restart + if [ "$RESTART_STATUS" = "running" ]; then + if ! kill -0 $RESTART_TRAINER_PID 2>/dev/null; then + wait $RESTART_TRAINER_PID 2>/dev/null && RESTART_STATUS="completed" || RESTART_STATUS="failed" + echo " lora_restart: $RESTART_STATUS" + fi + fi + + # Check lora_only + if [ "$ONLY_STATUS" = "running" ]; then + if ! kill -0 $ONLY_TRAINER_PID 2>/dev/null; then + wait $ONLY_TRAINER_PID 2>/dev/null && ONLY_STATUS="completed" || ONLY_STATUS="failed" + echo " lora_only: $ONLY_STATUS" + fi + fi + + # Show status + if [ "$RESTART_STATUS" = "running" ] || [ "$ONLY_STATUS" = "running" ]; then + echo " [$(date +%H:%M:%S)] lora_restart: $RESTART_STATUS, lora_only: $ONLY_STATUS" + fi +done + +# ============================================================================ +# Print results +# ============================================================================ +echo "" +echo "============================================================================" +echo "COMPARISON RESULTS" +echo "============================================================================" + +echo "" +echo "📊 LORA_RESTART (CUDA graphs, vLLM restarts):" +echo "─────────────────────────────────────────────────" +grep -A 20 "BENCHMARK SUMMARY" "$LOGDIR/trainer_restart.log" 2>/dev/null || echo " (check $LOGDIR/trainer_restart.log)" + +echo "" +echo "📊 LORA_ONLY (--enforce-eager, hot-swap):" +echo "─────────────────────────────────────────────────" +grep -A 20 "BENCHMARK SUMMARY" "$LOGDIR/trainer_only.log" 2>/dev/null || echo " (check $LOGDIR/trainer_only.log)" + +echo "" +echo "============================================================================" +echo "📁 LOGS SAVED TO: $LOGDIR" +echo "============================================================================" +echo "" +echo "Log files:" +echo " $LOGDIR/trainer_restart.log # lora_restart trainer" +echo " $LOGDIR/trainer_only.log # lora_only trainer" +echo " $LOGDIR/vllm_only.log # lora_only vLLM" +echo " $LOGDIR/env_restart.log # lora_restart environment" +echo " $LOGDIR/env_only.log # lora_only environment" +echo "" +echo "Checkpoints:" +echo " $LOGDIR/checkpoints_restart/" +echo " $LOGDIR/checkpoints_only/" +echo "" +echo "W&B runs should be visible at:" +echo " https://wandb.ai/$WANDB_PROJECT" +echo "" +echo "============================================================================" +echo "Done!" diff --git a/example_trainer/scripts/test_lora_restart.sh b/example_trainer/scripts/test_lora_restart.sh new file mode 100755 index 00000000..27177917 --- /dev/null +++ b/example_trainer/scripts/test_lora_restart.sh @@ -0,0 +1,138 @@ +#!/bin/bash +# Quick test script for lora_restart mode +# Tests that the mode works and compares timing + +set -e + +MODEL="Qwen/Qwen3-4B-Instruct-2507" +STEPS=10 +RESTART_INTERVAL=3 + +echo "==============================================" +echo "Testing lora_restart mode" +echo "==============================================" +echo "Model: $MODEL" +echo "Steps: $STEPS" +echo "Restart interval: $RESTART_INTERVAL" +echo "==============================================" + +# Get script directory +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" +cd "$REPO_ROOT" + +# Create log directory +LOGDIR="./lora_restart_test_$(date +%Y%m%d_%H%M%S)" +mkdir -p "$LOGDIR" +echo "Logs: $LOGDIR" + +# Cleanup function +cleanup() { + echo "Cleaning up..." + pkill -f "gsm8k_server.py" 2>/dev/null || true + pkill -f "run-api" 2>/dev/null || true + pkill -f "vllm_api_server.py" 2>/dev/null || true + # Kill by port + for port in 8000 9001; do + fuser -k ${port}/tcp 2>/dev/null || true + done +} +trap cleanup EXIT + +# Kill any existing processes +cleanup +sleep 2 + +# Start API server +echo "" +echo "[1/3] Starting Atropos API..." +run-api --port 8000 > "$LOGDIR/api.log" 2>&1 & +API_PID=$! +sleep 3 + +# Check API is up +if ! curl -s "http://localhost:8000/info" > /dev/null 2>&1; then + echo "ERROR: API server failed to start" + cat "$LOGDIR/api.log" + exit 1 +fi +echo " ✓ API running (PID: $API_PID)" + +# Start trainer (lora_restart manages vLLM internally) +echo "" +echo "[2/3] Starting lora_restart trainer..." +echo " (This will launch vLLM internally)" + +START_TIME=$(date +%s) + +CUDA_VISIBLE_DEVICES=0 python -m example_trainer.grpo \ + --model-name "$MODEL" \ + --weight-bridge-mode lora_restart \ + --vllm-port 9001 \ + --atropos-url http://localhost:8000 \ + --lora-r 16 \ + --lora-alpha 32 \ + --training-steps $STEPS \ + --vllm-restart-interval $RESTART_INTERVAL \ + --save-path "$LOGDIR/checkpoints" \ + --benchmark \ + > "$LOGDIR/trainer.log" 2>&1 & +TRAINER_PID=$! + +# Wait for vLLM to start (trainer launches it) +echo " Waiting for trainer to launch vLLM..." +sleep 30 + +# Start environment (needs to wait for vLLM) +echo "" +echo "[3/3] Starting GSM8K environment..." +python -u environments/gsm8k_server.py serve \ + --env.tokenizer_name "$MODEL" \ + --env.rollout_server_url "http://localhost:8000" \ + --env.max_token_length 2048 \ + --env.use_wandb=False \ + --openai.model_name "$MODEL" \ + --openai.base_url "http://localhost:9001/v1" \ + --openai.server_type vllm \ + --slurm false \ + > "$LOGDIR/env.log" 2>&1 & +ENV_PID=$! +sleep 5 +echo " ✓ Environment running (PID: $ENV_PID)" + +# Wait for trainer to complete +echo "" +echo "Waiting for training to complete..." +echo "(Check progress: tail -f $LOGDIR/trainer.log)" + +wait $TRAINER_PID +TRAINER_EXIT=$? + +END_TIME=$(date +%s) +ELAPSED=$((END_TIME - START_TIME)) + +echo "" +echo "==============================================" +echo "TEST RESULTS" +echo "==============================================" + +if [ $TRAINER_EXIT -eq 0 ]; then + echo "✓ Training completed successfully!" + echo " Time: ${ELAPSED}s" + echo "" + echo "Checkpoints:" + ls -la "$LOGDIR/checkpoints/" 2>/dev/null || echo " (no checkpoints found)" + echo "" + echo "Benchmark summary:" + grep -A 20 "BENCHMARK SUMMARY" "$LOGDIR/trainer.log" 2>/dev/null || echo " (no benchmark found)" +else + echo "✗ Training FAILED (exit code: $TRAINER_EXIT)" + echo "" + echo "Last 50 lines of trainer log:" + tail -50 "$LOGDIR/trainer.log" +fi + +echo "" +echo "==============================================" +echo "Log files saved to: $LOGDIR" +echo "==============================================" diff --git a/example_trainer/trainers.py b/example_trainer/trainers.py index 8485403f..0e413c27 100644 --- a/example_trainer/trainers.py +++ b/example_trainer/trainers.py @@ -5,9 +5,11 @@ Contains the four main training modes: - train_legacy: Checkpoint-based training with vLLM restarts - train_shared_vllm: Single-copy mode with CUDA IPC - train_lora: LoRA adapter training with HTTP hot-swap +- train_lora_restart: LoRA training with vLLM restarts (FAST mode) """ import os +import subprocess import time from typing import Optional @@ -658,3 +660,279 @@ def _hotswap_lora_adapter( return False +def train_lora_restart(config: TrainingConfig): + """ + GRPO training with LoRA adapters using vLLM restarts (FAST mode). + + This mode: + 1. Freezes base model, trains only LoRA adapter weights + 2. Runs vLLM WITH CUDA graphs enabled (no --enforce-eager) + 3. Restarts vLLM every N steps with the new adapter pre-loaded + + Performance comparison: + - lora_only (--enforce-eager): ~13 TPS (SLOW) + - lora_restart (CUDA graphs): ~170 TPS (FAST) + + The restart overhead (~45s) is much less than the 12x inference slowdown. + + Requirements: + - No external vLLM needed - this mode manages vLLM internally + - Requires PEFT library for LoRA + """ + if not PEFT_AVAILABLE: + raise RuntimeError( + "PEFT library required for LoRA mode. Install with: pip install peft" + ) + + training_start_time = time.time() + + # === Setup === + use_wandb = setup_wandb(config) + + print("\n" + "=" * 60) + print("LORA RESTART MODE (fast inference with CUDA graphs)") + print("=" * 60) + print(f"Base model: {config.model_name}") + print(f"LoRA config: r={config.lora_r}, alpha={config.lora_alpha}") + print(f"Save path: {config.save_path}") + print(f"vLLM port: {config.vllm_port}") + print(f"Restart interval: every {config.vllm_restart_interval} steps") + print("=" * 60) + print("NOTE: This mode restarts vLLM to keep CUDA graphs enabled.") + print(" Expected inference speed: ~170 TPS (vs ~13 TPS with --enforce-eager)") + print("=" * 60 + "\n") + + # Load model with LoRA adapters for training + print("[1/4] Loading model with LoRA adapters...") + model, tokenizer = load_model_and_tokenizer(config) + + # Only optimize LoRA parameters + trainable_params = [p for p in model.parameters() if p.requires_grad] + optimizer = AdamW(trainable_params, lr=config.lr) + + os.makedirs(config.save_path, exist_ok=True) + + # Save initial adapter + print("[2/4] Saving initial LoRA adapter...") + initial_adapter_path = save_lora_checkpoint(model, config.save_path, 0) + current_adapter_path = initial_adapter_path + + # Launch vLLM with the initial adapter + print("[3/4] Launching vLLM with CUDA graphs (no --enforce-eager)...") + vllm_proc = _launch_vllm_with_lora(config, current_adapter_path) + if vllm_proc is None: + raise RuntimeError("Failed to launch vLLM") + + print(f"[4/4] Starting training for {config.training_steps} steps") + print("-" * 60) + + # Check Atropos API + if not check_atropos_api(url=config.atropos_url, timeout=30): + _terminate_vllm(vllm_proc) + raise RuntimeError(f"Atropos API not reachable at {config.atropos_url}") + register_trainer(config) + + # === Benchmark tracking === + benchmark_stats = { + "step_times": [], + "sync_times": [], + "data_fetch_times": [], + "gpu_memories": [], + "restart_times": [], + } + + # === Training Loop === + batches = [] + for step in range(config.training_steps): + print(f"\nStep {step+1}/{config.training_steps}") + + # Fetch data (with inference logprobs for proper GRPO) + data_fetch_start = time.time() + if len(batches) == 0: + batches, _ = get_data( + config.batch_size, + config.seq_len, + config.atropos_url, + extract_inference_logprobs=True, + ) + batch_data = batches.pop(0) + token_batches, label_batches, advantage_batches, temperature_batches = ( + batch_data[:4] + ) + inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None + data_fetch_time = time.time() - data_fetch_start + benchmark_stats["data_fetch_times"].append(data_fetch_time) + + # Training step with proper GRPO + step_start = time.time() + metrics = run_training_step( + model, + optimizer, + token_batches, + label_batches, + advantage_batches, + temperature_batches, + config, + inference_logprob_batches=inference_logprob_batches, + ) + step_time = time.time() - step_start + benchmark_stats["step_times"].append(step_time) + + # GPU memory tracking + gpu_mem_gb = ( + torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0 + ) + gpu_mem_reserved_gb = ( + torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 + ) + benchmark_stats["gpu_memories"].append(gpu_mem_gb) + + # Periodic adapter save + vLLM restart + sync_time = 0 + should_sync = (step + 1) % config.vllm_restart_interval == 0 + if should_sync and (step + 1) < config.training_steps: # Don't restart on last step + sync_start = time.time() + + # Save new adapter + current_adapter_path = save_lora_checkpoint(model, config.save_path, step + 1) + + # Restart vLLM with new adapter + print(f" [RESTART] Restarting vLLM with new adapter...") + _terminate_vllm(vllm_proc) + vllm_proc = _launch_vllm_with_lora(config, current_adapter_path) + if vllm_proc is None: + raise RuntimeError("Failed to restart vLLM") + + sync_time = time.time() - sync_start + benchmark_stats["sync_times"].append(sync_time) + benchmark_stats["restart_times"].append(sync_time) + print(f" [RESTART] vLLM restarted in {sync_time:.1f}s") + + # Update metrics + metrics.update( + { + "step_time": step_time, + "sync_time": sync_time, + "data_fetch_time": data_fetch_time, + "gpu_memory_gb": gpu_mem_gb, + "gpu_memory_reserved_gb": gpu_mem_reserved_gb, + } + ) + + log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark) + + # === Cleanup === + print("\nSaving final adapter...") + final_sync_start = time.time() + final_adapter_path = save_lora_checkpoint( + model, config.save_path, config.training_steps, is_final=True + ) + final_sync_time = time.time() - final_sync_start + benchmark_stats["sync_times"].append(final_sync_time) + + # Terminate vLLM + _terminate_vllm(vllm_proc) + + finalize_training( + use_wandb, + training_start_time, + "lora_restart", + config.training_steps, + benchmark_stats, + config.benchmark, + ) + + # Save tokenizer + tokenizer_path = os.path.join(config.save_path, "tokenizer") + tokenizer.save_pretrained(tokenizer_path) + print(f"Tokenizer saved to {tokenizer_path}") + print(f"Final adapter saved to {final_adapter_path}") + + +def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optional[subprocess.Popen]: + """ + Launch vLLM with a LoRA adapter pre-loaded (CUDA graphs enabled). + + Unlike lora_only mode, this does NOT use --enforce-eager, so we get + full CUDA graph speed (~170 TPS instead of ~13 TPS). + """ + from .vllm_manager import kill_process_on_port, wait_for_vllm_ready + + # Kill any existing process on the port + kill_process_on_port(config.vllm_port) + + # Find the vllm_api_server.py script + script_dir = os.path.dirname(os.path.abspath(__file__)) + server_script = os.path.join(script_dir, "vllm_api_server.py") + + # Build command - NO --enforce-eager for full speed + cmd = [ + "python", server_script, + "--model", config.model_name, + "--port", str(config.vllm_port), + "--gpu-memory-utilization", str(config.vllm_gpu_memory_utilization), + "--enable-lora", + "--max-lora-rank", str(max(config.lora_r * 2, 32)), + # Note: NOT adding --enforce-eager - this is the key difference! + # LoRA adapter will be loaded at startup, CUDA graphs compiled with it + ] + + # Set environment for GPU selection + env = os.environ.copy() + if config.vllm_gpu is not None: + env["CUDA_VISIBLE_DEVICES"] = str(config.vllm_gpu) + print(f" GPU: {config.vllm_gpu} (via CUDA_VISIBLE_DEVICES)") + else: + print(f" GPU: Same as trainer (inherited CUDA_VISIBLE_DEVICES)") + + print(f" Launching: {' '.join(cmd)}") + print(f" Adapter: {adapter_path}") + + try: + proc = subprocess.Popen(cmd, env=env) + print(f" vLLM PID: {proc.pid}") + + # Wait for server to be ready + if not wait_for_vllm_ready(config.vllm_port, timeout=180): + print(" ERROR: vLLM failed to start") + proc.terminate() + return None + + # Load the LoRA adapter + print(f" Loading LoRA adapter...") + try: + resp = requests.post( + f"http://localhost:{config.vllm_port}/lora/load", + json={"adapter_path": adapter_path, "adapter_name": "training_adapter"}, + timeout=60, + ) + if resp.status_code == 200: + print(f" ✓ Adapter loaded successfully") + else: + print(f" WARNING: Adapter load returned {resp.status_code}: {resp.text}") + except Exception as e: + print(f" WARNING: Could not load adapter: {e}") + # Continue anyway - base model inference still works + + return proc + + except Exception as e: + print(f" ERROR: {e}") + return None + + +def _terminate_vllm(proc: Optional[subprocess.Popen]) -> None: + """Terminate a vLLM process.""" + if proc is None: + return + + try: + proc.terminate() + proc.wait(timeout=10) + except subprocess.TimeoutExpired: + proc.kill() + proc.wait() + except Exception: + pass + +