atropos/example_trainer/vllm_manager.py
Jai Suphavadeeprasit 9f6cc64b9e restart issues 3
2026-03-02 11:18:52 -05:00

315 lines
9.2 KiB
Python

"""
vLLM process management for GRPO trainer.
Handles launching, monitoring, and terminating vLLM server processes
for legacy mode training.
"""
import atexit
import os
import signal
import socket
import subprocess
import time
from typing import Optional
import requests
from .config import TrainingConfig
# Global variable to keep track of the vLLM process
_vllm_process: Optional[subprocess.Popen] = None
def is_port_in_use(port: int) -> bool:
"""Check if a port is already in use."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("localhost", port)) == 0
def kill_process_on_port(port: int, timeout: float = 5.0) -> bool:
"""
Kill any process using the specified port.
Returns True if no process was running or if it was successfully killed.
"""
if not is_port_in_use(port):
return True
print(f" Port {port} is in use, attempting to kill existing process...")
try:
# Try to find and kill the process using lsof (Linux/Mac)
result = subprocess.run(
["lsof", "-t", "-i", f":{port}"], capture_output=True, text=True, timeout=5
)
if result.stdout.strip():
pids = result.stdout.strip().split("\n")
print(f" Killing {len(pids)} processes on port {port}...")
for pid in pids:
try:
os.kill(int(pid), signal.SIGTERM)
except (ProcessLookupError, ValueError):
pass
# Wait for port to be free
start = time.time()
while time.time() - start < timeout:
if not is_port_in_use(port):
print(f" Port {port} is now free")
return True
time.sleep(0.5)
# Force kill if still running
killed_count = 0
for pid in pids:
try:
os.kill(int(pid), signal.SIGKILL)
killed_count += 1
except (ProcessLookupError, ValueError):
pass
if killed_count > 0:
print(f" Force killed {killed_count} stubborn processes")
time.sleep(1)
return not is_port_in_use(port)
except FileNotFoundError:
# lsof not available, try fuser (Linux)
try:
subprocess.run(["fuser", "-k", f"{port}/tcp"], timeout=5)
time.sleep(1)
return not is_port_in_use(port)
except (FileNotFoundError, subprocess.TimeoutExpired):
pass
except subprocess.TimeoutExpired:
pass
print(f" WARNING: Could not kill process on port {port}")
return False
def cleanup_vllm():
"""Cleanup function to terminate vLLM on exit."""
global _vllm_process
if _vllm_process:
print("\nTerminating vLLM process...")
_vllm_process.terminate()
try:
_vllm_process.wait(timeout=5)
print("vLLM process terminated.")
except subprocess.TimeoutExpired:
print("vLLM process did not terminate gracefully, killing.")
_vllm_process.kill()
_vllm_process.wait()
print("vLLM process killed.")
_vllm_process = None
# Register cleanup on module load
atexit.register(cleanup_vllm)
def launch_vllm_server(
config: TrainingConfig,
model_path: str,
) -> Optional[subprocess.Popen]:
"""
Launch a vLLM server process using our custom vllm_api_server.py.
Uses the custom server instead of standard vLLM because:
- Streamlined API: Only /generate endpoint (provides logprobs)
- Weight bridge support: /bridge/* endpoints for shared memory mode
- LoRA hot-swap: /lora/* endpoints for adapter loading/unloading
Args:
config: Training configuration
model_path: Path to model checkpoint
Returns:
Popen process object, or None if launch failed
"""
global _vllm_process
# Check if port is in use and try to kill existing process
if is_port_in_use(config.vllm_port):
print(f" WARNING: Port {config.vllm_port} is already in use!")
if not kill_process_on_port(config.vllm_port):
print(
f" ERROR: Could not free port {config.vllm_port}. Please manually kill the process."
)
print(f" Try: lsof -i :{config.vllm_port} | grep LISTEN")
print(f" Or: pkill -f 'vllm.*{config.vllm_port}'")
return None
print(f" Successfully freed port {config.vllm_port}")
# Use our custom vllm_api_server.py
script_dir = os.path.dirname(os.path.abspath(__file__))
custom_server_path = os.path.join(script_dir, "vllm_api_server.py")
vllm_command = [
"python",
custom_server_path,
"--model",
model_path,
"--port",
str(config.vllm_port),
"--gpu-memory-utilization",
str(config.vllm_gpu_memory_utilization),
]
# Add served-model-name if using checkpoint path
if model_path != config.model_name:
vllm_command.extend(["--served-model-name", config.model_name])
print(f" Launching vLLM: {' '.join(vllm_command)}")
try:
proc = subprocess.Popen(vllm_command)
print(f" vLLM launched with PID: {proc.pid}")
# Check for immediate startup errors
try:
proc.communicate(timeout=2)
if proc.returncode is not None and proc.returncode != 0:
print(" WARNING: vLLM failed to start")
return None
except subprocess.TimeoutExpired:
print(" vLLM process started (check logs for details)")
_vllm_process = proc
return proc
except FileNotFoundError:
print(" ERROR: vLLM not found. Is it installed?")
return None
except Exception as e:
print(f" ERROR launching vLLM: {e}")
return None
def terminate_vllm_process() -> None:
"""Terminate the running vLLM process if any."""
global _vllm_process
if _vllm_process is None:
return
print(" Terminating vLLM process...")
_vllm_process.terminate()
try:
_vllm_process.wait(timeout=5)
except subprocess.TimeoutExpired:
print(" vLLM did not terminate gracefully, killing...")
_vllm_process.kill()
_vllm_process.wait()
_vllm_process = None
def check_vllm_process_health() -> None:
"""Check if vLLM process terminated unexpectedly."""
global _vllm_process
if _vllm_process is not None and _vllm_process.poll() is not None:
print(
f" WARNING: vLLM terminated unexpectedly (code: {_vllm_process.returncode})"
)
_vllm_process = None
def get_vllm_process() -> Optional[subprocess.Popen]:
"""Get the current vLLM process."""
return _vllm_process
def set_vllm_process(proc: Optional[subprocess.Popen]) -> None:
"""Set the vLLM process (for external management)."""
global _vllm_process
_vllm_process = proc
def check_vllm_health(port: int) -> bool:
"""
Check if vLLM server is healthy and responding.
Args:
port: Port the vLLM server is running on
Returns:
True if server is healthy
"""
try:
response = requests.get(f"http://localhost:{port}/health", timeout=5)
return response.status_code == 200
except Exception:
return False
def wait_for_vllm_ready(port: int, timeout: float = 120.0) -> bool:
"""
Wait for vLLM server to be ready.
Args:
port: Port the vLLM server is running on
timeout: Maximum time to wait in seconds
Returns:
True if server is ready, False if timeout
"""
print(f" Waiting for vLLM to be ready (port {port})...")
start_time = time.time()
while time.time() - start_time < timeout:
if check_vllm_health(port):
print(" vLLM is ready!")
return True
time.sleep(2)
print(f" WARNING: vLLM not ready after {timeout}s")
return False
def hotswap_lora_adapter(
adapter_name: str,
adapter_path: str,
port: int,
) -> bool:
"""
Hot-swap a LoRA adapter on a running vLLM server.
Uses the vLLM /v1/load_lora_adapter endpoint to load a new adapter
without restarting the server.
Args:
adapter_name: Name to identify the adapter
adapter_path: Path to the adapter checkpoint
port: vLLM server port
Returns:
True if hot-swap succeeded
"""
try:
# Use vLLM's native LoRA loading endpoint
response = requests.post(
f"http://localhost:{port}/v1/load_lora_adapter",
json={
"lora_name": adapter_name,
"lora_path": adapter_path,
},
timeout=30,
)
if response.status_code == 200:
print(f" [LORA] ✓ Hot-swapped adapter: {adapter_name} ({adapter_path})")
return True
else:
print(
f" [LORA] ✗ Hot-swap failed: {response.status_code} - {response.text}"
)
return False
except requests.exceptions.ConnectionError:
print(f" [LORA] ✗ Cannot connect to vLLM at port {port}")
return False
except Exception as e:
print(f" [LORA] ✗ Error during hot-swap: {e}")
return False