mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
major refactor 2
This commit is contained in:
parent
6833d4d820
commit
3a1229afaf
4 changed files with 175 additions and 33 deletions
|
|
@ -37,7 +37,8 @@
|
|||
# ├── trainer_legacy.log # GRPO trainer log (MAIN OUTPUT)
|
||||
# ├── trainer_shared.log # GRPO trainer log (MAIN OUTPUT)
|
||||
# ├── trainer_lora.log # GRPO trainer log (MAIN OUTPUT)
|
||||
# ├── vllm_bridge_config.json # CUDA IPC config (shared mode)
|
||||
# ├── vllm_bridge_config_shared.json # CUDA IPC config (shared mode)
|
||||
# ├── vllm_bridge_config_lora.json # CUDA IPC config (lora mode)
|
||||
# ├── pids.txt # Process IDs for cleanup
|
||||
# ├── checkpoints_legacy/ # Model checkpoints
|
||||
# │ ├── step_3/
|
||||
|
|
@ -63,6 +64,53 @@ export LOGDIR="${LOGDIR:-./comparison_$(date +%Y%m%d_%H%M%S)}"
|
|||
|
||||
mkdir -p $LOGDIR
|
||||
|
||||
# ==============================================================================
|
||||
# Helper function: Wait for vLLM to be ready
|
||||
# ==============================================================================
|
||||
wait_for_vllm() {
|
||||
local port=$1
|
||||
local name=$2
|
||||
local max_attempts=${3:-60} # Default 60 attempts (5 minutes with 5s sleep)
|
||||
local attempt=1
|
||||
|
||||
echo " Waiting for vLLM ($name) on port $port..."
|
||||
while [ $attempt -le $max_attempts ]; do
|
||||
if curl -s "http://localhost:$port/health" > /dev/null 2>&1; then
|
||||
echo " ✓ vLLM ($name) is ready after ~$((attempt * 5))s"
|
||||
return 0
|
||||
fi
|
||||
echo " Attempt $attempt/$max_attempts - vLLM not ready yet..."
|
||||
sleep 5
|
||||
attempt=$((attempt + 1))
|
||||
done
|
||||
|
||||
echo " ✗ vLLM ($name) failed to start after $((max_attempts * 5))s"
|
||||
return 1
|
||||
}
|
||||
|
||||
# ==============================================================================
|
||||
# Helper function: Wait for API server to be ready
|
||||
# ==============================================================================
|
||||
wait_for_api() {
|
||||
local port=$1
|
||||
local name=$2
|
||||
local max_attempts=${3:-20}
|
||||
local attempt=1
|
||||
|
||||
echo " Waiting for API ($name) on port $port..."
|
||||
while [ $attempt -le $max_attempts ]; do
|
||||
if curl -s "http://localhost:$port/info" > /dev/null 2>&1; then
|
||||
echo " ✓ API ($name) is ready"
|
||||
return 0
|
||||
fi
|
||||
sleep 2
|
||||
attempt=$((attempt + 1))
|
||||
done
|
||||
|
||||
echo " ✗ API ($name) failed to start"
|
||||
return 1
|
||||
}
|
||||
|
||||
echo "=============================================="
|
||||
echo "GRPO Training Mode Comparison"
|
||||
echo "=============================================="
|
||||
|
|
@ -109,22 +157,10 @@ echo " Starting run-api server..."
|
|||
run-api --port 8001 > $LOGDIR/api_legacy.log 2>&1 &
|
||||
LEGACY_API_PID=$!
|
||||
echo " ✓ run-api started (PID: $LEGACY_API_PID, port 8001)"
|
||||
sleep 3
|
||||
wait_for_api 8001 "legacy" || { echo "Failed to start legacy API"; exit 1; }
|
||||
|
||||
# Start environment server for Legacy
|
||||
echo " Starting environment server..."
|
||||
python environments/gsm8k_server.py serve \
|
||||
--slurm.num_gpus 0 \
|
||||
--env.tokenizer_name $MODEL \
|
||||
--openai.base_url http://localhost:9001/v1 \
|
||||
--server.port 8001 \
|
||||
> $LOGDIR/env_legacy.log 2>&1 &
|
||||
LEGACY_ENV_PID=$!
|
||||
echo " ✓ Environment server started (PID: $LEGACY_ENV_PID)"
|
||||
sleep 5
|
||||
|
||||
# Start Legacy trainer (it will manage its own vLLM)
|
||||
echo " Starting trainer..."
|
||||
# In legacy mode: trainer launches vLLM internally, so start trainer FIRST
|
||||
echo " Starting trainer (will launch internal vLLM on port 9001)..."
|
||||
CUDA_VISIBLE_DEVICES=0,1 python -m example_trainer.grpo \
|
||||
--model-name $MODEL \
|
||||
--weight-bridge-mode none \
|
||||
|
|
@ -138,6 +174,21 @@ CUDA_VISIBLE_DEVICES=0,1 python -m example_trainer.grpo \
|
|||
LEGACY_TRAINER_PID=$!
|
||||
echo " ✓ Trainer started (PID: $LEGACY_TRAINER_PID)"
|
||||
|
||||
# Wait for trainer's internal vLLM to be ready
|
||||
wait_for_vllm 9001 "legacy (internal)" || { echo "Legacy vLLM failed to start"; exit 1; }
|
||||
|
||||
# NOW start environment server (after vLLM is ready)
|
||||
echo " Starting environment server..."
|
||||
python environments/gsm8k_server.py serve \
|
||||
--slurm.num_gpus 0 \
|
||||
--env.tokenizer_name $MODEL \
|
||||
--openai.base_url http://localhost:9001/v1 \
|
||||
--server.port 8001 \
|
||||
> $LOGDIR/env_legacy.log 2>&1 &
|
||||
LEGACY_ENV_PID=$!
|
||||
echo " ✓ Environment server started (PID: $LEGACY_ENV_PID)"
|
||||
sleep 3
|
||||
|
||||
# ==============================================================================
|
||||
# MODE 2: SHARED_VLLM (GPUs 2-3, API 8002, vLLM 9002)
|
||||
# ==============================================================================
|
||||
|
|
@ -151,11 +202,11 @@ echo " Starting run-api server..."
|
|||
run-api --port 8002 > $LOGDIR/api_shared.log 2>&1 &
|
||||
SHARED_API_PID=$!
|
||||
echo " ✓ run-api started (PID: $SHARED_API_PID, port 8002)"
|
||||
sleep 3
|
||||
wait_for_api 8002 "shared" || { echo "Failed to start shared API"; exit 1; }
|
||||
|
||||
# Start vLLM with shared weights
|
||||
# Start vLLM with shared weights (use separate config path)
|
||||
echo " Starting vLLM with shared weights..."
|
||||
VLLM_ENABLE_SHARED_WEIGHTS=1 LOGDIR=$LOGDIR \
|
||||
VLLM_ENABLE_SHARED_WEIGHTS=1 VLLM_BRIDGE_CONFIG_PATH=$LOGDIR/vllm_bridge_config_shared.json \
|
||||
CUDA_VISIBLE_DEVICES=2 python example_trainer/vllm_api_server.py \
|
||||
--model $MODEL \
|
||||
--port 9002 \
|
||||
|
|
@ -163,8 +214,7 @@ CUDA_VISIBLE_DEVICES=2 python example_trainer/vllm_api_server.py \
|
|||
> $LOGDIR/vllm_shared.log 2>&1 &
|
||||
SHARED_VLLM_PID=$!
|
||||
echo " ✓ vLLM started (PID: $SHARED_VLLM_PID, port 9002)"
|
||||
echo " Waiting for vLLM to initialize (30s)..."
|
||||
sleep 30
|
||||
wait_for_vllm 9002 "shared" || { echo "Failed to start shared vLLM"; exit 1; }
|
||||
|
||||
# Start environment server for Shared
|
||||
echo " Starting environment server..."
|
||||
|
|
@ -184,7 +234,7 @@ CUDA_VISIBLE_DEVICES=2 python -m example_trainer.grpo \
|
|||
--model-name $MODEL \
|
||||
--weight-bridge-mode shared_vllm \
|
||||
--vllm-port 9002 \
|
||||
--vllm-config-path $LOGDIR/vllm_bridge_config.json \
|
||||
--vllm-config-path $LOGDIR/vllm_bridge_config_shared.json \
|
||||
--atropos-url http://localhost:8002 \
|
||||
--training-steps $TRAINING_STEPS \
|
||||
--batch-size $BATCH_SIZE \
|
||||
|
|
@ -207,10 +257,11 @@ echo " Starting run-api server..."
|
|||
run-api --port 8003 > $LOGDIR/api_lora.log 2>&1 &
|
||||
LORA_API_PID=$!
|
||||
echo " ✓ run-api started (PID: $LORA_API_PID, port 8003)"
|
||||
sleep 3
|
||||
wait_for_api 8003 "lora" || { echo "Failed to start lora API"; exit 1; }
|
||||
|
||||
# Start vLLM with LoRA support
|
||||
# Start vLLM with LoRA support (use separate config path)
|
||||
echo " Starting vLLM with LoRA support..."
|
||||
VLLM_BRIDGE_CONFIG_PATH=$LOGDIR/vllm_bridge_config_lora.json \
|
||||
CUDA_VISIBLE_DEVICES=4 python example_trainer/vllm_api_server.py \
|
||||
--model $MODEL \
|
||||
--port 9003 \
|
||||
|
|
@ -221,8 +272,7 @@ CUDA_VISIBLE_DEVICES=4 python example_trainer/vllm_api_server.py \
|
|||
> $LOGDIR/vllm_lora.log 2>&1 &
|
||||
LORA_VLLM_PID=$!
|
||||
echo " ✓ vLLM started (PID: $LORA_VLLM_PID, port 9003)"
|
||||
echo " Waiting for vLLM to initialize (30s)..."
|
||||
sleep 30
|
||||
wait_for_vllm 9003 "lora" || { echo "Failed to start lora vLLM"; exit 1; }
|
||||
|
||||
# Start environment server for LoRA
|
||||
echo " Starting environment server..."
|
||||
|
|
@ -356,7 +406,8 @@ echo " LoRA: $LOGDIR/checkpoints_lora/final_adapter/"
|
|||
echo ""
|
||||
echo "🔧 OTHER:"
|
||||
echo " Process IDs: $LOGDIR/pids.txt"
|
||||
echo " IPC Config: $LOGDIR/vllm_bridge_config.json"
|
||||
echo " IPC Config (shared): $LOGDIR/vllm_bridge_config_shared.json"
|
||||
echo " IPC Config (lora): $LOGDIR/vllm_bridge_config_lora.json"
|
||||
echo "=============================================="
|
||||
echo ""
|
||||
echo "To re-run or inspect later:"
|
||||
|
|
|
|||
|
|
@ -858,9 +858,15 @@ async def init_app(args: Namespace, llm_engine: AsyncLLM | None = None) -> FastA
|
|||
|
||||
def _export_state_dict_info(args: Namespace) -> None:
|
||||
"""Export basic model info to JSON for trainer (backup if patches don't run)."""
|
||||
log_dir = os.environ.get("LOGDIR", ".")
|
||||
Path(log_dir).mkdir(parents=True, exist_ok=True)
|
||||
json_path = Path(log_dir) / "vllm_bridge_config.json"
|
||||
# Allow explicit config path via env var, otherwise use LOGDIR
|
||||
config_path = os.environ.get("VLLM_BRIDGE_CONFIG_PATH")
|
||||
if config_path:
|
||||
json_path = Path(config_path)
|
||||
json_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
else:
|
||||
log_dir = os.environ.get("LOGDIR", ".")
|
||||
Path(log_dir).mkdir(parents=True, exist_ok=True)
|
||||
json_path = Path(log_dir) / "vllm_bridge_config.json"
|
||||
|
||||
# Only write basic info if the file doesn't exist or is empty
|
||||
# The patched runner will write complete info with param_mappings
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ for legacy mode training.
|
|||
|
||||
import atexit
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import time
|
||||
from typing import Optional
|
||||
|
|
@ -20,6 +22,73 @@ from .config import TrainingConfig
|
|||
_vllm_process: Optional[subprocess.Popen] = None
|
||||
|
||||
|
||||
def is_port_in_use(port: int) -> bool:
|
||||
"""Check if a port is already in use."""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
return s.connect_ex(('localhost', port)) == 0
|
||||
|
||||
|
||||
def kill_process_on_port(port: int, timeout: float = 5.0) -> bool:
|
||||
"""
|
||||
Kill any process using the specified port.
|
||||
|
||||
Returns True if no process was running or if it was successfully killed.
|
||||
"""
|
||||
if not is_port_in_use(port):
|
||||
return True
|
||||
|
||||
print(f" Port {port} is in use, attempting to kill existing process...")
|
||||
|
||||
try:
|
||||
# Try to find and kill the process using lsof (Linux/Mac)
|
||||
result = subprocess.run(
|
||||
["lsof", "-t", "-i", f":{port}"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
if result.stdout.strip():
|
||||
pids = result.stdout.strip().split('\n')
|
||||
for pid in pids:
|
||||
try:
|
||||
os.kill(int(pid), signal.SIGTERM)
|
||||
print(f" Sent SIGTERM to PID {pid}")
|
||||
except (ProcessLookupError, ValueError):
|
||||
pass
|
||||
|
||||
# Wait for port to be free
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
if not is_port_in_use(port):
|
||||
print(f" Port {port} is now free")
|
||||
return True
|
||||
time.sleep(0.5)
|
||||
|
||||
# Force kill if still running
|
||||
for pid in pids:
|
||||
try:
|
||||
os.kill(int(pid), signal.SIGKILL)
|
||||
print(f" Sent SIGKILL to PID {pid}")
|
||||
except (ProcessLookupError, ValueError):
|
||||
pass
|
||||
|
||||
time.sleep(1)
|
||||
return not is_port_in_use(port)
|
||||
except FileNotFoundError:
|
||||
# lsof not available, try fuser (Linux)
|
||||
try:
|
||||
subprocess.run(["fuser", "-k", f"{port}/tcp"], timeout=5)
|
||||
time.sleep(1)
|
||||
return not is_port_in_use(port)
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||
pass
|
||||
except subprocess.TimeoutExpired:
|
||||
pass
|
||||
|
||||
print(f" WARNING: Could not kill process on port {port}")
|
||||
return False
|
||||
|
||||
|
||||
def cleanup_vllm():
|
||||
"""Cleanup function to terminate vLLM on exit."""
|
||||
global _vllm_process
|
||||
|
|
@ -62,6 +131,16 @@ def launch_vllm_server(
|
|||
"""
|
||||
global _vllm_process
|
||||
|
||||
# Check if port is in use and try to kill existing process
|
||||
if is_port_in_use(config.vllm_port):
|
||||
print(f" WARNING: Port {config.vllm_port} is already in use!")
|
||||
if not kill_process_on_port(config.vllm_port):
|
||||
print(f" ERROR: Could not free port {config.vllm_port}. Please manually kill the process.")
|
||||
print(f" Try: lsof -i :{config.vllm_port} | grep LISTEN")
|
||||
print(f" Or: pkill -f 'vllm.*{config.vllm_port}'")
|
||||
return None
|
||||
print(f" Successfully freed port {config.vllm_port}")
|
||||
|
||||
# Use our custom vllm_api_server.py
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
custom_server_path = os.path.join(script_dir, "vllm_api_server.py")
|
||||
|
|
|
|||
|
|
@ -295,9 +295,15 @@ def _create_patched_runner(BaseRunner: type) -> type:
|
|||
print(f"[vLLM Patch] Note: model.share_memory() not available: {e}")
|
||||
|
||||
# Export parameter info to JSON for trainer
|
||||
log_dir = os.environ.get("LOGDIR", ".")
|
||||
Path(log_dir).mkdir(parents=True, exist_ok=True)
|
||||
json_path = Path(log_dir) / "vllm_bridge_config.json"
|
||||
# Allow explicit config path via env var, otherwise use LOGDIR
|
||||
config_path = os.environ.get("VLLM_BRIDGE_CONFIG_PATH")
|
||||
if config_path:
|
||||
json_path = Path(config_path)
|
||||
json_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
else:
|
||||
log_dir = os.environ.get("LOGDIR", ".")
|
||||
Path(log_dir).mkdir(parents=True, exist_ok=True)
|
||||
json_path = Path(log_dir) / "vllm_bridge_config.json"
|
||||
|
||||
param_mappings = {}
|
||||
param_names = []
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue