mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
lora restart saving gradient changes
This commit is contained in:
parent
1127083b5f
commit
90281f5993
7 changed files with 805 additions and 19 deletions
|
|
@ -58,21 +58,35 @@ Data Flow:
|
|||
|
||||
---
|
||||
|
||||
## Three Training Modes
|
||||
## Four Training Modes
|
||||
|
||||
| Mode | Description | Memory | Best For |
|
||||
|------|-------------|--------|----------|
|
||||
| **shared_vllm** | Single-copy via CUDA IPC | 1x model | Same GPU, maximum efficiency |
|
||||
| **lora_only** | Train adapters, HTTP hot-swap | 1x + small adapter | Simple setup, debugging |
|
||||
| **legacy** | Full model, restart vLLM | 2x model | Different GPUs, simple setup |
|
||||
| Mode | Description | Memory | Inference Speed | Best For |
|
||||
|------|-------------|--------|-----------------|----------|
|
||||
| **shared_vllm** | Single-copy via CUDA IPC | 1x model | ~170 TPS | Same GPU, maximum efficiency |
|
||||
| **lora_restart** | LoRA + vLLM restarts | 1x + adapter | ~170 TPS | LoRA training with speed |
|
||||
| **lora_only** | LoRA + HTTP hot-swap | 1x + adapter | ~13 TPS ⚠️ | Debugging only |
|
||||
| **legacy** | Full model, restart vLLM | 2x model | ~170 TPS | Different GPUs, simple setup |
|
||||
|
||||
### ⚠️ IMPORTANT: `lora_only` Performance Warning
|
||||
|
||||
The `lora_only` mode requires `--enforce-eager` which **disables CUDA graphs**, resulting in:
|
||||
- **12x slower inference** (~13 TPS vs ~170 TPS)
|
||||
- Training that takes **4x longer** (401 min vs 132 min for 120 steps)
|
||||
|
||||
**Use `lora_restart` instead** - it restarts vLLM to keep CUDA graphs enabled.
|
||||
|
||||
### Recommendation
|
||||
|
||||
**Start with `lora_only`** - it's the easiest to set up and debug.
|
||||
**Use `shared_vllm`** for production training when:
|
||||
- You have enough GPU memory for the full model
|
||||
- You want fastest training (no overhead)
|
||||
|
||||
**Use `shared_vllm`** for production training when you need:
|
||||
- Fastest weight synchronization (CUDA IPC, zero-copy updates)
|
||||
- True on-policy training (vLLM sees updates immediately)
|
||||
**Use `lora_restart`** when:
|
||||
- You want LoRA's memory efficiency
|
||||
- You want fast inference (~170 TPS with CUDA graphs)
|
||||
- You can tolerate ~45s restart overhead every N steps
|
||||
|
||||
**Avoid `lora_only`** unless you're debugging - the 12x inference penalty is severe.
|
||||
|
||||
**Use `shared_vllm`** for single-GPU training when you need maximum efficiency.
|
||||
|
||||
|
|
|
|||
|
|
@ -121,6 +121,12 @@ def add_vllm_args(parser: argparse.ArgumentParser) -> None:
|
|||
default=9001,
|
||||
help="Port for the vLLM server",
|
||||
)
|
||||
group.add_argument(
|
||||
"--vllm-gpu",
|
||||
type=int,
|
||||
default=None,
|
||||
help="GPU ID for vLLM server. If not set, uses same GPU as trainer.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gpu-memory-utilization",
|
||||
"--vllm-gpu-memory-utilization",
|
||||
|
|
@ -146,7 +152,7 @@ def add_vllm_args(parser: argparse.ArgumentParser) -> None:
|
|||
"--vllm-restart-interval",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Restart vLLM every N training steps (legacy mode only)",
|
||||
help="Restart vLLM every N training steps (legacy/lora_restart modes)",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -189,9 +195,12 @@ def add_mode_args(parser: argparse.ArgumentParser) -> None:
|
|||
group.add_argument(
|
||||
"--weight-bridge-mode",
|
||||
type=str,
|
||||
choices=["shared_vllm", "lora_only", "none"],
|
||||
choices=["shared_vllm", "lora_only", "lora_restart", "none"],
|
||||
default="none",
|
||||
help="Weight sync mode: 'shared_vllm', 'lora_only', or 'none' (legacy)",
|
||||
help=(
|
||||
"Weight sync mode: 'shared_vllm' (CUDA IPC), 'lora_only' (slow, --enforce-eager), "
|
||||
"'lora_restart' (fast, restarts vLLM), or 'none' (legacy)"
|
||||
),
|
||||
)
|
||||
group.add_argument(
|
||||
"--vllm-config-path",
|
||||
|
|
@ -348,6 +357,7 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig:
|
|||
# vLLM settings
|
||||
vllm_restart_interval=getattr(args, "vllm_restart_interval", 3),
|
||||
vllm_port=args.vllm_port,
|
||||
vllm_gpu=getattr(args, "vllm_gpu", None),
|
||||
vllm_gpu_memory_utilization=getattr(args, "gpu_memory_utilization", 0.45),
|
||||
max_model_len=getattr(args, "max_model_len", 4096),
|
||||
dtype=getattr(args, "dtype", "bfloat16"),
|
||||
|
|
|
|||
|
|
@ -87,6 +87,13 @@ class TrainingConfig(BaseModel):
|
|||
3, description="Restart vLLM every N training steps (legacy mode)"
|
||||
)
|
||||
vllm_port: int = Field(9001, description="Port for the vLLM server")
|
||||
vllm_gpu: Optional[int] = Field(
|
||||
None,
|
||||
description=(
|
||||
"GPU ID for vLLM server (lora_restart/legacy modes). "
|
||||
"If None, uses same GPU as trainer. Set different for parallelism."
|
||||
),
|
||||
)
|
||||
vllm_gpu_memory_utilization: float = Field(
|
||||
0.45, description="GPU memory utilization for vLLM server (0.0-1.0)"
|
||||
)
|
||||
|
|
@ -105,12 +112,13 @@ class TrainingConfig(BaseModel):
|
|||
wandb_group: Optional[str] = Field(None, description="Wandb group name")
|
||||
|
||||
# === Training Mode Configuration ===
|
||||
weight_bridge_mode: Literal["shared_vllm", "lora_only", "none"] = Field(
|
||||
weight_bridge_mode: Literal["shared_vllm", "lora_only", "lora_restart", "none"] = Field(
|
||||
"none",
|
||||
description=(
|
||||
"How to synchronize weights with inference server. "
|
||||
"'shared_vllm': attach to vLLM's shared memory tensors and update in-place. "
|
||||
"'lora_only': keep base model frozen, train/swap LoRA adapters via HTTP. "
|
||||
"'lora_only': keep base model frozen, train/swap LoRA adapters via HTTP (slow, needs --enforce-eager). "
|
||||
"'lora_restart': LoRA training with vLLM restarts (fast, CUDA graphs enabled). "
|
||||
"'none': legacy mode, restart vLLM with new checkpoint files."
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,10 +2,11 @@
|
|||
"""
|
||||
GRPO (Group Relative Policy Optimization) Trainer.
|
||||
|
||||
Supports three training modes:
|
||||
Supports four training modes:
|
||||
- none (legacy): Periodic checkpoint saves + vLLM restarts
|
||||
- shared_vllm: Single-copy mode with CUDA IPC weight sharing
|
||||
- lora_only: LoRA adapter training with HTTP hot-swap
|
||||
- lora_only: LoRA adapter training with HTTP hot-swap (SLOW - needs --enforce-eager)
|
||||
- lora_restart: LoRA training with vLLM restarts (FAST - CUDA graphs enabled)
|
||||
|
||||
Usage:
|
||||
# Legacy mode (manages vLLM internally)
|
||||
|
|
@ -15,13 +16,18 @@ Usage:
|
|||
python -m example_trainer.grpo --model-name Qwen/Qwen2.5-3B-Instruct \\
|
||||
--weight-bridge-mode shared_vllm
|
||||
|
||||
# LoRA mode (requires external vLLM with --enable-lora --enforce-eager)
|
||||
# LoRA mode with HTTP hot-swap (SLOW - 13 TPS due to --enforce-eager)
|
||||
python -m example_trainer.grpo --model-name Qwen/Qwen2.5-3B-Instruct \\
|
||||
--weight-bridge-mode lora_only --lora-r 16 --lora-alpha 32
|
||||
|
||||
# LoRA mode with vLLM restarts (FAST - 170 TPS with CUDA graphs)
|
||||
python -m example_trainer.grpo --model-name Qwen/Qwen2.5-3B-Instruct \\
|
||||
--weight-bridge-mode lora_restart --lora-r 16 --lora-alpha 32 \\
|
||||
--vllm-restart-interval 3
|
||||
"""
|
||||
|
||||
from .cli import config_from_args, parse_args
|
||||
from .trainers import train_legacy, train_lora, train_shared_vllm
|
||||
from .trainers import train_legacy, train_lora, train_lora_restart, train_shared_vllm
|
||||
|
||||
|
||||
def main():
|
||||
|
|
@ -44,8 +50,14 @@ def main():
|
|||
|
||||
elif config.weight_bridge_mode == "lora_only":
|
||||
# LoRA mode: freeze base model, train adapters only (HTTP hot-swap)
|
||||
# WARNING: This is SLOW (~13 TPS) because it requires --enforce-eager
|
||||
train_lora(config)
|
||||
|
||||
elif config.weight_bridge_mode == "lora_restart":
|
||||
# LoRA mode with vLLM restarts (FAST - uses CUDA graphs)
|
||||
# Restarts vLLM every vllm_restart_interval steps with new adapter
|
||||
train_lora_restart(config)
|
||||
|
||||
else:
|
||||
# Legacy mode: periodic checkpoint saves + vLLM restarts
|
||||
train_legacy(config)
|
||||
|
|
|
|||
326
example_trainer/scripts/compare_lora_modes.sh
Executable file
326
example_trainer/scripts/compare_lora_modes.sh
Executable file
|
|
@ -0,0 +1,326 @@
|
|||
#!/bin/bash
|
||||
# ============================================================================
|
||||
# Compare lora_restart vs lora_only performance
|
||||
# ============================================================================
|
||||
# Runs both modes in parallel with separate APIs/environments/ports
|
||||
# All commands run in background (single terminal)
|
||||
# Results uploaded to W&B
|
||||
#
|
||||
# Usage:
|
||||
# ./compare_lora_modes.sh [steps]
|
||||
# ./compare_lora_modes.sh 30 # 30 steps (default)
|
||||
# ./compare_lora_modes.sh 10 # Quick 10-step test
|
||||
# ============================================================================
|
||||
|
||||
set -e
|
||||
|
||||
# Configuration
|
||||
MODEL="Qwen/Qwen3-4B-Instruct-2507"
|
||||
STEPS="${1:-30}"
|
||||
RESTART_INTERVAL=3
|
||||
WANDB_PROJECT="lora-mode-comparison"
|
||||
|
||||
# Port allocation
|
||||
# lora_restart: API 8001, vLLM 9001
|
||||
# lora_only: API 8002, vLLM 9002
|
||||
|
||||
echo "============================================================================"
|
||||
echo "LoRA Mode Comparison: lora_restart vs lora_only"
|
||||
echo "============================================================================"
|
||||
echo "Model: $MODEL"
|
||||
echo "Steps: $STEPS"
|
||||
echo "Restart interval: $RESTART_INTERVAL"
|
||||
echo "W&B project: $WANDB_PROJECT"
|
||||
echo ""
|
||||
echo "Port allocation:"
|
||||
echo " lora_restart: API=8001, vLLM=9001, GPU=0"
|
||||
echo " lora_only: API=8002, vLLM=9002, GPU=1"
|
||||
echo "============================================================================"
|
||||
|
||||
# Get script directory and repo root
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
REPO_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)"
|
||||
cd "$REPO_ROOT"
|
||||
echo "Working directory: $(pwd)"
|
||||
|
||||
# Create log directory
|
||||
LOGDIR="./lora_comparison_$(date +%Y%m%d_%H%M%S)"
|
||||
mkdir -p "$LOGDIR"
|
||||
echo "Log directory: $LOGDIR"
|
||||
|
||||
# Cleanup function
|
||||
cleanup() {
|
||||
echo ""
|
||||
echo "Cleaning up all processes..."
|
||||
|
||||
# Kill by name
|
||||
pkill -f "gsm8k_server.py" 2>/dev/null || true
|
||||
pkill -f "run-api" 2>/dev/null || true
|
||||
pkill -f "vllm_api_server.py" 2>/dev/null || true
|
||||
pkill -f "example_trainer.grpo" 2>/dev/null || true
|
||||
|
||||
# Kill by port
|
||||
for port in 8001 8002 9001 9002; do
|
||||
fuser -k ${port}/tcp 2>/dev/null || true
|
||||
done
|
||||
|
||||
echo "Cleanup complete."
|
||||
}
|
||||
trap cleanup EXIT
|
||||
|
||||
# Initial cleanup
|
||||
echo ""
|
||||
echo "Killing any existing processes on ports 8001, 8002, 9001, 9002..."
|
||||
cleanup
|
||||
sleep 3
|
||||
|
||||
# ============================================================================
|
||||
# MODE 1: lora_restart (GPU 0, ports 8001/9001)
|
||||
# ============================================================================
|
||||
echo ""
|
||||
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
echo "[1/2] LORA_RESTART MODE (GPU 0, API:8001, vLLM:9001)"
|
||||
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
|
||||
# Start API for lora_restart
|
||||
echo " Starting API server (port 8001)..."
|
||||
run-api --port 8001 > "$LOGDIR/api_restart.log" 2>&1 &
|
||||
RESTART_API_PID=$!
|
||||
sleep 3
|
||||
|
||||
# Check API is up
|
||||
if curl -s "http://localhost:8001/info" > /dev/null 2>&1; then
|
||||
echo " ✓ API running (PID: $RESTART_API_PID)"
|
||||
else
|
||||
echo " ✗ API failed to start"
|
||||
cat "$LOGDIR/api_restart.log"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Start trainer (lora_restart manages vLLM internally)
|
||||
echo " Starting lora_restart trainer (will launch vLLM on port 9001)..."
|
||||
CUDA_VISIBLE_DEVICES=0 python -m example_trainer.grpo \
|
||||
--model-name "$MODEL" \
|
||||
--weight-bridge-mode lora_restart \
|
||||
--vllm-port 9001 \
|
||||
--atropos-url http://localhost:8001 \
|
||||
--lora-r 16 \
|
||||
--lora-alpha 32 \
|
||||
--training-steps $STEPS \
|
||||
--vllm-restart-interval $RESTART_INTERVAL \
|
||||
--save-path "$LOGDIR/checkpoints_restart" \
|
||||
--use-wandb \
|
||||
--wandb-project "$WANDB_PROJECT" \
|
||||
--wandb-group "comparison-$(date +%Y%m%d)" \
|
||||
--benchmark \
|
||||
> "$LOGDIR/trainer_restart.log" 2>&1 &
|
||||
RESTART_TRAINER_PID=$!
|
||||
echo " ✓ Trainer started (PID: $RESTART_TRAINER_PID)"
|
||||
|
||||
# Wait for vLLM to be ready (trainer launches it)
|
||||
echo " Waiting for vLLM to start (port 9001)..."
|
||||
for i in {1..60}; do
|
||||
if curl -s "http://localhost:9001/health" > /dev/null 2>&1; then
|
||||
echo " ✓ vLLM ready after ~${i}s"
|
||||
break
|
||||
fi
|
||||
sleep 2
|
||||
done
|
||||
|
||||
# Start environment for lora_restart
|
||||
echo " Starting environment..."
|
||||
python -u environments/gsm8k_server.py serve \
|
||||
--env.tokenizer_name "$MODEL" \
|
||||
--env.rollout_server_url "http://localhost:8001" \
|
||||
--env.max_token_length 2048 \
|
||||
--env.use_wandb=True \
|
||||
--env.wandb_name "lora-restart-env" \
|
||||
--openai.model_name "$MODEL" \
|
||||
--openai.base_url "http://localhost:9001/v1" \
|
||||
--openai.server_type vllm \
|
||||
--slurm false \
|
||||
> "$LOGDIR/env_restart.log" 2>&1 &
|
||||
RESTART_ENV_PID=$!
|
||||
echo " ✓ Environment started (PID: $RESTART_ENV_PID)"
|
||||
|
||||
# ============================================================================
|
||||
# MODE 2: lora_only (GPU 1, ports 8002/9002)
|
||||
# ============================================================================
|
||||
echo ""
|
||||
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
echo "[2/2] LORA_ONLY MODE (GPU 1, API:8002, vLLM:9002)"
|
||||
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
|
||||
# Start API for lora_only
|
||||
echo " Starting API server (port 8002)..."
|
||||
run-api --port 8002 > "$LOGDIR/api_only.log" 2>&1 &
|
||||
ONLY_API_PID=$!
|
||||
sleep 3
|
||||
|
||||
# Check API is up
|
||||
if curl -s "http://localhost:8002/info" > /dev/null 2>&1; then
|
||||
echo " ✓ API running (PID: $ONLY_API_PID)"
|
||||
else
|
||||
echo " ✗ API failed to start"
|
||||
cat "$LOGDIR/api_only.log"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Start vLLM for lora_only (external, with --enforce-eager)
|
||||
echo " Starting vLLM with --enable-lora --enforce-eager (port 9002)..."
|
||||
CUDA_VISIBLE_DEVICES=1 python example_trainer/vllm_api_server.py \
|
||||
--model "$MODEL" \
|
||||
--port 9002 \
|
||||
--gpu-memory-utilization 0.45 \
|
||||
--enable-lora \
|
||||
--max-lora-rank 32 \
|
||||
--enforce-eager \
|
||||
> "$LOGDIR/vllm_only.log" 2>&1 &
|
||||
ONLY_VLLM_PID=$!
|
||||
echo " ✓ vLLM started (PID: $ONLY_VLLM_PID)"
|
||||
|
||||
# Wait for vLLM to be ready
|
||||
echo " Waiting for vLLM to start (port 9002)..."
|
||||
for i in {1..90}; do
|
||||
if curl -s "http://localhost:9002/health" > /dev/null 2>&1; then
|
||||
echo " ✓ vLLM ready after ~${i}s"
|
||||
break
|
||||
fi
|
||||
sleep 2
|
||||
done
|
||||
|
||||
# Start environment for lora_only
|
||||
echo " Starting environment..."
|
||||
python -u environments/gsm8k_server.py serve \
|
||||
--env.tokenizer_name "$MODEL" \
|
||||
--env.rollout_server_url "http://localhost:8002" \
|
||||
--env.max_token_length 2048 \
|
||||
--env.use_wandb=True \
|
||||
--env.wandb_name "lora-only-env" \
|
||||
--openai.model_name "$MODEL" \
|
||||
--openai.base_url "http://localhost:9002/v1" \
|
||||
--openai.server_type vllm \
|
||||
--slurm false \
|
||||
> "$LOGDIR/env_only.log" 2>&1 &
|
||||
ONLY_ENV_PID=$!
|
||||
echo " ✓ Environment started (PID: $ONLY_ENV_PID)"
|
||||
|
||||
# Start trainer for lora_only
|
||||
echo " Starting lora_only trainer..."
|
||||
CUDA_VISIBLE_DEVICES=1 python -m example_trainer.grpo \
|
||||
--model-name "$MODEL" \
|
||||
--weight-bridge-mode lora_only \
|
||||
--vllm-port 9002 \
|
||||
--atropos-url http://localhost:8002 \
|
||||
--lora-r 16 \
|
||||
--lora-alpha 32 \
|
||||
--training-steps $STEPS \
|
||||
--save-path "$LOGDIR/checkpoints_only" \
|
||||
--use-wandb \
|
||||
--wandb-project "$WANDB_PROJECT" \
|
||||
--wandb-group "comparison-$(date +%Y%m%d)" \
|
||||
--benchmark \
|
||||
> "$LOGDIR/trainer_only.log" 2>&1 &
|
||||
ONLY_TRAINER_PID=$!
|
||||
echo " ✓ Trainer started (PID: $ONLY_TRAINER_PID)"
|
||||
|
||||
# ============================================================================
|
||||
# Save PIDs and monitor
|
||||
# ============================================================================
|
||||
cat > "$LOGDIR/pids.txt" << EOF
|
||||
RESTART_API_PID=$RESTART_API_PID
|
||||
RESTART_TRAINER_PID=$RESTART_TRAINER_PID
|
||||
RESTART_ENV_PID=$RESTART_ENV_PID
|
||||
ONLY_API_PID=$ONLY_API_PID
|
||||
ONLY_VLLM_PID=$ONLY_VLLM_PID
|
||||
ONLY_ENV_PID=$ONLY_ENV_PID
|
||||
ONLY_TRAINER_PID=$ONLY_TRAINER_PID
|
||||
EOF
|
||||
|
||||
echo ""
|
||||
echo "============================================================================"
|
||||
echo "All components started!"
|
||||
echo "============================================================================"
|
||||
echo ""
|
||||
echo "📊 Monitor progress:"
|
||||
echo " tail -f $LOGDIR/trainer_restart.log # lora_restart"
|
||||
echo " tail -f $LOGDIR/trainer_only.log # lora_only"
|
||||
echo ""
|
||||
echo "🔍 Watch both:"
|
||||
echo " tail -f $LOGDIR/trainer_*.log"
|
||||
echo ""
|
||||
echo "📈 W&B Dashboard:"
|
||||
echo " https://wandb.ai/$WANDB_PROJECT"
|
||||
echo ""
|
||||
echo "Waiting for trainers to complete..."
|
||||
echo "(lora_restart should finish MUCH faster than lora_only)"
|
||||
echo ""
|
||||
|
||||
# Wait for trainers
|
||||
RESTART_STATUS="running"
|
||||
ONLY_STATUS="running"
|
||||
|
||||
while [ "$RESTART_STATUS" = "running" ] || [ "$ONLY_STATUS" = "running" ]; do
|
||||
sleep 30
|
||||
|
||||
# Check lora_restart
|
||||
if [ "$RESTART_STATUS" = "running" ]; then
|
||||
if ! kill -0 $RESTART_TRAINER_PID 2>/dev/null; then
|
||||
wait $RESTART_TRAINER_PID 2>/dev/null && RESTART_STATUS="completed" || RESTART_STATUS="failed"
|
||||
echo " lora_restart: $RESTART_STATUS"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Check lora_only
|
||||
if [ "$ONLY_STATUS" = "running" ]; then
|
||||
if ! kill -0 $ONLY_TRAINER_PID 2>/dev/null; then
|
||||
wait $ONLY_TRAINER_PID 2>/dev/null && ONLY_STATUS="completed" || ONLY_STATUS="failed"
|
||||
echo " lora_only: $ONLY_STATUS"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Show status
|
||||
if [ "$RESTART_STATUS" = "running" ] || [ "$ONLY_STATUS" = "running" ]; then
|
||||
echo " [$(date +%H:%M:%S)] lora_restart: $RESTART_STATUS, lora_only: $ONLY_STATUS"
|
||||
fi
|
||||
done
|
||||
|
||||
# ============================================================================
|
||||
# Print results
|
||||
# ============================================================================
|
||||
echo ""
|
||||
echo "============================================================================"
|
||||
echo "COMPARISON RESULTS"
|
||||
echo "============================================================================"
|
||||
|
||||
echo ""
|
||||
echo "📊 LORA_RESTART (CUDA graphs, vLLM restarts):"
|
||||
echo "─────────────────────────────────────────────────"
|
||||
grep -A 20 "BENCHMARK SUMMARY" "$LOGDIR/trainer_restart.log" 2>/dev/null || echo " (check $LOGDIR/trainer_restart.log)"
|
||||
|
||||
echo ""
|
||||
echo "📊 LORA_ONLY (--enforce-eager, hot-swap):"
|
||||
echo "─────────────────────────────────────────────────"
|
||||
grep -A 20 "BENCHMARK SUMMARY" "$LOGDIR/trainer_only.log" 2>/dev/null || echo " (check $LOGDIR/trainer_only.log)"
|
||||
|
||||
echo ""
|
||||
echo "============================================================================"
|
||||
echo "📁 LOGS SAVED TO: $LOGDIR"
|
||||
echo "============================================================================"
|
||||
echo ""
|
||||
echo "Log files:"
|
||||
echo " $LOGDIR/trainer_restart.log # lora_restart trainer"
|
||||
echo " $LOGDIR/trainer_only.log # lora_only trainer"
|
||||
echo " $LOGDIR/vllm_only.log # lora_only vLLM"
|
||||
echo " $LOGDIR/env_restart.log # lora_restart environment"
|
||||
echo " $LOGDIR/env_only.log # lora_only environment"
|
||||
echo ""
|
||||
echo "Checkpoints:"
|
||||
echo " $LOGDIR/checkpoints_restart/"
|
||||
echo " $LOGDIR/checkpoints_only/"
|
||||
echo ""
|
||||
echo "W&B runs should be visible at:"
|
||||
echo " https://wandb.ai/$WANDB_PROJECT"
|
||||
echo ""
|
||||
echo "============================================================================"
|
||||
echo "Done!"
|
||||
138
example_trainer/scripts/test_lora_restart.sh
Executable file
138
example_trainer/scripts/test_lora_restart.sh
Executable file
|
|
@ -0,0 +1,138 @@
|
|||
#!/bin/bash
|
||||
# Quick test script for lora_restart mode
|
||||
# Tests that the mode works and compares timing
|
||||
|
||||
set -e
|
||||
|
||||
MODEL="Qwen/Qwen3-4B-Instruct-2507"
|
||||
STEPS=10
|
||||
RESTART_INTERVAL=3
|
||||
|
||||
echo "=============================================="
|
||||
echo "Testing lora_restart mode"
|
||||
echo "=============================================="
|
||||
echo "Model: $MODEL"
|
||||
echo "Steps: $STEPS"
|
||||
echo "Restart interval: $RESTART_INTERVAL"
|
||||
echo "=============================================="
|
||||
|
||||
# Get script directory
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
REPO_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)"
|
||||
cd "$REPO_ROOT"
|
||||
|
||||
# Create log directory
|
||||
LOGDIR="./lora_restart_test_$(date +%Y%m%d_%H%M%S)"
|
||||
mkdir -p "$LOGDIR"
|
||||
echo "Logs: $LOGDIR"
|
||||
|
||||
# Cleanup function
|
||||
cleanup() {
|
||||
echo "Cleaning up..."
|
||||
pkill -f "gsm8k_server.py" 2>/dev/null || true
|
||||
pkill -f "run-api" 2>/dev/null || true
|
||||
pkill -f "vllm_api_server.py" 2>/dev/null || true
|
||||
# Kill by port
|
||||
for port in 8000 9001; do
|
||||
fuser -k ${port}/tcp 2>/dev/null || true
|
||||
done
|
||||
}
|
||||
trap cleanup EXIT
|
||||
|
||||
# Kill any existing processes
|
||||
cleanup
|
||||
sleep 2
|
||||
|
||||
# Start API server
|
||||
echo ""
|
||||
echo "[1/3] Starting Atropos API..."
|
||||
run-api --port 8000 > "$LOGDIR/api.log" 2>&1 &
|
||||
API_PID=$!
|
||||
sleep 3
|
||||
|
||||
# Check API is up
|
||||
if ! curl -s "http://localhost:8000/info" > /dev/null 2>&1; then
|
||||
echo "ERROR: API server failed to start"
|
||||
cat "$LOGDIR/api.log"
|
||||
exit 1
|
||||
fi
|
||||
echo " ✓ API running (PID: $API_PID)"
|
||||
|
||||
# Start trainer (lora_restart manages vLLM internally)
|
||||
echo ""
|
||||
echo "[2/3] Starting lora_restart trainer..."
|
||||
echo " (This will launch vLLM internally)"
|
||||
|
||||
START_TIME=$(date +%s)
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python -m example_trainer.grpo \
|
||||
--model-name "$MODEL" \
|
||||
--weight-bridge-mode lora_restart \
|
||||
--vllm-port 9001 \
|
||||
--atropos-url http://localhost:8000 \
|
||||
--lora-r 16 \
|
||||
--lora-alpha 32 \
|
||||
--training-steps $STEPS \
|
||||
--vllm-restart-interval $RESTART_INTERVAL \
|
||||
--save-path "$LOGDIR/checkpoints" \
|
||||
--benchmark \
|
||||
> "$LOGDIR/trainer.log" 2>&1 &
|
||||
TRAINER_PID=$!
|
||||
|
||||
# Wait for vLLM to start (trainer launches it)
|
||||
echo " Waiting for trainer to launch vLLM..."
|
||||
sleep 30
|
||||
|
||||
# Start environment (needs to wait for vLLM)
|
||||
echo ""
|
||||
echo "[3/3] Starting GSM8K environment..."
|
||||
python -u environments/gsm8k_server.py serve \
|
||||
--env.tokenizer_name "$MODEL" \
|
||||
--env.rollout_server_url "http://localhost:8000" \
|
||||
--env.max_token_length 2048 \
|
||||
--env.use_wandb=False \
|
||||
--openai.model_name "$MODEL" \
|
||||
--openai.base_url "http://localhost:9001/v1" \
|
||||
--openai.server_type vllm \
|
||||
--slurm false \
|
||||
> "$LOGDIR/env.log" 2>&1 &
|
||||
ENV_PID=$!
|
||||
sleep 5
|
||||
echo " ✓ Environment running (PID: $ENV_PID)"
|
||||
|
||||
# Wait for trainer to complete
|
||||
echo ""
|
||||
echo "Waiting for training to complete..."
|
||||
echo "(Check progress: tail -f $LOGDIR/trainer.log)"
|
||||
|
||||
wait $TRAINER_PID
|
||||
TRAINER_EXIT=$?
|
||||
|
||||
END_TIME=$(date +%s)
|
||||
ELAPSED=$((END_TIME - START_TIME))
|
||||
|
||||
echo ""
|
||||
echo "=============================================="
|
||||
echo "TEST RESULTS"
|
||||
echo "=============================================="
|
||||
|
||||
if [ $TRAINER_EXIT -eq 0 ]; then
|
||||
echo "✓ Training completed successfully!"
|
||||
echo " Time: ${ELAPSED}s"
|
||||
echo ""
|
||||
echo "Checkpoints:"
|
||||
ls -la "$LOGDIR/checkpoints/" 2>/dev/null || echo " (no checkpoints found)"
|
||||
echo ""
|
||||
echo "Benchmark summary:"
|
||||
grep -A 20 "BENCHMARK SUMMARY" "$LOGDIR/trainer.log" 2>/dev/null || echo " (no benchmark found)"
|
||||
else
|
||||
echo "✗ Training FAILED (exit code: $TRAINER_EXIT)"
|
||||
echo ""
|
||||
echo "Last 50 lines of trainer log:"
|
||||
tail -50 "$LOGDIR/trainer.log"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=============================================="
|
||||
echo "Log files saved to: $LOGDIR"
|
||||
echo "=============================================="
|
||||
|
|
@ -5,9 +5,11 @@ Contains the four main training modes:
|
|||
- train_legacy: Checkpoint-based training with vLLM restarts
|
||||
- train_shared_vllm: Single-copy mode with CUDA IPC
|
||||
- train_lora: LoRA adapter training with HTTP hot-swap
|
||||
- train_lora_restart: LoRA training with vLLM restarts (FAST mode)
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -658,3 +660,279 @@ def _hotswap_lora_adapter(
|
|||
return False
|
||||
|
||||
|
||||
def train_lora_restart(config: TrainingConfig):
|
||||
"""
|
||||
GRPO training with LoRA adapters using vLLM restarts (FAST mode).
|
||||
|
||||
This mode:
|
||||
1. Freezes base model, trains only LoRA adapter weights
|
||||
2. Runs vLLM WITH CUDA graphs enabled (no --enforce-eager)
|
||||
3. Restarts vLLM every N steps with the new adapter pre-loaded
|
||||
|
||||
Performance comparison:
|
||||
- lora_only (--enforce-eager): ~13 TPS (SLOW)
|
||||
- lora_restart (CUDA graphs): ~170 TPS (FAST)
|
||||
|
||||
The restart overhead (~45s) is much less than the 12x inference slowdown.
|
||||
|
||||
Requirements:
|
||||
- No external vLLM needed - this mode manages vLLM internally
|
||||
- Requires PEFT library for LoRA
|
||||
"""
|
||||
if not PEFT_AVAILABLE:
|
||||
raise RuntimeError(
|
||||
"PEFT library required for LoRA mode. Install with: pip install peft"
|
||||
)
|
||||
|
||||
training_start_time = time.time()
|
||||
|
||||
# === Setup ===
|
||||
use_wandb = setup_wandb(config)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("LORA RESTART MODE (fast inference with CUDA graphs)")
|
||||
print("=" * 60)
|
||||
print(f"Base model: {config.model_name}")
|
||||
print(f"LoRA config: r={config.lora_r}, alpha={config.lora_alpha}")
|
||||
print(f"Save path: {config.save_path}")
|
||||
print(f"vLLM port: {config.vllm_port}")
|
||||
print(f"Restart interval: every {config.vllm_restart_interval} steps")
|
||||
print("=" * 60)
|
||||
print("NOTE: This mode restarts vLLM to keep CUDA graphs enabled.")
|
||||
print(" Expected inference speed: ~170 TPS (vs ~13 TPS with --enforce-eager)")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
# Load model with LoRA adapters for training
|
||||
print("[1/4] Loading model with LoRA adapters...")
|
||||
model, tokenizer = load_model_and_tokenizer(config)
|
||||
|
||||
# Only optimize LoRA parameters
|
||||
trainable_params = [p for p in model.parameters() if p.requires_grad]
|
||||
optimizer = AdamW(trainable_params, lr=config.lr)
|
||||
|
||||
os.makedirs(config.save_path, exist_ok=True)
|
||||
|
||||
# Save initial adapter
|
||||
print("[2/4] Saving initial LoRA adapter...")
|
||||
initial_adapter_path = save_lora_checkpoint(model, config.save_path, 0)
|
||||
current_adapter_path = initial_adapter_path
|
||||
|
||||
# Launch vLLM with the initial adapter
|
||||
print("[3/4] Launching vLLM with CUDA graphs (no --enforce-eager)...")
|
||||
vllm_proc = _launch_vllm_with_lora(config, current_adapter_path)
|
||||
if vllm_proc is None:
|
||||
raise RuntimeError("Failed to launch vLLM")
|
||||
|
||||
print(f"[4/4] Starting training for {config.training_steps} steps")
|
||||
print("-" * 60)
|
||||
|
||||
# Check Atropos API
|
||||
if not check_atropos_api(url=config.atropos_url, timeout=30):
|
||||
_terminate_vllm(vllm_proc)
|
||||
raise RuntimeError(f"Atropos API not reachable at {config.atropos_url}")
|
||||
register_trainer(config)
|
||||
|
||||
# === Benchmark tracking ===
|
||||
benchmark_stats = {
|
||||
"step_times": [],
|
||||
"sync_times": [],
|
||||
"data_fetch_times": [],
|
||||
"gpu_memories": [],
|
||||
"restart_times": [],
|
||||
}
|
||||
|
||||
# === Training Loop ===
|
||||
batches = []
|
||||
for step in range(config.training_steps):
|
||||
print(f"\nStep {step+1}/{config.training_steps}")
|
||||
|
||||
# Fetch data (with inference logprobs for proper GRPO)
|
||||
data_fetch_start = time.time()
|
||||
if len(batches) == 0:
|
||||
batches, _ = get_data(
|
||||
config.batch_size,
|
||||
config.seq_len,
|
||||
config.atropos_url,
|
||||
extract_inference_logprobs=True,
|
||||
)
|
||||
batch_data = batches.pop(0)
|
||||
token_batches, label_batches, advantage_batches, temperature_batches = (
|
||||
batch_data[:4]
|
||||
)
|
||||
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
|
||||
data_fetch_time = time.time() - data_fetch_start
|
||||
benchmark_stats["data_fetch_times"].append(data_fetch_time)
|
||||
|
||||
# Training step with proper GRPO
|
||||
step_start = time.time()
|
||||
metrics = run_training_step(
|
||||
model,
|
||||
optimizer,
|
||||
token_batches,
|
||||
label_batches,
|
||||
advantage_batches,
|
||||
temperature_batches,
|
||||
config,
|
||||
inference_logprob_batches=inference_logprob_batches,
|
||||
)
|
||||
step_time = time.time() - step_start
|
||||
benchmark_stats["step_times"].append(step_time)
|
||||
|
||||
# GPU memory tracking
|
||||
gpu_mem_gb = (
|
||||
torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
|
||||
)
|
||||
gpu_mem_reserved_gb = (
|
||||
torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
|
||||
)
|
||||
benchmark_stats["gpu_memories"].append(gpu_mem_gb)
|
||||
|
||||
# Periodic adapter save + vLLM restart
|
||||
sync_time = 0
|
||||
should_sync = (step + 1) % config.vllm_restart_interval == 0
|
||||
if should_sync and (step + 1) < config.training_steps: # Don't restart on last step
|
||||
sync_start = time.time()
|
||||
|
||||
# Save new adapter
|
||||
current_adapter_path = save_lora_checkpoint(model, config.save_path, step + 1)
|
||||
|
||||
# Restart vLLM with new adapter
|
||||
print(f" [RESTART] Restarting vLLM with new adapter...")
|
||||
_terminate_vllm(vllm_proc)
|
||||
vllm_proc = _launch_vllm_with_lora(config, current_adapter_path)
|
||||
if vllm_proc is None:
|
||||
raise RuntimeError("Failed to restart vLLM")
|
||||
|
||||
sync_time = time.time() - sync_start
|
||||
benchmark_stats["sync_times"].append(sync_time)
|
||||
benchmark_stats["restart_times"].append(sync_time)
|
||||
print(f" [RESTART] vLLM restarted in {sync_time:.1f}s")
|
||||
|
||||
# Update metrics
|
||||
metrics.update(
|
||||
{
|
||||
"step_time": step_time,
|
||||
"sync_time": sync_time,
|
||||
"data_fetch_time": data_fetch_time,
|
||||
"gpu_memory_gb": gpu_mem_gb,
|
||||
"gpu_memory_reserved_gb": gpu_mem_reserved_gb,
|
||||
}
|
||||
)
|
||||
|
||||
log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark)
|
||||
|
||||
# === Cleanup ===
|
||||
print("\nSaving final adapter...")
|
||||
final_sync_start = time.time()
|
||||
final_adapter_path = save_lora_checkpoint(
|
||||
model, config.save_path, config.training_steps, is_final=True
|
||||
)
|
||||
final_sync_time = time.time() - final_sync_start
|
||||
benchmark_stats["sync_times"].append(final_sync_time)
|
||||
|
||||
# Terminate vLLM
|
||||
_terminate_vllm(vllm_proc)
|
||||
|
||||
finalize_training(
|
||||
use_wandb,
|
||||
training_start_time,
|
||||
"lora_restart",
|
||||
config.training_steps,
|
||||
benchmark_stats,
|
||||
config.benchmark,
|
||||
)
|
||||
|
||||
# Save tokenizer
|
||||
tokenizer_path = os.path.join(config.save_path, "tokenizer")
|
||||
tokenizer.save_pretrained(tokenizer_path)
|
||||
print(f"Tokenizer saved to {tokenizer_path}")
|
||||
print(f"Final adapter saved to {final_adapter_path}")
|
||||
|
||||
|
||||
def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optional[subprocess.Popen]:
|
||||
"""
|
||||
Launch vLLM with a LoRA adapter pre-loaded (CUDA graphs enabled).
|
||||
|
||||
Unlike lora_only mode, this does NOT use --enforce-eager, so we get
|
||||
full CUDA graph speed (~170 TPS instead of ~13 TPS).
|
||||
"""
|
||||
from .vllm_manager import kill_process_on_port, wait_for_vllm_ready
|
||||
|
||||
# Kill any existing process on the port
|
||||
kill_process_on_port(config.vllm_port)
|
||||
|
||||
# Find the vllm_api_server.py script
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
server_script = os.path.join(script_dir, "vllm_api_server.py")
|
||||
|
||||
# Build command - NO --enforce-eager for full speed
|
||||
cmd = [
|
||||
"python", server_script,
|
||||
"--model", config.model_name,
|
||||
"--port", str(config.vllm_port),
|
||||
"--gpu-memory-utilization", str(config.vllm_gpu_memory_utilization),
|
||||
"--enable-lora",
|
||||
"--max-lora-rank", str(max(config.lora_r * 2, 32)),
|
||||
# Note: NOT adding --enforce-eager - this is the key difference!
|
||||
# LoRA adapter will be loaded at startup, CUDA graphs compiled with it
|
||||
]
|
||||
|
||||
# Set environment for GPU selection
|
||||
env = os.environ.copy()
|
||||
if config.vllm_gpu is not None:
|
||||
env["CUDA_VISIBLE_DEVICES"] = str(config.vllm_gpu)
|
||||
print(f" GPU: {config.vllm_gpu} (via CUDA_VISIBLE_DEVICES)")
|
||||
else:
|
||||
print(f" GPU: Same as trainer (inherited CUDA_VISIBLE_DEVICES)")
|
||||
|
||||
print(f" Launching: {' '.join(cmd)}")
|
||||
print(f" Adapter: {adapter_path}")
|
||||
|
||||
try:
|
||||
proc = subprocess.Popen(cmd, env=env)
|
||||
print(f" vLLM PID: {proc.pid}")
|
||||
|
||||
# Wait for server to be ready
|
||||
if not wait_for_vllm_ready(config.vllm_port, timeout=180):
|
||||
print(" ERROR: vLLM failed to start")
|
||||
proc.terminate()
|
||||
return None
|
||||
|
||||
# Load the LoRA adapter
|
||||
print(f" Loading LoRA adapter...")
|
||||
try:
|
||||
resp = requests.post(
|
||||
f"http://localhost:{config.vllm_port}/lora/load",
|
||||
json={"adapter_path": adapter_path, "adapter_name": "training_adapter"},
|
||||
timeout=60,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
print(f" ✓ Adapter loaded successfully")
|
||||
else:
|
||||
print(f" WARNING: Adapter load returned {resp.status_code}: {resp.text}")
|
||||
except Exception as e:
|
||||
print(f" WARNING: Could not load adapter: {e}")
|
||||
# Continue anyway - base model inference still works
|
||||
|
||||
return proc
|
||||
|
||||
except Exception as e:
|
||||
print(f" ERROR: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _terminate_vllm(proc: Optional[subprocess.Popen]) -> None:
|
||||
"""Terminate a vLLM process."""
|
||||
if proc is None:
|
||||
return
|
||||
|
||||
try:
|
||||
proc.terminate()
|
||||
proc.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
proc.wait()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue