mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
enforce eager check
This commit is contained in:
parent
84cee536a3
commit
211f91b528
4 changed files with 190 additions and 616 deletions
|
|
@ -1,12 +1,16 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Benchmark LoRA vs Shared vLLM inference performance.
|
||||
Benchmark LoRA inference modes to find the fastest approach.
|
||||
|
||||
This script:
|
||||
1. Starts two vLLM instances (one with LoRA, one without)
|
||||
2. Optionally loads a LoRA adapter
|
||||
3. Sends identical prompts to both
|
||||
4. Measures and compares TPS (tokens per second)
|
||||
This script tests multiple vLLM configurations to determine:
|
||||
1. Does --enable-lora force eager mode even without --enforce-eager?
|
||||
2. What's the actual TPS difference between configurations?
|
||||
3. Is there ANY way to get fast LoRA inference?
|
||||
|
||||
Configurations tested:
|
||||
- BASE: No LoRA flags (CUDA graphs enabled) - baseline
|
||||
- LORA_EAGER: --enable-lora --enforce-eager (required for hot-swap)
|
||||
- LORA_NO_EAGER: --enable-lora only (does vLLM force eager anyway?)
|
||||
|
||||
Usage:
|
||||
python benchmark_lora_vs_shared.py --model Qwen/Qwen3-4B-Instruct-2507
|
||||
|
|
@ -77,11 +81,18 @@ def start_vllm_server(
|
|||
model: str,
|
||||
port: int,
|
||||
gpu_id: int,
|
||||
enable_lora: bool = False,
|
||||
mode: str = "base", # "base", "lora_eager", "lora_no_eager"
|
||||
max_lora_rank: int = 32,
|
||||
log_file: str = "vllm.log",
|
||||
) -> subprocess.Popen:
|
||||
"""Start a vLLM server."""
|
||||
"""
|
||||
Start a vLLM server with different configurations.
|
||||
|
||||
Modes:
|
||||
- base: No LoRA, CUDA graphs enabled (fastest)
|
||||
- lora_eager: --enable-lora --enforce-eager (slow, but supports hot-swap)
|
||||
- lora_no_eager: --enable-lora only (test if vLLM forces eager anyway)
|
||||
"""
|
||||
# Find the vllm_api_server.py script relative to this script
|
||||
script_dir = Path(__file__).parent.parent # example_trainer/
|
||||
vllm_server_path = script_dir / "vllm_api_server.py"
|
||||
|
|
@ -99,17 +110,27 @@ def start_vllm_server(
|
|||
"--dtype", "bfloat16",
|
||||
]
|
||||
|
||||
if enable_lora:
|
||||
if mode == "lora_eager":
|
||||
cmd.extend([
|
||||
"--enable-lora",
|
||||
"--max-lora-rank", str(max_lora_rank),
|
||||
"--enforce-eager", # Required for LoRA
|
||||
"--enforce-eager",
|
||||
])
|
||||
log(f"Mode: LORA_EAGER (--enable-lora --enforce-eager)")
|
||||
elif mode == "lora_no_eager":
|
||||
cmd.extend([
|
||||
"--enable-lora",
|
||||
"--max-lora-rank", str(max_lora_rank),
|
||||
# NOTE: NOT adding --enforce-eager - testing if vLLM forces it anyway
|
||||
])
|
||||
log(f"Mode: LORA_NO_EAGER (--enable-lora only, NO --enforce-eager)")
|
||||
else:
|
||||
log(f"Mode: BASE (no LoRA flags, CUDA graphs enabled)")
|
||||
|
||||
env = os.environ.copy()
|
||||
env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
||||
|
||||
log(f"Starting vLLM: CUDA_VISIBLE_DEVICES={gpu_id}")
|
||||
log(f"GPU: {gpu_id}")
|
||||
log(f"Command: {' '.join(cmd)}")
|
||||
|
||||
log_f = open(log_file, "w")
|
||||
|
|
@ -119,7 +140,7 @@ def start_vllm_server(
|
|||
stdout=log_f,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
log(f"Started vLLM process PID={proc.pid}, logging to {log_file}")
|
||||
log(f"Started vLLM PID={proc.pid}, log: {log_file}")
|
||||
return proc
|
||||
|
||||
|
||||
|
|
@ -189,7 +210,7 @@ def benchmark_inference(
|
|||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Benchmark LoRA vs Shared vLLM inference")
|
||||
parser = argparse.ArgumentParser(description="Benchmark LoRA inference configurations")
|
||||
parser.add_argument("--model", type=str, default="Qwen/Qwen3-4B-Instruct-2507",
|
||||
help="Model to benchmark")
|
||||
parser.add_argument("--lora-path", type=str, default=None,
|
||||
|
|
@ -198,146 +219,166 @@ def main():
|
|||
help="Max tokens to generate")
|
||||
parser.add_argument("--num-runs", type=int, default=3,
|
||||
help="Number of benchmark runs per server")
|
||||
parser.add_argument("--lora-gpu", type=int, default=0,
|
||||
help="GPU for LoRA server")
|
||||
parser.add_argument("--shared-gpu", type=int, default=1,
|
||||
help="GPU for shared/base server")
|
||||
parser.add_argument("--lora-port", type=int, default=9001,
|
||||
help="Port for LoRA server")
|
||||
parser.add_argument("--shared-port", type=int, default=9002,
|
||||
help="Port for shared/base server")
|
||||
parser.add_argument("--gpu", type=int, default=0,
|
||||
help="GPU to use (tests run sequentially)")
|
||||
parser.add_argument("--port", type=int, default=9001,
|
||||
help="Port for vLLM server")
|
||||
parser.add_argument("--prompt", type=str, choices=["math", "long"], default="long",
|
||||
help="Which prompt to use")
|
||||
parser.add_argument("--skip-lora", action="store_true",
|
||||
help="Skip LoRA server (test base only)")
|
||||
parser.add_argument("--skip-shared", action="store_true",
|
||||
help="Skip shared/base server (test LoRA only)")
|
||||
parser.add_argument("--modes", type=str, default="all",
|
||||
help="Comma-separated modes to test: base,lora_eager,lora_no_eager or 'all'")
|
||||
args = parser.parse_args()
|
||||
|
||||
prompt = LONG_PROMPT if args.prompt == "long" else BENCHMARK_PROMPT
|
||||
|
||||
procs = []
|
||||
# Parse modes to test
|
||||
if args.modes == "all":
|
||||
modes_to_test = ["base", "lora_no_eager", "lora_eager"]
|
||||
else:
|
||||
modes_to_test = [m.strip() for m in args.modes.split(",")]
|
||||
|
||||
results = {}
|
||||
current_proc = None
|
||||
|
||||
def cleanup():
|
||||
log("\nCleaning up...")
|
||||
for p in procs:
|
||||
if current_proc:
|
||||
try:
|
||||
p.terminate()
|
||||
p.wait(timeout=5)
|
||||
current_proc.terminate()
|
||||
current_proc.wait(timeout=5)
|
||||
except Exception:
|
||||
p.kill()
|
||||
try:
|
||||
current_proc.kill()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
signal.signal(signal.SIGINT, lambda s, f: (cleanup(), sys.exit(0)))
|
||||
signal.signal(signal.SIGTERM, lambda s, f: (cleanup(), sys.exit(0)))
|
||||
|
||||
try:
|
||||
log("=" * 70)
|
||||
log("vLLM Inference Benchmark: LoRA vs Base Model")
|
||||
log("vLLM LoRA Inference Configuration Benchmark")
|
||||
log("=" * 70)
|
||||
log(f"Model: {args.model}")
|
||||
log(f"LoRA adapter: {args.lora_path or 'None (base model only)'}")
|
||||
log(f"LoRA adapter: {args.lora_path or 'None'}")
|
||||
log(f"Max tokens: {args.max_tokens}")
|
||||
log(f"Num runs: {args.num_runs}")
|
||||
log(f"Prompt type: {args.prompt}")
|
||||
log(f"Modes to test: {modes_to_test}")
|
||||
log("=" * 70)
|
||||
log("")
|
||||
log("QUESTION: Does --enable-lora force eager mode even without --enforce-eager?")
|
||||
log("=" * 70)
|
||||
|
||||
# Start LoRA server
|
||||
if not args.skip_lora:
|
||||
log(f"\n[1/4] Starting LoRA-enabled vLLM on GPU {args.lora_gpu}, port {args.lora_port}...")
|
||||
log(" Flags: --enable-lora --enforce-eager (no CUDA graphs)")
|
||||
lora_proc = start_vllm_server(
|
||||
args.model, args.lora_port, args.lora_gpu,
|
||||
enable_lora=True, log_file="benchmark_lora.log"
|
||||
# Test each mode sequentially (same GPU, restart between tests)
|
||||
for i, mode in enumerate(modes_to_test):
|
||||
log(f"\n[{i+1}/{len(modes_to_test)}] Testing mode: {mode.upper()}")
|
||||
log("-" * 70)
|
||||
|
||||
# Start server
|
||||
current_proc = start_vllm_server(
|
||||
args.model, args.port, args.gpu,
|
||||
mode=mode, log_file=f"benchmark_{mode}.log"
|
||||
)
|
||||
procs.append(lora_proc)
|
||||
|
||||
# Start base/shared server
|
||||
if not args.skip_shared:
|
||||
log(f"\n[2/4] Starting base vLLM on GPU {args.shared_gpu}, port {args.shared_port}...")
|
||||
log(" Flags: (none) - uses CUDA graphs for faster inference")
|
||||
shared_proc = start_vllm_server(
|
||||
args.model, args.shared_port, args.shared_gpu,
|
||||
enable_lora=False, log_file="benchmark_shared.log"
|
||||
)
|
||||
procs.append(shared_proc)
|
||||
|
||||
# Wait for servers
|
||||
log("\n[3/4] Waiting for servers to be ready...")
|
||||
|
||||
lora_ready = False
|
||||
shared_ready = False
|
||||
|
||||
if not args.skip_lora:
|
||||
log(f" Waiting for LoRA server (port {args.lora_port})...")
|
||||
lora_ready = wait_for_server(args.lora_port, timeout=300)
|
||||
if lora_ready:
|
||||
log(f" ✓ LoRA server ready")
|
||||
|
||||
# Load LoRA adapter if provided
|
||||
if args.lora_path:
|
||||
log(f" Loading LoRA adapter from {args.lora_path}...")
|
||||
if load_lora_adapter(args.lora_port, args.lora_path):
|
||||
log(f" ✓ LoRA adapter loaded")
|
||||
|
||||
# Wait for ready
|
||||
log(f" Waiting for server (port {args.port})...")
|
||||
if not wait_for_server(args.port, timeout=300):
|
||||
log(f" ✗ Server failed to start! Check benchmark_{mode}.log")
|
||||
results[mode] = {"error": "Server failed to start"}
|
||||
current_proc.terminate()
|
||||
current_proc = None
|
||||
continue
|
||||
|
||||
log(f" ✓ Server ready")
|
||||
|
||||
# Load LoRA adapter if provided and mode supports it
|
||||
if args.lora_path and mode in ["lora_eager", "lora_no_eager"]:
|
||||
log(f" Loading LoRA adapter...")
|
||||
if load_lora_adapter(args.port, args.lora_path):
|
||||
log(f" ✓ Adapter loaded")
|
||||
else:
|
||||
log(f" ⚠ Failed to load adapter (continuing anyway)")
|
||||
|
||||
# Check the log file for CUDA graph status
|
||||
log(f" Checking CUDA graph status in log...")
|
||||
try:
|
||||
with open(f"benchmark_{mode}.log", "r") as f:
|
||||
log_content = f.read()
|
||||
if "Cudagraph is disabled" in log_content:
|
||||
log(f" ⚠ CUDA GRAPHS DISABLED (eager mode)")
|
||||
elif "cudagraph" in log_content.lower():
|
||||
# Look for other cudagraph messages
|
||||
for line in log_content.split("\n"):
|
||||
if "cudagraph" in line.lower():
|
||||
log(f" Log: {line.strip()[:80]}")
|
||||
else:
|
||||
log(f" ✗ Failed to load LoRA adapter")
|
||||
else:
|
||||
log(f" ✗ LoRA server failed to start")
|
||||
|
||||
if not args.skip_shared:
|
||||
log(f" Waiting for base server (port {args.shared_port})...")
|
||||
shared_ready = wait_for_server(args.shared_port, timeout=300)
|
||||
if shared_ready:
|
||||
log(f" ✓ Base server ready")
|
||||
else:
|
||||
log(f" ✗ Base server failed to start")
|
||||
|
||||
# Run benchmarks
|
||||
log("\n[4/4] Running benchmarks...")
|
||||
log("-" * 70)
|
||||
|
||||
lora_results = None
|
||||
shared_results = None
|
||||
|
||||
if lora_ready and not args.skip_lora:
|
||||
log(f"\nLoRA Server (--enable-lora --enforce-eager):")
|
||||
lora_results = benchmark_inference(
|
||||
args.lora_port, prompt, args.max_tokens, args.num_runs
|
||||
)
|
||||
|
||||
if shared_ready and not args.skip_shared:
|
||||
log(f"\nBase Server (CUDA graphs enabled):")
|
||||
shared_results = benchmark_inference(
|
||||
args.shared_port, prompt, args.max_tokens, args.num_runs
|
||||
log(f" (No cudagraph message found in log)")
|
||||
except Exception as e:
|
||||
log(f" (Could not read log: {e})")
|
||||
|
||||
# Run benchmark
|
||||
log(f"\n Running {args.num_runs} inference requests...")
|
||||
mode_results = benchmark_inference(
|
||||
args.port, prompt, args.max_tokens, args.num_runs
|
||||
)
|
||||
results[mode] = mode_results
|
||||
|
||||
# Terminate server
|
||||
log(f" Stopping server...")
|
||||
current_proc.terminate()
|
||||
try:
|
||||
current_proc.wait(timeout=10)
|
||||
except Exception:
|
||||
current_proc.kill()
|
||||
current_proc = None
|
||||
|
||||
# Wait for port to be free
|
||||
time.sleep(3)
|
||||
|
||||
# Print comparison
|
||||
log("\n" + "=" * 70)
|
||||
log("RESULTS SUMMARY")
|
||||
log("=" * 70)
|
||||
|
||||
if lora_results and "avg_tps" in lora_results:
|
||||
log(f"\nLoRA Mode (--enable-lora --enforce-eager):")
|
||||
log(f" Avg time: {lora_results['avg_time']:.2f}s")
|
||||
log(f" Avg tokens: {lora_results['avg_tokens']:.0f}")
|
||||
log(f" Avg TPS: {lora_results['avg_tps']:.1f}")
|
||||
valid_results = {k: v for k, v in results.items() if "avg_tps" in v}
|
||||
|
||||
if shared_results and "avg_tps" in shared_results:
|
||||
log(f"\nBase Mode (CUDA graphs):")
|
||||
log(f" Avg time: {shared_results['avg_time']:.2f}s")
|
||||
log(f" Avg tokens: {shared_results['avg_tokens']:.0f}")
|
||||
log(f" Avg TPS: {shared_results['avg_tps']:.1f}")
|
||||
for mode, res in valid_results.items():
|
||||
log(f"\n{mode.upper()}:")
|
||||
log(f" Avg time: {res['avg_time']:.2f}s")
|
||||
log(f" Avg tokens: {res['avg_tokens']:.0f}")
|
||||
log(f" Avg TPS: {res['avg_tps']:.1f}")
|
||||
|
||||
if lora_results and shared_results and "avg_tps" in lora_results and "avg_tps" in shared_results:
|
||||
speedup = shared_results["avg_tps"] / lora_results["avg_tps"] if lora_results["avg_tps"] > 0 else 0
|
||||
time_diff = lora_results["avg_time"] - shared_results["avg_time"]
|
||||
log(f"\nComparison:")
|
||||
log(f" Base is {speedup:.2f}x faster in TPS")
|
||||
log(f" Base saves {time_diff:.2f}s per request")
|
||||
log(f" --enforce-eager overhead: ~{(1 - 1/speedup) * 100:.1f}%")
|
||||
# Compare
|
||||
if "base" in valid_results:
|
||||
base_tps = valid_results["base"]["avg_tps"]
|
||||
log(f"\n" + "-" * 70)
|
||||
log("COMPARISON TO BASE (CUDA graphs enabled):")
|
||||
for mode, res in valid_results.items():
|
||||
if mode != "base":
|
||||
ratio = res["avg_tps"] / base_tps if base_tps > 0 else 0
|
||||
slowdown = (1 - ratio) * 100
|
||||
log(f" {mode}: {res['avg_tps']:.1f} TPS ({ratio:.2f}x base, {slowdown:.1f}% slower)")
|
||||
|
||||
# Key finding
|
||||
log("\n" + "=" * 70)
|
||||
log("Note: The main difference is --enforce-eager which disables CUDA graphs.")
|
||||
log("This is REQUIRED for LoRA hot-swapping but costs ~10-30% performance.")
|
||||
log("KEY FINDING:")
|
||||
if "lora_no_eager" in valid_results and "lora_eager" in valid_results:
|
||||
no_eager_tps = valid_results["lora_no_eager"]["avg_tps"]
|
||||
eager_tps = valid_results["lora_eager"]["avg_tps"]
|
||||
if abs(no_eager_tps - eager_tps) < eager_tps * 0.1: # Within 10%
|
||||
log(" ⚠ --enable-lora FORCES eager mode regardless of --enforce-eager flag!")
|
||||
log(" ⚠ There is NO WAY to get CUDA graphs with LoRA enabled in vLLM.")
|
||||
else:
|
||||
log(" ✓ --enable-lora without --enforce-eager is FASTER!")
|
||||
log(f" ✓ lora_no_eager: {no_eager_tps:.1f} TPS vs lora_eager: {eager_tps:.1f} TPS")
|
||||
|
||||
if "base" in valid_results and "lora_eager" in valid_results:
|
||||
base_tps = valid_results["base"]["avg_tps"]
|
||||
lora_tps = valid_results["lora_eager"]["avg_tps"]
|
||||
log(f"\n Base model (no LoRA): {base_tps:.1f} TPS")
|
||||
log(f" LoRA enabled: {lora_tps:.1f} TPS")
|
||||
log(f" Slowdown factor: {base_tps/lora_tps:.1f}x")
|
||||
|
||||
log("=" * 70)
|
||||
|
||||
finally:
|
||||
|
|
|
|||
|
|
@ -1,326 +0,0 @@
|
|||
#!/bin/bash
|
||||
# ============================================================================
|
||||
# Compare lora_restart vs lora_only performance
|
||||
# ============================================================================
|
||||
# Runs both modes in parallel with separate APIs/environments/ports
|
||||
# All commands run in background (single terminal)
|
||||
# Results uploaded to W&B
|
||||
#
|
||||
# Usage:
|
||||
# ./compare_lora_modes.sh [steps]
|
||||
# ./compare_lora_modes.sh 30 # 30 steps (default)
|
||||
# ./compare_lora_modes.sh 10 # Quick 10-step test
|
||||
# ============================================================================
|
||||
|
||||
set -e
|
||||
|
||||
# Configuration
|
||||
MODEL="Qwen/Qwen3-4B-Instruct-2507"
|
||||
STEPS="${1:-30}"
|
||||
RESTART_INTERVAL=3
|
||||
WANDB_PROJECT="lora-mode-comparison"
|
||||
|
||||
# Port allocation
|
||||
# lora_restart: API 8001, vLLM 9001
|
||||
# lora_only: API 8002, vLLM 9002
|
||||
|
||||
echo "============================================================================"
|
||||
echo "LoRA Mode Comparison: lora_restart vs lora_only"
|
||||
echo "============================================================================"
|
||||
echo "Model: $MODEL"
|
||||
echo "Steps: $STEPS"
|
||||
echo "Restart interval: $RESTART_INTERVAL"
|
||||
echo "W&B project: $WANDB_PROJECT"
|
||||
echo ""
|
||||
echo "Port allocation:"
|
||||
echo " lora_restart: API=8001, vLLM=9001, GPU=0"
|
||||
echo " lora_only: API=8002, vLLM=9002, GPU=1"
|
||||
echo "============================================================================"
|
||||
|
||||
# Get script directory and repo root
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
REPO_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)"
|
||||
cd "$REPO_ROOT"
|
||||
echo "Working directory: $(pwd)"
|
||||
|
||||
# Create log directory
|
||||
LOGDIR="./lora_comparison_$(date +%Y%m%d_%H%M%S)"
|
||||
mkdir -p "$LOGDIR"
|
||||
echo "Log directory: $LOGDIR"
|
||||
|
||||
# Cleanup function
|
||||
cleanup() {
|
||||
echo ""
|
||||
echo "Cleaning up all processes..."
|
||||
|
||||
# Kill by name
|
||||
pkill -f "gsm8k_server.py" 2>/dev/null || true
|
||||
pkill -f "run-api" 2>/dev/null || true
|
||||
pkill -f "vllm_api_server.py" 2>/dev/null || true
|
||||
pkill -f "example_trainer.grpo" 2>/dev/null || true
|
||||
|
||||
# Kill by port
|
||||
for port in 8001 8002 9001 9002; do
|
||||
fuser -k ${port}/tcp 2>/dev/null || true
|
||||
done
|
||||
|
||||
echo "Cleanup complete."
|
||||
}
|
||||
trap cleanup EXIT
|
||||
|
||||
# Initial cleanup
|
||||
echo ""
|
||||
echo "Killing any existing processes on ports 8001, 8002, 9001, 9002..."
|
||||
cleanup
|
||||
sleep 3
|
||||
|
||||
# ============================================================================
|
||||
# MODE 1: lora_restart (GPU 0, ports 8001/9001)
|
||||
# ============================================================================
|
||||
echo ""
|
||||
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
echo "[1/2] LORA_RESTART MODE (GPU 0, API:8001, vLLM:9001)"
|
||||
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
|
||||
# Start API for lora_restart
|
||||
echo " Starting API server (port 8001)..."
|
||||
run-api --port 8001 > "$LOGDIR/api_restart.log" 2>&1 &
|
||||
RESTART_API_PID=$!
|
||||
sleep 3
|
||||
|
||||
# Check API is up
|
||||
if curl -s "http://localhost:8001/info" > /dev/null 2>&1; then
|
||||
echo " ✓ API running (PID: $RESTART_API_PID)"
|
||||
else
|
||||
echo " ✗ API failed to start"
|
||||
cat "$LOGDIR/api_restart.log"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Start trainer (lora_restart manages vLLM internally)
|
||||
echo " Starting lora_restart trainer (will launch vLLM on port 9001)..."
|
||||
CUDA_VISIBLE_DEVICES=0 python -m example_trainer.grpo \
|
||||
--model-name "$MODEL" \
|
||||
--weight-bridge-mode lora_restart \
|
||||
--vllm-port 9001 \
|
||||
--atropos-url http://localhost:8001 \
|
||||
--lora-r 16 \
|
||||
--lora-alpha 32 \
|
||||
--training-steps $STEPS \
|
||||
--vllm-restart-interval $RESTART_INTERVAL \
|
||||
--save-path "$LOGDIR/checkpoints_restart" \
|
||||
--use-wandb \
|
||||
--wandb-project "$WANDB_PROJECT" \
|
||||
--wandb-group "comparison-$(date +%Y%m%d)" \
|
||||
--benchmark \
|
||||
> "$LOGDIR/trainer_restart.log" 2>&1 &
|
||||
RESTART_TRAINER_PID=$!
|
||||
echo " ✓ Trainer started (PID: $RESTART_TRAINER_PID)"
|
||||
|
||||
# Wait for vLLM to be ready (trainer launches it)
|
||||
echo " Waiting for vLLM to start (port 9001)..."
|
||||
for i in {1..60}; do
|
||||
if curl -s "http://localhost:9001/health" > /dev/null 2>&1; then
|
||||
echo " ✓ vLLM ready after ~${i}s"
|
||||
break
|
||||
fi
|
||||
sleep 2
|
||||
done
|
||||
|
||||
# Start environment for lora_restart
|
||||
echo " Starting environment..."
|
||||
python -u environments/gsm8k_server.py serve \
|
||||
--env.tokenizer_name "$MODEL" \
|
||||
--env.rollout_server_url "http://localhost:8001" \
|
||||
--env.max_token_length 2048 \
|
||||
--env.use_wandb=True \
|
||||
--env.wandb_name "lora-restart-env" \
|
||||
--openai.model_name "$MODEL" \
|
||||
--openai.base_url "http://localhost:9001/v1" \
|
||||
--openai.server_type vllm \
|
||||
--slurm false \
|
||||
> "$LOGDIR/env_restart.log" 2>&1 &
|
||||
RESTART_ENV_PID=$!
|
||||
echo " ✓ Environment started (PID: $RESTART_ENV_PID)"
|
||||
|
||||
# ============================================================================
|
||||
# MODE 2: lora_only (GPU 1, ports 8002/9002)
|
||||
# ============================================================================
|
||||
echo ""
|
||||
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
echo "[2/2] LORA_ONLY MODE (GPU 1, API:8002, vLLM:9002)"
|
||||
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
|
||||
# Start API for lora_only
|
||||
echo " Starting API server (port 8002)..."
|
||||
run-api --port 8002 > "$LOGDIR/api_only.log" 2>&1 &
|
||||
ONLY_API_PID=$!
|
||||
sleep 3
|
||||
|
||||
# Check API is up
|
||||
if curl -s "http://localhost:8002/info" > /dev/null 2>&1; then
|
||||
echo " ✓ API running (PID: $ONLY_API_PID)"
|
||||
else
|
||||
echo " ✗ API failed to start"
|
||||
cat "$LOGDIR/api_only.log"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Start vLLM for lora_only (external, with --enforce-eager)
|
||||
echo " Starting vLLM with --enable-lora --enforce-eager (port 9002)..."
|
||||
CUDA_VISIBLE_DEVICES=1 python example_trainer/vllm_api_server.py \
|
||||
--model "$MODEL" \
|
||||
--port 9002 \
|
||||
--gpu-memory-utilization 0.45 \
|
||||
--enable-lora \
|
||||
--max-lora-rank 32 \
|
||||
--enforce-eager \
|
||||
> "$LOGDIR/vllm_only.log" 2>&1 &
|
||||
ONLY_VLLM_PID=$!
|
||||
echo " ✓ vLLM started (PID: $ONLY_VLLM_PID)"
|
||||
|
||||
# Wait for vLLM to be ready
|
||||
echo " Waiting for vLLM to start (port 9002)..."
|
||||
for i in {1..90}; do
|
||||
if curl -s "http://localhost:9002/health" > /dev/null 2>&1; then
|
||||
echo " ✓ vLLM ready after ~${i}s"
|
||||
break
|
||||
fi
|
||||
sleep 2
|
||||
done
|
||||
|
||||
# Start environment for lora_only
|
||||
echo " Starting environment..."
|
||||
python -u environments/gsm8k_server.py serve \
|
||||
--env.tokenizer_name "$MODEL" \
|
||||
--env.rollout_server_url "http://localhost:8002" \
|
||||
--env.max_token_length 2048 \
|
||||
--env.use_wandb=True \
|
||||
--env.wandb_name "lora-only-env" \
|
||||
--openai.model_name "$MODEL" \
|
||||
--openai.base_url "http://localhost:9002/v1" \
|
||||
--openai.server_type vllm \
|
||||
--slurm false \
|
||||
> "$LOGDIR/env_only.log" 2>&1 &
|
||||
ONLY_ENV_PID=$!
|
||||
echo " ✓ Environment started (PID: $ONLY_ENV_PID)"
|
||||
|
||||
# Start trainer for lora_only
|
||||
echo " Starting lora_only trainer..."
|
||||
CUDA_VISIBLE_DEVICES=1 python -m example_trainer.grpo \
|
||||
--model-name "$MODEL" \
|
||||
--weight-bridge-mode lora_only \
|
||||
--vllm-port 9002 \
|
||||
--atropos-url http://localhost:8002 \
|
||||
--lora-r 16 \
|
||||
--lora-alpha 32 \
|
||||
--training-steps $STEPS \
|
||||
--save-path "$LOGDIR/checkpoints_only" \
|
||||
--use-wandb \
|
||||
--wandb-project "$WANDB_PROJECT" \
|
||||
--wandb-group "comparison-$(date +%Y%m%d)" \
|
||||
--benchmark \
|
||||
> "$LOGDIR/trainer_only.log" 2>&1 &
|
||||
ONLY_TRAINER_PID=$!
|
||||
echo " ✓ Trainer started (PID: $ONLY_TRAINER_PID)"
|
||||
|
||||
# ============================================================================
|
||||
# Save PIDs and monitor
|
||||
# ============================================================================
|
||||
cat > "$LOGDIR/pids.txt" << EOF
|
||||
RESTART_API_PID=$RESTART_API_PID
|
||||
RESTART_TRAINER_PID=$RESTART_TRAINER_PID
|
||||
RESTART_ENV_PID=$RESTART_ENV_PID
|
||||
ONLY_API_PID=$ONLY_API_PID
|
||||
ONLY_VLLM_PID=$ONLY_VLLM_PID
|
||||
ONLY_ENV_PID=$ONLY_ENV_PID
|
||||
ONLY_TRAINER_PID=$ONLY_TRAINER_PID
|
||||
EOF
|
||||
|
||||
echo ""
|
||||
echo "============================================================================"
|
||||
echo "All components started!"
|
||||
echo "============================================================================"
|
||||
echo ""
|
||||
echo "📊 Monitor progress:"
|
||||
echo " tail -f $LOGDIR/trainer_restart.log # lora_restart"
|
||||
echo " tail -f $LOGDIR/trainer_only.log # lora_only"
|
||||
echo ""
|
||||
echo "🔍 Watch both:"
|
||||
echo " tail -f $LOGDIR/trainer_*.log"
|
||||
echo ""
|
||||
echo "📈 W&B Dashboard:"
|
||||
echo " https://wandb.ai/$WANDB_PROJECT"
|
||||
echo ""
|
||||
echo "Waiting for trainers to complete..."
|
||||
echo "(lora_restart should finish MUCH faster than lora_only)"
|
||||
echo ""
|
||||
|
||||
# Wait for trainers
|
||||
RESTART_STATUS="running"
|
||||
ONLY_STATUS="running"
|
||||
|
||||
while [ "$RESTART_STATUS" = "running" ] || [ "$ONLY_STATUS" = "running" ]; do
|
||||
sleep 30
|
||||
|
||||
# Check lora_restart
|
||||
if [ "$RESTART_STATUS" = "running" ]; then
|
||||
if ! kill -0 $RESTART_TRAINER_PID 2>/dev/null; then
|
||||
wait $RESTART_TRAINER_PID 2>/dev/null && RESTART_STATUS="completed" || RESTART_STATUS="failed"
|
||||
echo " lora_restart: $RESTART_STATUS"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Check lora_only
|
||||
if [ "$ONLY_STATUS" = "running" ]; then
|
||||
if ! kill -0 $ONLY_TRAINER_PID 2>/dev/null; then
|
||||
wait $ONLY_TRAINER_PID 2>/dev/null && ONLY_STATUS="completed" || ONLY_STATUS="failed"
|
||||
echo " lora_only: $ONLY_STATUS"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Show status
|
||||
if [ "$RESTART_STATUS" = "running" ] || [ "$ONLY_STATUS" = "running" ]; then
|
||||
echo " [$(date +%H:%M:%S)] lora_restart: $RESTART_STATUS, lora_only: $ONLY_STATUS"
|
||||
fi
|
||||
done
|
||||
|
||||
# ============================================================================
|
||||
# Print results
|
||||
# ============================================================================
|
||||
echo ""
|
||||
echo "============================================================================"
|
||||
echo "COMPARISON RESULTS"
|
||||
echo "============================================================================"
|
||||
|
||||
echo ""
|
||||
echo "📊 LORA_RESTART (CUDA graphs, vLLM restarts):"
|
||||
echo "─────────────────────────────────────────────────"
|
||||
grep -A 20 "BENCHMARK SUMMARY" "$LOGDIR/trainer_restart.log" 2>/dev/null || echo " (check $LOGDIR/trainer_restart.log)"
|
||||
|
||||
echo ""
|
||||
echo "📊 LORA_ONLY (--enforce-eager, hot-swap):"
|
||||
echo "─────────────────────────────────────────────────"
|
||||
grep -A 20 "BENCHMARK SUMMARY" "$LOGDIR/trainer_only.log" 2>/dev/null || echo " (check $LOGDIR/trainer_only.log)"
|
||||
|
||||
echo ""
|
||||
echo "============================================================================"
|
||||
echo "📁 LOGS SAVED TO: $LOGDIR"
|
||||
echo "============================================================================"
|
||||
echo ""
|
||||
echo "Log files:"
|
||||
echo " $LOGDIR/trainer_restart.log # lora_restart trainer"
|
||||
echo " $LOGDIR/trainer_only.log # lora_only trainer"
|
||||
echo " $LOGDIR/vllm_only.log # lora_only vLLM"
|
||||
echo " $LOGDIR/env_restart.log # lora_restart environment"
|
||||
echo " $LOGDIR/env_only.log # lora_only environment"
|
||||
echo ""
|
||||
echo "Checkpoints:"
|
||||
echo " $LOGDIR/checkpoints_restart/"
|
||||
echo " $LOGDIR/checkpoints_only/"
|
||||
echo ""
|
||||
echo "W&B runs should be visible at:"
|
||||
echo " https://wandb.ai/$WANDB_PROJECT"
|
||||
echo ""
|
||||
echo "============================================================================"
|
||||
echo "Done!"
|
||||
|
|
@ -1,138 +0,0 @@
|
|||
#!/bin/bash
|
||||
# Quick test script for lora_restart mode
|
||||
# Tests that the mode works and compares timing
|
||||
|
||||
set -e
|
||||
|
||||
MODEL="Qwen/Qwen3-4B-Instruct-2507"
|
||||
STEPS=10
|
||||
RESTART_INTERVAL=3
|
||||
|
||||
echo "=============================================="
|
||||
echo "Testing lora_restart mode"
|
||||
echo "=============================================="
|
||||
echo "Model: $MODEL"
|
||||
echo "Steps: $STEPS"
|
||||
echo "Restart interval: $RESTART_INTERVAL"
|
||||
echo "=============================================="
|
||||
|
||||
# Get script directory
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
REPO_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)"
|
||||
cd "$REPO_ROOT"
|
||||
|
||||
# Create log directory
|
||||
LOGDIR="./lora_restart_test_$(date +%Y%m%d_%H%M%S)"
|
||||
mkdir -p "$LOGDIR"
|
||||
echo "Logs: $LOGDIR"
|
||||
|
||||
# Cleanup function
|
||||
cleanup() {
|
||||
echo "Cleaning up..."
|
||||
pkill -f "gsm8k_server.py" 2>/dev/null || true
|
||||
pkill -f "run-api" 2>/dev/null || true
|
||||
pkill -f "vllm_api_server.py" 2>/dev/null || true
|
||||
# Kill by port
|
||||
for port in 8000 9001; do
|
||||
fuser -k ${port}/tcp 2>/dev/null || true
|
||||
done
|
||||
}
|
||||
trap cleanup EXIT
|
||||
|
||||
# Kill any existing processes
|
||||
cleanup
|
||||
sleep 2
|
||||
|
||||
# Start API server
|
||||
echo ""
|
||||
echo "[1/3] Starting Atropos API..."
|
||||
run-api --port 8000 > "$LOGDIR/api.log" 2>&1 &
|
||||
API_PID=$!
|
||||
sleep 3
|
||||
|
||||
# Check API is up
|
||||
if ! curl -s "http://localhost:8000/info" > /dev/null 2>&1; then
|
||||
echo "ERROR: API server failed to start"
|
||||
cat "$LOGDIR/api.log"
|
||||
exit 1
|
||||
fi
|
||||
echo " ✓ API running (PID: $API_PID)"
|
||||
|
||||
# Start trainer (lora_restart manages vLLM internally)
|
||||
echo ""
|
||||
echo "[2/3] Starting lora_restart trainer..."
|
||||
echo " (This will launch vLLM internally)"
|
||||
|
||||
START_TIME=$(date +%s)
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python -m example_trainer.grpo \
|
||||
--model-name "$MODEL" \
|
||||
--weight-bridge-mode lora_restart \
|
||||
--vllm-port 9001 \
|
||||
--atropos-url http://localhost:8000 \
|
||||
--lora-r 16 \
|
||||
--lora-alpha 32 \
|
||||
--training-steps $STEPS \
|
||||
--vllm-restart-interval $RESTART_INTERVAL \
|
||||
--save-path "$LOGDIR/checkpoints" \
|
||||
--benchmark \
|
||||
> "$LOGDIR/trainer.log" 2>&1 &
|
||||
TRAINER_PID=$!
|
||||
|
||||
# Wait for vLLM to start (trainer launches it)
|
||||
echo " Waiting for trainer to launch vLLM..."
|
||||
sleep 30
|
||||
|
||||
# Start environment (needs to wait for vLLM)
|
||||
echo ""
|
||||
echo "[3/3] Starting GSM8K environment..."
|
||||
python -u environments/gsm8k_server.py serve \
|
||||
--env.tokenizer_name "$MODEL" \
|
||||
--env.rollout_server_url "http://localhost:8000" \
|
||||
--env.max_token_length 2048 \
|
||||
--env.use_wandb=False \
|
||||
--openai.model_name "$MODEL" \
|
||||
--openai.base_url "http://localhost:9001/v1" \
|
||||
--openai.server_type vllm \
|
||||
--slurm false \
|
||||
> "$LOGDIR/env.log" 2>&1 &
|
||||
ENV_PID=$!
|
||||
sleep 5
|
||||
echo " ✓ Environment running (PID: $ENV_PID)"
|
||||
|
||||
# Wait for trainer to complete
|
||||
echo ""
|
||||
echo "Waiting for training to complete..."
|
||||
echo "(Check progress: tail -f $LOGDIR/trainer.log)"
|
||||
|
||||
wait $TRAINER_PID
|
||||
TRAINER_EXIT=$?
|
||||
|
||||
END_TIME=$(date +%s)
|
||||
ELAPSED=$((END_TIME - START_TIME))
|
||||
|
||||
echo ""
|
||||
echo "=============================================="
|
||||
echo "TEST RESULTS"
|
||||
echo "=============================================="
|
||||
|
||||
if [ $TRAINER_EXIT -eq 0 ]; then
|
||||
echo "✓ Training completed successfully!"
|
||||
echo " Time: ${ELAPSED}s"
|
||||
echo ""
|
||||
echo "Checkpoints:"
|
||||
ls -la "$LOGDIR/checkpoints/" 2>/dev/null || echo " (no checkpoints found)"
|
||||
echo ""
|
||||
echo "Benchmark summary:"
|
||||
grep -A 20 "BENCHMARK SUMMARY" "$LOGDIR/trainer.log" 2>/dev/null || echo " (no benchmark found)"
|
||||
else
|
||||
echo "✗ Training FAILED (exit code: $TRAINER_EXIT)"
|
||||
echo ""
|
||||
echo "Last 50 lines of trainer log:"
|
||||
tail -50 "$LOGDIR/trainer.log"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=============================================="
|
||||
echo "Log files saved to: $LOGDIR"
|
||||
echo "=============================================="
|
||||
Loading…
Add table
Add a link
Reference in a new issue