diff --git a/example_trainer/scripts/run_comparison.sh b/example_trainer/scripts/run_comparison.sh deleted file mode 100755 index fa56a5cf..00000000 --- a/example_trainer/scripts/run_comparison.sh +++ /dev/null @@ -1,462 +0,0 @@ -#!/bin/bash -# ============================================================================== -# GRPO Training Mode Comparison Script -# ============================================================================== -# Runs all three training modes (Legacy, Shared vLLM, LoRA) in parallel -# on an 8-GPU node for comparison. -# -# GPU Allocation: -# - GPUs 0-1: Legacy mode (trainer manages vLLM) -# - GPUs 2-3: Shared vLLM mode (CUDA IPC single-copy) -# - GPUs 4-5: LoRA mode (adapter training) -# - GPUs 6-7: Reserved -# -# Port Allocation: -# - Legacy: API 8001, vLLM 9001 -# - Shared vLLM: API 8002, vLLM 9002 -# - LoRA: API 8003, vLLM 9003 -# -# Usage: -# ./run_comparison.sh # Default 50 steps, logs to ./comparison_ -# ./run_comparison.sh 100 # 100 steps -# LOGDIR=/my/path ./run_comparison.sh # Custom log directory -# -# ============================================================================== -# OUTPUT DIRECTORY STRUCTURE ($LOGDIR): -# ============================================================================== -# -# $LOGDIR/ -# ├── api_legacy.log # run-api server log (port 8001) -# ├── api_shared.log # run-api server log (port 8002) -# ├── api_lora.log # run-api server log (port 8003) -# ├── env_legacy.log # gsm8k environment log -# ├── env_shared.log # gsm8k environment log -# ├── env_lora.log # gsm8k environment log -# ├── vllm_shared.log # vLLM server log (shared mode) -# ├── vllm_lora.log # vLLM server log (lora mode) -# ├── trainer_legacy.log # GRPO trainer log (MAIN OUTPUT) -# ├── trainer_shared.log # GRPO trainer log (MAIN OUTPUT) -# ├── trainer_lora.log # GRPO trainer log (MAIN OUTPUT) -# ├── vllm_bridge_config_shared.json # CUDA IPC config (shared mode) -# ├── vllm_bridge_config_lora.json # CUDA IPC config (lora mode) -# ├── pids.txt # Process IDs for cleanup -# ├── checkpoints_legacy/ # Model checkpoints -# │ ├── step_3/ -# │ ├── step_6/ -# │ └── final_model/ -# ├── checkpoints_shared/ # Model checkpoints -# │ └── final_model/ -# └── checkpoints_lora/ # LoRA adapter checkpoints -# ├── adapter_step_3/ -# ├── adapter_step_6/ -# ├── final_adapter/ -# └── tokenizer/ -# -# ============================================================================== - -set -e - -# Get script directory and repo root -# Script is at: atropos/example_trainer/scripts/run_comparison.sh -# We need to be at: atropos/ (where example_trainer/ and environments/ are) -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -REPO_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" - -# Change to repo root so all relative paths work correctly -cd "$REPO_ROOT" -echo "Working directory: $(pwd)" - -# Configuration -export MODEL="${MODEL:-Qwen/Qwen2.5-3B-Instruct}" -export TRAINING_STEPS="${1:-50}" -export BATCH_SIZE="${BATCH_SIZE:-2}" -export LOGDIR="${LOGDIR:-./comparison_$(date +%Y%m%d_%H%M%S)}" - -mkdir -p $LOGDIR - -# ============================================================================== -# Helper function: Wait for vLLM to be ready -# ============================================================================== -wait_for_vllm() { - local port=$1 - local name=$2 - local max_attempts=${3:-60} # Default 60 attempts (5 minutes with 5s sleep) - local attempt=1 - - echo " Waiting for vLLM ($name) on port $port..." - while [ $attempt -le $max_attempts ]; do - if curl -s "http://localhost:$port/health" > /dev/null 2>&1; then - echo " ✓ vLLM ($name) is ready after ~$((attempt * 5))s" - return 0 - fi - echo " Attempt $attempt/$max_attempts - vLLM not ready yet..." - sleep 5 - attempt=$((attempt + 1)) - done - - echo " ✗ vLLM ($name) failed to start after $((max_attempts * 5))s" - return 1 -} - -# ============================================================================== -# Helper function: Wait for API server to be ready -# ============================================================================== -wait_for_api() { - local port=$1 - local name=$2 - local max_attempts=${3:-20} - local attempt=1 - - echo " Waiting for API ($name) on port $port..." - while [ $attempt -le $max_attempts ]; do - if curl -s "http://localhost:$port/info" > /dev/null 2>&1; then - echo " ✓ API ($name) is ready" - return 0 - fi - sleep 2 - attempt=$((attempt + 1)) - done - - echo " ✗ API ($name) failed to start" - return 1 -} - -echo "==============================================" -echo "GRPO Training Mode Comparison" -echo "==============================================" -echo "Model: $MODEL" -echo "Steps: $TRAINING_STEPS" -echo "Batch size: $BATCH_SIZE" -echo "Log dir: $LOGDIR" -echo "==============================================" -echo "" - -# Cleanup function -cleanup() { - echo "" - echo "Cleaning up processes..." - if [ -f "$LOGDIR/pids.txt" ]; then - source $LOGDIR/pids.txt - kill $LEGACY_TRAINER_PID $LEGACY_ENV_PID $LEGACY_API_PID 2>/dev/null || true - kill $SHARED_TRAINER_PID $SHARED_VLLM_PID $SHARED_ENV_PID $SHARED_API_PID 2>/dev/null || true - kill $LORA_TRAINER_PID $LORA_VLLM_PID $LORA_ENV_PID $LORA_API_PID 2>/dev/null || true - fi - pkill -f "gsm8k_server.py" 2>/dev/null || true - pkill -f "run-api" 2>/dev/null || true - pkill -f "vllm_api_server.py" 2>/dev/null || true - # Also kill by port - for port in 8001 8002 8003 9001 9002 9003; do - fuser -k ${port}/tcp 2>/dev/null || true - done - echo "Cleanup complete." -} -trap cleanup EXIT - -# Kill any existing processes on our ports - be aggressive! -echo "Killing any existing processes on ports 8001-8003, 9001-9003..." - -# Kill by process name patterns -pkill -9 -f "vllm_api_server.py" 2>/dev/null || true -pkill -9 -f "gsm8k_server.py" 2>/dev/null || true -pkill -9 -f "run-api" 2>/dev/null || true -pkill -9 -f "grpo" 2>/dev/null || true - -# Kill by port using fuser (more reliable) -for port in 8001 8002 8003 9001 9002 9003; do - fuser -k ${port}/tcp 2>/dev/null || true -done - -sleep 3 - -# Verify ports are free -for port in 8001 8002 8003 9001 9002 9003; do - if lsof -i :${port} > /dev/null 2>&1; then - echo "WARNING: Port ${port} still in use!" - lsof -i :${port} | head -3 - fi -done - -# ============================================================================== -# MODE 1: LEGACY (GPUs 0-1, API 8001, vLLM 9001) -# ============================================================================== -echo "" -echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" -echo "[1/3] LEGACY MODE (GPUs 0-1, API:8001, vLLM:9001)" -echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" - -# Start run-api server for Legacy -echo " Starting run-api server..." -run-api --port 8001 > $LOGDIR/api_legacy.log 2>&1 & -LEGACY_API_PID=$! -echo " ✓ run-api started (PID: $LEGACY_API_PID, port 8001)" -wait_for_api 8001 "legacy" || { echo "Failed to start legacy API"; exit 1; } - -# In legacy mode: trainer launches vLLM internally, so start trainer FIRST -echo " Starting trainer (will launch internal vLLM on port 9001)..." -CUDA_VISIBLE_DEVICES=0,1 python -m example_trainer.grpo \ - --model-name $MODEL \ - --weight-bridge-mode none \ - --vllm-port 9001 \ - --vllm-gpu-memory-utilization 0.35 \ - --atropos-url http://localhost:8001 \ - --training-steps $TRAINING_STEPS \ - --batch-size $BATCH_SIZE \ - --save-path $LOGDIR/checkpoints_legacy \ - --benchmark \ - > $LOGDIR/trainer_legacy.log 2>&1 & -LEGACY_TRAINER_PID=$! -echo " ✓ Trainer started (PID: $LEGACY_TRAINER_PID)" - -# Wait for trainer's internal vLLM to be ready -wait_for_vllm 9001 "legacy (internal)" || { echo "Legacy vLLM failed to start"; exit 1; } - -# NOW start environment server (after vLLM is ready) -echo " Starting environment server..." -python -u environments/gsm8k_server.py serve \ - --env.tokenizer_name "$MODEL" \ - --env.use_wandb=False \ - --env.rollout_server_url "http://localhost:8001" \ - --openai.model_name "$MODEL" \ - --openai.base_url "http://localhost:9001/v1" \ - --openai.server_type vllm \ - --slurm false \ - > $LOGDIR/env_legacy.log 2>&1 & -LEGACY_ENV_PID=$! -echo " ✓ Environment server started (PID: $LEGACY_ENV_PID)" -sleep 3 - -# ============================================================================== -# MODE 2: SHARED_VLLM (GPUs 2-3, API 8002, vLLM 9002) -# ============================================================================== -echo "" -echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" -echo "[2/3] SHARED_VLLM MODE (GPUs 2-3, API:8002, vLLM:9002)" -echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" - -# Start run-api server for Shared -echo " Starting run-api server..." -run-api --port 8002 > $LOGDIR/api_shared.log 2>&1 & -SHARED_API_PID=$! -echo " ✓ run-api started (PID: $SHARED_API_PID, port 8002)" -wait_for_api 8002 "shared" || { echo "Failed to start shared API"; exit 1; } - -# Start vLLM with shared weights (use separate config path) -# NOTE: --enforce-eager is REQUIRED for single-copy mode! -# Without it, CUDA graphs freeze weights and updates won't be visible to inference. -echo " Starting vLLM with shared weights..." -VLLM_ENABLE_SHARED_WEIGHTS=1 VLLM_BRIDGE_CONFIG_PATH=$LOGDIR/vllm_bridge_config_shared.json \ -CUDA_VISIBLE_DEVICES=2 python example_trainer/vllm_api_server.py \ - --model $MODEL \ - --port 9002 \ - --gpu-memory-utilization 0.35 \ - --enforce-eager \ - > $LOGDIR/vllm_shared.log 2>&1 & -SHARED_VLLM_PID=$! -echo " ✓ vLLM started (PID: $SHARED_VLLM_PID, port 9002)" -wait_for_vllm 9002 "shared" || { echo "Failed to start shared vLLM"; exit 1; } - -# Start environment server for Shared -echo " Starting environment server..." -python -u environments/gsm8k_server.py serve \ - --env.tokenizer_name "$MODEL" \ - --env.use_wandb=False \ - --env.rollout_server_url "http://localhost:8002" \ - --openai.model_name "$MODEL" \ - --openai.base_url "http://localhost:9002/v1" \ - --openai.server_type vllm \ - --slurm false \ - > $LOGDIR/env_shared.log 2>&1 & -SHARED_ENV_PID=$! -echo " ✓ Environment server started (PID: $SHARED_ENV_PID)" -sleep 5 - -# Start Shared vLLM trainer -echo " Starting trainer..." -CUDA_VISIBLE_DEVICES=2 python -m example_trainer.grpo \ - --model-name $MODEL \ - --weight-bridge-mode shared_vllm \ - --vllm-port 9002 \ - --vllm-config-path $LOGDIR/vllm_bridge_config_shared.json \ - --atropos-url http://localhost:8002 \ - --training-steps $TRAINING_STEPS \ - --batch-size $BATCH_SIZE \ - --save-path $LOGDIR/checkpoints_shared \ - --benchmark \ - > $LOGDIR/trainer_shared.log 2>&1 & -SHARED_TRAINER_PID=$! -echo " ✓ Trainer started (PID: $SHARED_TRAINER_PID)" - -# ============================================================================== -# MODE 3: LORA (GPUs 4-5, API 8003, vLLM 9003) -# ============================================================================== -echo "" -echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" -echo "[3/3] LORA MODE (GPUs 4-5, API:8003, vLLM:9003)" -echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" - -# Start run-api server for LoRA -echo " Starting run-api server..." -run-api --port 8003 > $LOGDIR/api_lora.log 2>&1 & -LORA_API_PID=$! -echo " ✓ run-api started (PID: $LORA_API_PID, port 8003)" -wait_for_api 8003 "lora" || { echo "Failed to start lora API"; exit 1; } - -# Start vLLM with LoRA support (use separate config path) -echo " Starting vLLM with LoRA support..." -VLLM_BRIDGE_CONFIG_PATH=$LOGDIR/vllm_bridge_config_lora.json \ -CUDA_VISIBLE_DEVICES=4 python example_trainer/vllm_api_server.py \ - --model $MODEL \ - --port 9003 \ - --gpu-memory-utilization 0.35 \ - --enable-lora \ - --max-lora-rank 32 \ - --enforce-eager \ - > $LOGDIR/vllm_lora.log 2>&1 & -LORA_VLLM_PID=$! -echo " ✓ vLLM started (PID: $LORA_VLLM_PID, port 9003)" -wait_for_vllm 9003 "lora" || { echo "Failed to start lora vLLM"; exit 1; } - -# Start environment server for LoRA -echo " Starting environment server..." -python -u environments/gsm8k_server.py serve \ - --env.tokenizer_name "$MODEL" \ - --env.use_wandb=False \ - --env.rollout_server_url "http://localhost:8003" \ - --openai.model_name "$MODEL" \ - --openai.base_url "http://localhost:9003/v1" \ - --openai.server_type vllm \ - --slurm false \ - > $LOGDIR/env_lora.log 2>&1 & -LORA_ENV_PID=$! -echo " ✓ Environment server started (PID: $LORA_ENV_PID)" -sleep 5 - -# Start LoRA trainer -echo " Starting trainer..." -CUDA_VISIBLE_DEVICES=5 python -m example_trainer.grpo \ - --model-name $MODEL \ - --weight-bridge-mode lora_only \ - --vllm-port 9003 \ - --atropos-url http://localhost:8003 \ - --lora-r 16 \ - --lora-alpha 32 \ - --training-steps $TRAINING_STEPS \ - --batch-size $BATCH_SIZE \ - --save-path $LOGDIR/checkpoints_lora \ - --benchmark \ - > $LOGDIR/trainer_lora.log 2>&1 & -LORA_TRAINER_PID=$! -echo " ✓ Trainer started (PID: $LORA_TRAINER_PID)" - -# ============================================================================== -# Save PIDs and Monitor -# ============================================================================== -cat > $LOGDIR/pids.txt << EOF -LEGACY_TRAINER_PID=$LEGACY_TRAINER_PID -LEGACY_ENV_PID=$LEGACY_ENV_PID -LEGACY_API_PID=$LEGACY_API_PID -SHARED_TRAINER_PID=$SHARED_TRAINER_PID -SHARED_VLLM_PID=$SHARED_VLLM_PID -SHARED_ENV_PID=$SHARED_ENV_PID -SHARED_API_PID=$SHARED_API_PID -LORA_TRAINER_PID=$LORA_TRAINER_PID -LORA_VLLM_PID=$LORA_VLLM_PID -LORA_ENV_PID=$LORA_ENV_PID -LORA_API_PID=$LORA_API_PID -EOF - -echo "" -echo "==============================================" -echo "All components started!" -echo "==============================================" -echo "" -echo "📂 Log directory: $LOGDIR" -echo "" -echo "📊 Monitor progress:" -echo " tail -f $LOGDIR/trainer_legacy.log" -echo " tail -f $LOGDIR/trainer_shared.log" -echo " tail -f $LOGDIR/trainer_lora.log" -echo "" -echo "🔍 Or watch all at once:" -echo " tail -f $LOGDIR/trainer_*.log" -echo "" -echo "📋 Check API servers:" -echo " curl http://localhost:8001/info" -echo " curl http://localhost:8002/info" -echo " curl http://localhost:8003/info" -echo "" -echo "📝 Process IDs saved to: $LOGDIR/pids.txt" -echo "" -echo "Waiting for all trainers to complete..." -echo "(This may take a while depending on training steps)" -echo "" - -# Wait for trainers to complete -wait $LEGACY_TRAINER_PID 2>/dev/null && echo " ✓ Legacy trainer completed" || echo " ✗ Legacy trainer failed" -wait $SHARED_TRAINER_PID 2>/dev/null && echo " ✓ Shared vLLM trainer completed" || echo " ✗ Shared vLLM trainer failed" -wait $LORA_TRAINER_PID 2>/dev/null && echo " ✓ LoRA trainer completed" || echo " ✗ LoRA trainer failed" - -# ============================================================================== -# Print Results -# ============================================================================== -echo "" -echo "==============================================" -echo "COMPARISON RESULTS" -echo "==============================================" -echo "" - -echo "📊 LEGACY MODE:" -echo "─────────────────────────────────────────────" -grep -A 15 "BENCHMARK SUMMARY" $LOGDIR/trainer_legacy.log 2>/dev/null || echo " (check $LOGDIR/trainer_legacy.log)" -echo "" - -echo "📊 SHARED VLLM MODE:" -echo "─────────────────────────────────────────────" -grep -A 15 "BENCHMARK SUMMARY" $LOGDIR/trainer_shared.log 2>/dev/null || echo " (check $LOGDIR/trainer_shared.log)" -echo "" - -echo "📊 LORA MODE:" -echo "─────────────────────────────────────────────" -grep -A 15 "BENCHMARK SUMMARY" $LOGDIR/trainer_lora.log 2>/dev/null || echo " (check $LOGDIR/trainer_lora.log)" -echo "" - -echo "==============================================" -echo "📁 ALL OUTPUT SAVED TO: $LOGDIR" -echo "==============================================" -echo "" -echo "📋 LOG FILES:" -echo " Trainers (main output):" -echo " $LOGDIR/trainer_legacy.log" -echo " $LOGDIR/trainer_shared.log" -echo " $LOGDIR/trainer_lora.log" -echo "" -echo " vLLM servers:" -echo " $LOGDIR/vllm_shared.log" -echo " $LOGDIR/vllm_lora.log" -echo "" -echo " Environment servers:" -echo " $LOGDIR/env_legacy.log" -echo " $LOGDIR/env_shared.log" -echo " $LOGDIR/env_lora.log" -echo "" -echo " API servers:" -echo " $LOGDIR/api_legacy.log" -echo " $LOGDIR/api_shared.log" -echo " $LOGDIR/api_lora.log" -echo "" -echo "💾 CHECKPOINTS:" -echo " Legacy: $LOGDIR/checkpoints_legacy/final_model/" -echo " Shared vLLM: $LOGDIR/checkpoints_shared/final_model/" -echo " LoRA: $LOGDIR/checkpoints_lora/final_adapter/" -echo "" -echo "🔧 OTHER:" -echo " Process IDs: $LOGDIR/pids.txt" -echo " IPC Config (shared): $LOGDIR/vllm_bridge_config_shared.json" -echo " IPC Config (lora): $LOGDIR/vllm_bridge_config_lora.json" -echo "==============================================" -echo "" -echo "To re-run or inspect later:" -echo " export LOGDIR=$LOGDIR" -echo " tail -f \$LOGDIR/trainer_*.log" -echo "" -echo "Done!" diff --git a/example_trainer/scripts/test_lora_mode.sh b/example_trainer/scripts/test_lora_mode.sh index 5f9ff9e7..078e78e0 100644 --- a/example_trainer/scripts/test_lora_mode.sh +++ b/example_trainer/scripts/test_lora_mode.sh @@ -87,14 +87,13 @@ sleep 10 echo "" echo "[3/4] Baseline test (before training)..." -curl -s -X POST "http://localhost:${VLLM_PORT}/v1/chat/completions" \ +curl -s -X POST "http://localhost:${VLLM_PORT}/generate" \ -H "Content-Type: application/json" \ -d '{ - "model": "'"$MODEL"'", - "messages": [{"role": "user", "content": "What is 123 + 456?"}], + "prompt": "<|im_start|>user\nWhat is 123 + 456?<|im_end|>\n<|im_start|>assistant\n", "max_tokens": 100, "temperature": 0.1 - }' | jq '.choices[0].message.content' | tee "${LOG_DIR}/baseline_response.txt" + }' | jq '.text[0]' | tee "${LOG_DIR}/baseline_response.txt" echo "" echo "[4/4] Starting LoRA trainer..." @@ -130,14 +129,13 @@ if [ -d "$LOG_DIR/checkpoints" ]; then echo "" echo "Response after training:" - curl -s -X POST "http://localhost:${VLLM_PORT}/v1/chat/completions" \ + curl -s -X POST "http://localhost:${VLLM_PORT}/generate" \ -H "Content-Type: application/json" \ -d '{ - "model": "'"$MODEL"'", - "messages": [{"role": "user", "content": "What is 123 + 456?"}], + "prompt": "<|im_start|>user\nWhat is 123 + 456?<|im_end|>\n<|im_start|>assistant\n", "max_tokens": 100, "temperature": 0.1 - }' | jq '.choices[0].message.content' | tee "${LOG_DIR}/trained_response.txt" + }' | jq '.text[0]' | tee "${LOG_DIR}/trained_response.txt" fi fi diff --git a/example_trainer/scripts/test_single_copy_mode.sh b/example_trainer/scripts/test_single_copy_mode.sh index 28efc6de..1022ea72 100644 --- a/example_trainer/scripts/test_single_copy_mode.sh +++ b/example_trainer/scripts/test_single_copy_mode.sh @@ -97,14 +97,13 @@ sleep 10 echo "" echo "[3/4] Baseline test (before training)..." -curl -s -X POST "http://localhost:${VLLM_PORT}/v1/chat/completions" \ +curl -s -X POST "http://localhost:${VLLM_PORT}/generate" \ -H "Content-Type: application/json" \ -d '{ - "model": "'"$MODEL"'", - "messages": [{"role": "user", "content": "What is 123 + 456?"}], + "prompt": "<|im_start|>user\nWhat is 123 + 456?<|im_end|>\n<|im_start|>assistant\n", "max_tokens": 100, "temperature": 0.1 - }' | jq '.choices[0].message.content' | tee "${LOG_DIR}/baseline_response.txt" + }' | jq '.text[0]' | tee "${LOG_DIR}/baseline_response.txt" echo "" echo "[4/4] Starting Single-Copy trainer..." @@ -137,12 +136,11 @@ echo "============================================================" # Post-training test echo "" echo "Post-training test (weights are already updated in vLLM):" -curl -s -X POST "http://localhost:${VLLM_PORT}/v1/chat/completions" \ +curl -s -X POST "http://localhost:${VLLM_PORT}/generate" \ -H "Content-Type: application/json" \ -d '{ - "model": "'"$MODEL"'", - "messages": [{"role": "user", "content": "What is 123 + 456?"}], + "prompt": "<|im_start|>user\nWhat is 123 + 456?<|im_end|>\n<|im_start|>assistant\n", "max_tokens": 100, "temperature": 0.1 - }' | jq '.choices[0].message.content' | tee "${LOG_DIR}/trained_response.txt" + }' | jq '.text[0]' | tee "${LOG_DIR}/trained_response.txt" diff --git a/example_trainer/vllm_api_server.py b/example_trainer/vllm_api_server.py index fcc082b9..173465b6 100644 --- a/example_trainer/vllm_api_server.py +++ b/example_trainer/vllm_api_server.py @@ -114,8 +114,20 @@ from vllm.usage.usage_lib import UsageContext # noqa: E402 from vllm.utils import random_uuid # noqa: E402 from vllm.v1.engine.async_llm import AsyncLLM # noqa: E402 +# Handle vLLM version differences - FlexibleArgumentParser was removed/renamed +try: + from vllm.utils import FlexibleArgumentParser +except ImportError: + # Use standard argparse for newer vLLM versions + from argparse import ArgumentParser as FlexibleArgumentParser -from vllm.utils import FlexibleArgumentParser, set_ulimit # noqa: E402 +# set_ulimit might not exist in all vLLM versions +try: + from vllm.utils import set_ulimit +except ImportError: + def set_ulimit() -> None: + """No-op fallback for set_ulimit.""" + pass from vllm.outputs import RequestOutput # noqa: F401, E402 from vllm.version import __version__ as VLLM_VERSION # noqa: E402 @@ -312,220 +324,6 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response: return JSONResponse(ret) -# ============================================================================= -# OpenAI-Compatible Chat Completions Endpoint -# ============================================================================= - - -@app.post("/v1/chat/completions") -async def openai_chat_completions(request: Request) -> Response: - """ - OpenAI-compatible chat completions endpoint. - - This is a thin wrapper around our /generate endpoint that formats - the request/response to match OpenAI's chat completions API. - - Used by atroposlib/GSM8k environment for rollout generation. - """ - if engine is None: - raise HTTPException(status_code=503, detail="Engine not initialized") - - import time as time_module - - request_dict = await request.json() - - # Extract parameters - model = request_dict.get("model", "") - messages = request_dict.get("messages", []) - max_tokens = request_dict.get("max_tokens", 256) - temperature = request_dict.get("temperature", 1.0) - top_p = request_dict.get("top_p", 1.0) - n = request_dict.get("n", 1) - stop = request_dict.get("stop", None) - presence_penalty = request_dict.get("presence_penalty", 0.0) - frequency_penalty = request_dict.get("frequency_penalty", 0.0) - - # Convert messages to prompt using tokenizer's chat template - try: - prompt = engine.tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - except Exception: - # Fallback: simple concatenation if no chat template - prompt = "" - for msg in messages: - role = msg.get("role", "user") - content = msg.get("content", "") - prompt += f"{role}: {content}\n" - prompt += "assistant: " - - # Build sampling params (reusing our existing logic) - sampling_params = SamplingParams( - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - n=n, - stop=stop, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - ) - - request_id = random_uuid() - - # Get active LoRA adapter if any - lora_request = _get_lora_request() - - final_output = None - async for request_output in engine.generate( - prompt, sampling_params, request_id, lora_request=lora_request - ): - final_output = request_output - - if final_output is None: - raise HTTPException(status_code=500, detail="Generation failed") - - # Build choices in OpenAI chat format - choices = [] - for idx, output in enumerate(final_output.outputs): - choices.append( - { - "index": idx, - "message": { - "role": "assistant", - "content": output.text, - }, - "finish_reason": output.finish_reason or "stop", - } - ) - - # Build response - prompt_tokens = len(final_output.prompt_token_ids) - completion_tokens = sum(len(o.token_ids) for o in final_output.outputs) - - response = { - "id": f"chatcmpl-{random_uuid()}", - "object": "chat.completion", - "created": int(time_module.time()), - "model": model, - "choices": choices, - "usage": { - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": prompt_tokens + completion_tokens, - }, - } - - return JSONResponse(response) - - -@app.post("/v1/completions") -async def openai_completions(request: Request) -> Response: - """ - OpenAI-compatible text completions endpoint. - - This is the non-chat version of completions (raw text in, text out). - """ - if engine is None: - raise HTTPException(status_code=503, detail="Engine not initialized") - - import time as time_module - - request_dict = await request.json() - - # Extract parameters - model = request_dict.get("model", "") - prompt = request_dict.get("prompt", "") - max_tokens = request_dict.get("max_tokens", 256) - temperature = request_dict.get("temperature", 1.0) - top_p = request_dict.get("top_p", 1.0) - n = request_dict.get("n", 1) - stop = request_dict.get("stop", None) - presence_penalty = request_dict.get("presence_penalty", 0.0) - frequency_penalty = request_dict.get("frequency_penalty", 0.0) - logprobs_requested = request_dict.get("logprobs", None) - - # Handle single prompt or list of prompts - prompts = [prompt] if isinstance(prompt, str) else prompt - - # Build sampling params - sampling_params = SamplingParams( - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - n=n, - stop=stop, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - logprobs=logprobs_requested, - ) - - # Generate for each prompt - all_choices = [] - total_prompt_tokens = 0 - total_completion_tokens = 0 - - # Get active LoRA adapter if any - lora_request = _get_lora_request() - - for prompt_text in prompts: - request_id = random_uuid() - - final_output = None - async for request_output in engine.generate( - prompt_text, sampling_params, request_id, lora_request=lora_request - ): - final_output = request_output - - if final_output is None: - raise HTTPException(status_code=500, detail="Generation failed") - - # Count tokens - total_prompt_tokens += len(final_output.prompt_token_ids) - - # Build choices - for output in final_output.outputs: - total_completion_tokens += len(output.token_ids) - - choice = { - "text": output.text, - "index": len(all_choices), - "logprobs": None, - "finish_reason": output.finish_reason or "stop", - } - - # Add logprobs if requested - if logprobs_requested is not None and output.logprobs: - choice["logprobs"] = { - "tokens": [ - engine.tokenizer.decode([tid]) for tid in output.token_ids - ], - "token_logprobs": [ - list(lp.values())[0].logprob if lp else None - for lp in output.logprobs - ], - "top_logprobs": None, - "text_offset": [], - } - - all_choices.append(choice) - - # Build response in OpenAI format - response = { - "id": f"cmpl-{random_uuid()}", - "object": "text_completion", - "created": int(time_module.time()), - "model": model, - "choices": all_choices, - "usage": { - "prompt_tokens": total_prompt_tokens, - "completion_tokens": total_completion_tokens, - "total_tokens": total_prompt_tokens + total_completion_tokens, - }, - } - - return JSONResponse(response) - - # ============================================================================= # Bridge Endpoints (Weight Synchronization) # ============================================================================= @@ -911,12 +709,16 @@ async def run_server( # Log available endpoints logger.info("=" * 60) + logger.info("Streamlined vLLM Server - Training-Focused API") logger.info("Available endpoints:") - logger.info(" POST /generate - Generate completions") - logger.info(" GET /bridge/info - Bridge status") - logger.info(" POST /bridge/pause - Pause generation") - logger.info(" POST /bridge/resume - Resume generation") - logger.info(" GET /lora/status - LoRA adapter status") + logger.info(" POST /generate - Generate with logprobs (primary endpoint)") + logger.info(" GET /health - Health check") + logger.info(" GET /bridge/info - Bridge status") + logger.info(" POST /bridge/pause - Pause generation") + logger.info(" POST /bridge/resume - Resume generation") + logger.info(" GET /lora/status - LoRA adapter status") + logger.info(" POST /lora/load - Load LoRA adapter") + logger.info(" POST /lora/unload - Unload LoRA adapter") logger.info("=" * 60) shutdown_task = await serve_http( diff --git a/example_trainer/vllm_manager.py b/example_trainer/vllm_manager.py index bdc8c2a7..3713660a 100644 --- a/example_trainer/vllm_manager.py +++ b/example_trainer/vllm_manager.py @@ -118,9 +118,9 @@ def launch_vllm_server( Launch a vLLM server process using our custom vllm_api_server.py. Uses the custom server instead of standard vLLM because: - - Standard vLLM only has /v1/completions (OpenAI-compatible) - - Our custom server has /generate endpoint needed by VLLMServer class - - This allows proper tokens_and_logprobs_completion support + - Streamlined API: Only /generate endpoint (provides logprobs) + - Weight bridge support: /bridge/* endpoints for shared memory mode + - LoRA hot-swap: /lora/* endpoints for adapter loading/unloading Args: config: Training configuration