mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
315 lines
9.2 KiB
Python
315 lines
9.2 KiB
Python
"""
|
|
vLLM process management for GRPO trainer.
|
|
|
|
Handles launching, monitoring, and terminating vLLM server processes
|
|
for legacy mode training.
|
|
"""
|
|
|
|
import atexit
|
|
import os
|
|
import signal
|
|
import socket
|
|
import subprocess
|
|
import time
|
|
from typing import Optional
|
|
|
|
import requests
|
|
|
|
from .config import TrainingConfig
|
|
|
|
# Global variable to keep track of the vLLM process
|
|
_vllm_process: Optional[subprocess.Popen] = None
|
|
|
|
|
|
def is_port_in_use(port: int) -> bool:
|
|
"""Check if a port is already in use."""
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
return s.connect_ex(("localhost", port)) == 0
|
|
|
|
|
|
def kill_process_on_port(port: int, timeout: float = 5.0) -> bool:
|
|
"""
|
|
Kill any process using the specified port.
|
|
|
|
Returns True if no process was running or if it was successfully killed.
|
|
"""
|
|
if not is_port_in_use(port):
|
|
return True
|
|
|
|
print(f" Port {port} is in use, attempting to kill existing process...")
|
|
|
|
try:
|
|
# Try to find and kill the process using lsof (Linux/Mac)
|
|
result = subprocess.run(
|
|
["lsof", "-t", "-i", f":{port}"], capture_output=True, text=True, timeout=5
|
|
)
|
|
if result.stdout.strip():
|
|
pids = result.stdout.strip().split("\n")
|
|
print(f" Killing {len(pids)} processes on port {port}...")
|
|
for pid in pids:
|
|
try:
|
|
os.kill(int(pid), signal.SIGTERM)
|
|
except (ProcessLookupError, ValueError):
|
|
pass
|
|
|
|
# Wait for port to be free
|
|
start = time.time()
|
|
while time.time() - start < timeout:
|
|
if not is_port_in_use(port):
|
|
print(f" Port {port} is now free")
|
|
return True
|
|
time.sleep(0.5)
|
|
|
|
# Force kill if still running
|
|
killed_count = 0
|
|
for pid in pids:
|
|
try:
|
|
os.kill(int(pid), signal.SIGKILL)
|
|
killed_count += 1
|
|
except (ProcessLookupError, ValueError):
|
|
pass
|
|
if killed_count > 0:
|
|
print(f" Force killed {killed_count} stubborn processes")
|
|
|
|
time.sleep(1)
|
|
return not is_port_in_use(port)
|
|
except FileNotFoundError:
|
|
# lsof not available, try fuser (Linux)
|
|
try:
|
|
subprocess.run(["fuser", "-k", f"{port}/tcp"], timeout=5)
|
|
time.sleep(1)
|
|
return not is_port_in_use(port)
|
|
except (FileNotFoundError, subprocess.TimeoutExpired):
|
|
pass
|
|
except subprocess.TimeoutExpired:
|
|
pass
|
|
|
|
print(f" WARNING: Could not kill process on port {port}")
|
|
return False
|
|
|
|
|
|
def cleanup_vllm():
|
|
"""Cleanup function to terminate vLLM on exit."""
|
|
global _vllm_process
|
|
if _vllm_process:
|
|
print("\nTerminating vLLM process...")
|
|
_vllm_process.terminate()
|
|
try:
|
|
_vllm_process.wait(timeout=5)
|
|
print("vLLM process terminated.")
|
|
except subprocess.TimeoutExpired:
|
|
print("vLLM process did not terminate gracefully, killing.")
|
|
_vllm_process.kill()
|
|
_vllm_process.wait()
|
|
print("vLLM process killed.")
|
|
_vllm_process = None
|
|
|
|
|
|
# Register cleanup on module load
|
|
atexit.register(cleanup_vllm)
|
|
|
|
|
|
def launch_vllm_server(
|
|
config: TrainingConfig,
|
|
model_path: str,
|
|
) -> Optional[subprocess.Popen]:
|
|
"""
|
|
Launch a vLLM server process using our custom vllm_api_server.py.
|
|
|
|
Uses the custom server instead of standard vLLM because:
|
|
- Streamlined API: Only /generate endpoint (provides logprobs)
|
|
- Weight bridge support: /bridge/* endpoints for shared memory mode
|
|
- LoRA hot-swap: /lora/* endpoints for adapter loading/unloading
|
|
|
|
Args:
|
|
config: Training configuration
|
|
model_path: Path to model checkpoint
|
|
|
|
Returns:
|
|
Popen process object, or None if launch failed
|
|
"""
|
|
global _vllm_process
|
|
|
|
# Check if port is in use and try to kill existing process
|
|
if is_port_in_use(config.vllm_port):
|
|
print(f" WARNING: Port {config.vllm_port} is already in use!")
|
|
if not kill_process_on_port(config.vllm_port):
|
|
print(
|
|
f" ERROR: Could not free port {config.vllm_port}. Please manually kill the process."
|
|
)
|
|
print(f" Try: lsof -i :{config.vllm_port} | grep LISTEN")
|
|
print(f" Or: pkill -f 'vllm.*{config.vllm_port}'")
|
|
return None
|
|
print(f" Successfully freed port {config.vllm_port}")
|
|
|
|
# Use our custom vllm_api_server.py
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
custom_server_path = os.path.join(script_dir, "vllm_api_server.py")
|
|
|
|
vllm_command = [
|
|
"python",
|
|
custom_server_path,
|
|
"--model",
|
|
model_path,
|
|
"--port",
|
|
str(config.vllm_port),
|
|
"--gpu-memory-utilization",
|
|
str(config.vllm_gpu_memory_utilization),
|
|
]
|
|
|
|
# Add served-model-name if using checkpoint path
|
|
if model_path != config.model_name:
|
|
vllm_command.extend(["--served-model-name", config.model_name])
|
|
|
|
print(f" Launching vLLM: {' '.join(vllm_command)}")
|
|
|
|
try:
|
|
proc = subprocess.Popen(vllm_command)
|
|
print(f" vLLM launched with PID: {proc.pid}")
|
|
|
|
# Check for immediate startup errors
|
|
try:
|
|
proc.communicate(timeout=2)
|
|
if proc.returncode is not None and proc.returncode != 0:
|
|
print(" WARNING: vLLM failed to start")
|
|
return None
|
|
except subprocess.TimeoutExpired:
|
|
print(" vLLM process started (check logs for details)")
|
|
|
|
_vllm_process = proc
|
|
return proc
|
|
|
|
except FileNotFoundError:
|
|
print(" ERROR: vLLM not found. Is it installed?")
|
|
return None
|
|
except Exception as e:
|
|
print(f" ERROR launching vLLM: {e}")
|
|
return None
|
|
|
|
|
|
def terminate_vllm_process() -> None:
|
|
"""Terminate the running vLLM process if any."""
|
|
global _vllm_process
|
|
|
|
if _vllm_process is None:
|
|
return
|
|
|
|
print(" Terminating vLLM process...")
|
|
_vllm_process.terminate()
|
|
try:
|
|
_vllm_process.wait(timeout=5)
|
|
except subprocess.TimeoutExpired:
|
|
print(" vLLM did not terminate gracefully, killing...")
|
|
_vllm_process.kill()
|
|
_vllm_process.wait()
|
|
_vllm_process = None
|
|
|
|
|
|
def check_vllm_process_health() -> None:
|
|
"""Check if vLLM process terminated unexpectedly."""
|
|
global _vllm_process
|
|
|
|
if _vllm_process is not None and _vllm_process.poll() is not None:
|
|
print(
|
|
f" WARNING: vLLM terminated unexpectedly (code: {_vllm_process.returncode})"
|
|
)
|
|
_vllm_process = None
|
|
|
|
|
|
def get_vllm_process() -> Optional[subprocess.Popen]:
|
|
"""Get the current vLLM process."""
|
|
return _vllm_process
|
|
|
|
|
|
def set_vllm_process(proc: Optional[subprocess.Popen]) -> None:
|
|
"""Set the vLLM process (for external management)."""
|
|
global _vllm_process
|
|
_vllm_process = proc
|
|
|
|
|
|
def check_vllm_health(port: int) -> bool:
|
|
"""
|
|
Check if vLLM server is healthy and responding.
|
|
|
|
Args:
|
|
port: Port the vLLM server is running on
|
|
|
|
Returns:
|
|
True if server is healthy
|
|
"""
|
|
try:
|
|
response = requests.get(f"http://localhost:{port}/health", timeout=5)
|
|
return response.status_code == 200
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def wait_for_vllm_ready(port: int, timeout: float = 120.0) -> bool:
|
|
"""
|
|
Wait for vLLM server to be ready.
|
|
|
|
Args:
|
|
port: Port the vLLM server is running on
|
|
timeout: Maximum time to wait in seconds
|
|
|
|
Returns:
|
|
True if server is ready, False if timeout
|
|
"""
|
|
print(f" Waiting for vLLM to be ready (port {port})...")
|
|
start_time = time.time()
|
|
|
|
while time.time() - start_time < timeout:
|
|
if check_vllm_health(port):
|
|
print(" vLLM is ready!")
|
|
return True
|
|
time.sleep(2)
|
|
|
|
print(f" WARNING: vLLM not ready after {timeout}s")
|
|
return False
|
|
|
|
|
|
def hotswap_lora_adapter(
|
|
adapter_name: str,
|
|
adapter_path: str,
|
|
port: int,
|
|
) -> bool:
|
|
"""
|
|
Hot-swap a LoRA adapter on a running vLLM server.
|
|
|
|
Uses the vLLM /v1/load_lora_adapter endpoint to load a new adapter
|
|
without restarting the server.
|
|
|
|
Args:
|
|
adapter_name: Name to identify the adapter
|
|
adapter_path: Path to the adapter checkpoint
|
|
port: vLLM server port
|
|
|
|
Returns:
|
|
True if hot-swap succeeded
|
|
"""
|
|
try:
|
|
# Use vLLM's native LoRA loading endpoint
|
|
response = requests.post(
|
|
f"http://localhost:{port}/v1/load_lora_adapter",
|
|
json={
|
|
"lora_name": adapter_name,
|
|
"lora_path": adapter_path,
|
|
},
|
|
timeout=30,
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
print(f" [LORA] ✓ Hot-swapped adapter: {adapter_name} ({adapter_path})")
|
|
return True
|
|
else:
|
|
print(
|
|
f" [LORA] ✗ Hot-swap failed: {response.status_code} - {response.text}"
|
|
)
|
|
return False
|
|
|
|
except requests.exceptions.ConnectionError:
|
|
print(f" [LORA] ✗ Cannot connect to vLLM at port {port}")
|
|
return False
|
|
except Exception as e:
|
|
print(f" [LORA] ✗ Error during hot-swap: {e}")
|
|
return False
|