major refactor 2

This commit is contained in:
Jai Suphavadeeprasit 2026-01-22 12:04:29 -05:00
parent 6833d4d820
commit 3a1229afaf
4 changed files with 175 additions and 33 deletions

View file

@ -37,7 +37,8 @@
# ├── trainer_legacy.log # GRPO trainer log (MAIN OUTPUT)
# ├── trainer_shared.log # GRPO trainer log (MAIN OUTPUT)
# ├── trainer_lora.log # GRPO trainer log (MAIN OUTPUT)
# ├── vllm_bridge_config.json # CUDA IPC config (shared mode)
# ├── vllm_bridge_config_shared.json # CUDA IPC config (shared mode)
# ├── vllm_bridge_config_lora.json # CUDA IPC config (lora mode)
# ├── pids.txt # Process IDs for cleanup
# ├── checkpoints_legacy/ # Model checkpoints
# │ ├── step_3/
@ -63,6 +64,53 @@ export LOGDIR="${LOGDIR:-./comparison_$(date +%Y%m%d_%H%M%S)}"
mkdir -p $LOGDIR
# ==============================================================================
# Helper function: Wait for vLLM to be ready
# ==============================================================================
wait_for_vllm() {
local port=$1
local name=$2
local max_attempts=${3:-60} # Default 60 attempts (5 minutes with 5s sleep)
local attempt=1
echo " Waiting for vLLM ($name) on port $port..."
while [ $attempt -le $max_attempts ]; do
if curl -s "http://localhost:$port/health" > /dev/null 2>&1; then
echo " ✓ vLLM ($name) is ready after ~$((attempt * 5))s"
return 0
fi
echo " Attempt $attempt/$max_attempts - vLLM not ready yet..."
sleep 5
attempt=$((attempt + 1))
done
echo " ✗ vLLM ($name) failed to start after $((max_attempts * 5))s"
return 1
}
# ==============================================================================
# Helper function: Wait for API server to be ready
# ==============================================================================
wait_for_api() {
local port=$1
local name=$2
local max_attempts=${3:-20}
local attempt=1
echo " Waiting for API ($name) on port $port..."
while [ $attempt -le $max_attempts ]; do
if curl -s "http://localhost:$port/info" > /dev/null 2>&1; then
echo " ✓ API ($name) is ready"
return 0
fi
sleep 2
attempt=$((attempt + 1))
done
echo " ✗ API ($name) failed to start"
return 1
}
echo "=============================================="
echo "GRPO Training Mode Comparison"
echo "=============================================="
@ -109,22 +157,10 @@ echo " Starting run-api server..."
run-api --port 8001 > $LOGDIR/api_legacy.log 2>&1 &
LEGACY_API_PID=$!
echo " ✓ run-api started (PID: $LEGACY_API_PID, port 8001)"
sleep 3
wait_for_api 8001 "legacy" || { echo "Failed to start legacy API"; exit 1; }
# Start environment server for Legacy
echo " Starting environment server..."
python environments/gsm8k_server.py serve \
--slurm.num_gpus 0 \
--env.tokenizer_name $MODEL \
--openai.base_url http://localhost:9001/v1 \
--server.port 8001 \
> $LOGDIR/env_legacy.log 2>&1 &
LEGACY_ENV_PID=$!
echo " ✓ Environment server started (PID: $LEGACY_ENV_PID)"
sleep 5
# Start Legacy trainer (it will manage its own vLLM)
echo " Starting trainer..."
# In legacy mode: trainer launches vLLM internally, so start trainer FIRST
echo " Starting trainer (will launch internal vLLM on port 9001)..."
CUDA_VISIBLE_DEVICES=0,1 python -m example_trainer.grpo \
--model-name $MODEL \
--weight-bridge-mode none \
@ -138,6 +174,21 @@ CUDA_VISIBLE_DEVICES=0,1 python -m example_trainer.grpo \
LEGACY_TRAINER_PID=$!
echo " ✓ Trainer started (PID: $LEGACY_TRAINER_PID)"
# Wait for trainer's internal vLLM to be ready
wait_for_vllm 9001 "legacy (internal)" || { echo "Legacy vLLM failed to start"; exit 1; }
# NOW start environment server (after vLLM is ready)
echo " Starting environment server..."
python environments/gsm8k_server.py serve \
--slurm.num_gpus 0 \
--env.tokenizer_name $MODEL \
--openai.base_url http://localhost:9001/v1 \
--server.port 8001 \
> $LOGDIR/env_legacy.log 2>&1 &
LEGACY_ENV_PID=$!
echo " ✓ Environment server started (PID: $LEGACY_ENV_PID)"
sleep 3
# ==============================================================================
# MODE 2: SHARED_VLLM (GPUs 2-3, API 8002, vLLM 9002)
# ==============================================================================
@ -151,11 +202,11 @@ echo " Starting run-api server..."
run-api --port 8002 > $LOGDIR/api_shared.log 2>&1 &
SHARED_API_PID=$!
echo " ✓ run-api started (PID: $SHARED_API_PID, port 8002)"
sleep 3
wait_for_api 8002 "shared" || { echo "Failed to start shared API"; exit 1; }
# Start vLLM with shared weights
# Start vLLM with shared weights (use separate config path)
echo " Starting vLLM with shared weights..."
VLLM_ENABLE_SHARED_WEIGHTS=1 LOGDIR=$LOGDIR \
VLLM_ENABLE_SHARED_WEIGHTS=1 VLLM_BRIDGE_CONFIG_PATH=$LOGDIR/vllm_bridge_config_shared.json \
CUDA_VISIBLE_DEVICES=2 python example_trainer/vllm_api_server.py \
--model $MODEL \
--port 9002 \
@ -163,8 +214,7 @@ CUDA_VISIBLE_DEVICES=2 python example_trainer/vllm_api_server.py \
> $LOGDIR/vllm_shared.log 2>&1 &
SHARED_VLLM_PID=$!
echo " ✓ vLLM started (PID: $SHARED_VLLM_PID, port 9002)"
echo " Waiting for vLLM to initialize (30s)..."
sleep 30
wait_for_vllm 9002 "shared" || { echo "Failed to start shared vLLM"; exit 1; }
# Start environment server for Shared
echo " Starting environment server..."
@ -184,7 +234,7 @@ CUDA_VISIBLE_DEVICES=2 python -m example_trainer.grpo \
--model-name $MODEL \
--weight-bridge-mode shared_vllm \
--vllm-port 9002 \
--vllm-config-path $LOGDIR/vllm_bridge_config.json \
--vllm-config-path $LOGDIR/vllm_bridge_config_shared.json \
--atropos-url http://localhost:8002 \
--training-steps $TRAINING_STEPS \
--batch-size $BATCH_SIZE \
@ -207,10 +257,11 @@ echo " Starting run-api server..."
run-api --port 8003 > $LOGDIR/api_lora.log 2>&1 &
LORA_API_PID=$!
echo " ✓ run-api started (PID: $LORA_API_PID, port 8003)"
sleep 3
wait_for_api 8003 "lora" || { echo "Failed to start lora API"; exit 1; }
# Start vLLM with LoRA support
# Start vLLM with LoRA support (use separate config path)
echo " Starting vLLM with LoRA support..."
VLLM_BRIDGE_CONFIG_PATH=$LOGDIR/vllm_bridge_config_lora.json \
CUDA_VISIBLE_DEVICES=4 python example_trainer/vllm_api_server.py \
--model $MODEL \
--port 9003 \
@ -221,8 +272,7 @@ CUDA_VISIBLE_DEVICES=4 python example_trainer/vllm_api_server.py \
> $LOGDIR/vllm_lora.log 2>&1 &
LORA_VLLM_PID=$!
echo " ✓ vLLM started (PID: $LORA_VLLM_PID, port 9003)"
echo " Waiting for vLLM to initialize (30s)..."
sleep 30
wait_for_vllm 9003 "lora" || { echo "Failed to start lora vLLM"; exit 1; }
# Start environment server for LoRA
echo " Starting environment server..."
@ -356,7 +406,8 @@ echo " LoRA: $LOGDIR/checkpoints_lora/final_adapter/"
echo ""
echo "🔧 OTHER:"
echo " Process IDs: $LOGDIR/pids.txt"
echo " IPC Config: $LOGDIR/vllm_bridge_config.json"
echo " IPC Config (shared): $LOGDIR/vllm_bridge_config_shared.json"
echo " IPC Config (lora): $LOGDIR/vllm_bridge_config_lora.json"
echo "=============================================="
echo ""
echo "To re-run or inspect later:"