This commit is contained in:
Jai Suphavadeeprasit 2026-02-12 13:56:07 -05:00
parent f29a3d04fa
commit 676593de73
2 changed files with 77 additions and 20 deletions

View file

@ -729,7 +729,7 @@ def train_lora_restart(config: TrainingConfig):
# Check Atropos API
if not check_atropos_api(url=config.atropos_url, timeout=30):
_terminate_vllm(vllm_proc)
_terminate_vllm(vllm_proc, config.vllm_port)
raise RuntimeError(f"Atropos API not reachable at {config.atropos_url}")
register_trainer(config)
@ -799,7 +799,7 @@ def train_lora_restart(config: TrainingConfig):
# Restart vLLM with new adapter
print(f" [RESTART] Restarting vLLM with new adapter...")
_terminate_vllm(vllm_proc)
_terminate_vllm(vllm_proc, config.vllm_port)
vllm_proc = _launch_vllm_with_lora(config, current_adapter_path)
if vllm_proc is None:
raise RuntimeError("Failed to restart vLLM")
@ -832,7 +832,7 @@ def train_lora_restart(config: TrainingConfig):
benchmark_stats["sync_times"].append(final_sync_time)
# Terminate vLLM
_terminate_vllm(vllm_proc)
_terminate_vllm(vllm_proc, config.vllm_port)
finalize_training(
use_wandb,
@ -850,6 +850,10 @@ def train_lora_restart(config: TrainingConfig):
print(f"Final adapter saved to {final_adapter_path}")
# Global counter for vLLM restarts (for unique log files)
_vllm_restart_counter = 0
def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optional[subprocess.Popen]:
"""
Launch vLLM with a LoRA adapter (no --enforce-eager for faster inference).
@ -857,11 +861,20 @@ def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optiona
Unlike lora_only mode, this does NOT use --enforce-eager, so we get
~108 TPS instead of ~13 TPS (8x faster).
"""
global _vllm_restart_counter
from .vllm_manager import kill_process_on_port, wait_for_vllm_ready
# Kill any existing process on the port
print(f" Cleaning up port {config.vllm_port}...")
kill_process_on_port(config.vllm_port)
time.sleep(2) # Wait for port to be fully released
# Clear CUDA cache before starting new vLLM
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Wait for port and GPU memory to be fully released
time.sleep(5)
# Find the vllm_api_server.py script
script_dir = os.path.dirname(os.path.abspath(__file__))
@ -891,8 +904,9 @@ def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optiona
print(f" Launching: {' '.join(cmd)}")
print(f" Adapter: {adapter_path}")
# Log vLLM output to file for debugging
vllm_log_path = os.path.join(config.save_path, "vllm_internal.log")
# Log vLLM output to file for debugging (unique file per restart)
vllm_log_path = os.path.join(config.save_path, f"vllm_restart_{_vllm_restart_counter}.log")
_vllm_restart_counter += 1
print(f" vLLM log: {vllm_log_path}")
try:
@ -901,9 +915,19 @@ def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optiona
print(f" vLLM PID: {proc.pid}")
print(f" NOTE: vLLM without --enforce-eager compiles CUDA graphs on startup (takes 1-3 min)...")
# Wait for server to be ready
if not wait_for_vllm_ready(config.vllm_port, timeout=180):
print(" ERROR: vLLM failed to start")
# Wait for server to be ready (longer timeout for CUDA graph compilation)
if not wait_for_vllm_ready(config.vllm_port, timeout=300):
print(" ERROR: vLLM failed to start after 300s")
print(f" Check log: {vllm_log_path}")
# Print last 30 lines of the log
try:
with open(vllm_log_path, 'r') as f:
lines = f.readlines()
print(" Last 30 lines of vLLM log:")
for line in lines[-30:]:
print(f" {line.rstrip()}")
except Exception as e:
print(f" Could not read log: {e}")
proc.terminate()
return None
@ -930,18 +954,51 @@ def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optiona
return None
def _terminate_vllm(proc: Optional[subprocess.Popen]) -> None:
"""Terminate a vLLM process."""
if proc is None:
return
def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None:
"""Terminate a vLLM process and release GPU resources."""
import signal
print(f" Terminating vLLM...")
# Kill by port first (catches all child processes)
from .vllm_manager import kill_process_on_port
kill_process_on_port(port)
if proc is not None:
print(f" Killing main process (PID: {proc.pid})...")
try:
# Send SIGKILL immediately - vLLM doesn't gracefully shutdown well
proc.kill()
proc.wait(timeout=10)
except Exception as e:
print(f" Warning: {e}")
# Kill any remaining vLLM-related processes on this port
# Use pkill to catch any orphaned child processes
try:
proc.terminate()
proc.wait(timeout=10)
except subprocess.TimeoutExpired:
proc.kill()
proc.wait()
import subprocess as sp
# Kill by port using fuser
sp.run(f"fuser -k {port}/tcp", shell=True, capture_output=True, timeout=5)
# Also kill any vllm processes that might be orphaned
sp.run("pkill -9 -f 'vllm.*EngineCore'", shell=True, capture_output=True, timeout=5)
except Exception:
pass
# Wait for GPU memory to be released by the OS
print(" Waiting for GPU memory release (10s)...")
time.sleep(10)
# Clear CUDA cache in this process (won't affect other processes but good hygiene)
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Verify GPU memory is free
if torch.cuda.is_available():
free_mem = torch.cuda.mem_get_info()[0] / 1e9
total_mem = torch.cuda.mem_get_info()[1] / 1e9
print(f" ✓ GPU memory: {free_mem:.1f}/{total_mem:.1f} GB free")
print(" ✓ vLLM terminated")