mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
ccdd5a1ca6
commit
60fb6cae11
11 changed files with 221 additions and 136 deletions
|
|
@ -735,19 +735,23 @@ def train_lora_restart(config: TrainingConfig):
|
|||
# Periodic adapter save + vLLM restart
|
||||
sync_time = 0
|
||||
should_sync = (step + 1) % config.vllm_restart_interval == 0
|
||||
if should_sync and (step + 1) < config.training_steps: # Don't restart on last step
|
||||
if (
|
||||
should_sync and (step + 1) < config.training_steps
|
||||
): # Don't restart on last step
|
||||
sync_start = time.time()
|
||||
|
||||
|
||||
# Save new adapter
|
||||
current_adapter_path = save_lora_checkpoint(model, config.save_path, step + 1)
|
||||
|
||||
current_adapter_path = save_lora_checkpoint(
|
||||
model, config.save_path, step + 1
|
||||
)
|
||||
|
||||
# Restart vLLM with new adapter
|
||||
print(" [RESTART] Restarting vLLM with new adapter...")
|
||||
_terminate_vllm(vllm_proc, config.vllm_port)
|
||||
vllm_proc = _launch_vllm_with_lora(config, current_adapter_path)
|
||||
if vllm_proc is None:
|
||||
raise RuntimeError("Failed to restart vLLM")
|
||||
|
||||
|
||||
sync_time = time.time() - sync_start
|
||||
benchmark_stats["sync_times"].append(sync_time)
|
||||
benchmark_stats["restart_times"].append(sync_time)
|
||||
|
|
@ -798,45 +802,53 @@ def train_lora_restart(config: TrainingConfig):
|
|||
_vllm_restart_counter = 0
|
||||
|
||||
|
||||
def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optional[subprocess.Popen]:
|
||||
def _launch_vllm_with_lora(
|
||||
config: TrainingConfig, adapter_path: str
|
||||
) -> Optional[subprocess.Popen]:
|
||||
"""
|
||||
Launch vLLM with a LoRA adapter (no --enforce-eager for faster inference).
|
||||
|
||||
|
||||
Unlike lora_only mode, this does NOT use --enforce-eager, so we get
|
||||
~108 TPS instead of ~13 TPS (8x faster).
|
||||
"""
|
||||
global _vllm_restart_counter
|
||||
from .vllm_manager import kill_process_on_port, wait_for_vllm_ready
|
||||
|
||||
|
||||
# Kill any existing process on the port
|
||||
print(f" Cleaning up port {config.vllm_port}...")
|
||||
kill_process_on_port(config.vllm_port)
|
||||
|
||||
|
||||
# Clear CUDA cache before starting new vLLM
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
# Wait for port and GPU memory to be fully released
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
# Find the vllm_api_server.py script
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
server_script = os.path.join(script_dir, "vllm_api_server.py")
|
||||
|
||||
|
||||
# Build command - NO --enforce-eager for faster inference (~108 TPS vs ~13 TPS)
|
||||
cmd = [
|
||||
sys.executable, server_script,
|
||||
"--model", config.model_name,
|
||||
"--port", str(config.vllm_port),
|
||||
"--gpu-memory-utilization", str(config.vllm_gpu_memory_utilization),
|
||||
"--max-model-len", str(config.max_model_len),
|
||||
sys.executable,
|
||||
server_script,
|
||||
"--model",
|
||||
config.model_name,
|
||||
"--port",
|
||||
str(config.vllm_port),
|
||||
"--gpu-memory-utilization",
|
||||
str(config.vllm_gpu_memory_utilization),
|
||||
"--max-model-len",
|
||||
str(config.max_model_len),
|
||||
"--enable-lora",
|
||||
"--max-lora-rank", str(max(config.lora_r * 2, 32)),
|
||||
"--max-lora-rank",
|
||||
str(max(config.lora_r * 2, 32)),
|
||||
# Note: NOT adding --enforce-eager - this gives us ~8x faster inference!
|
||||
# Without --enforce-eager, vLLM can use more optimizations.
|
||||
]
|
||||
|
||||
|
||||
# Set environment for GPU selection
|
||||
env = os.environ.copy()
|
||||
if config.vllm_gpu is not None:
|
||||
|
|
@ -844,32 +856,39 @@ def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optiona
|
|||
print(f" GPU: {config.vllm_gpu} (via CUDA_VISIBLE_DEVICES)")
|
||||
else:
|
||||
print(" GPU: Same as trainer (inherited CUDA_VISIBLE_DEVICES)")
|
||||
|
||||
|
||||
print(f" Launching: {' '.join(cmd)}")
|
||||
print(f" Adapter: {adapter_path}")
|
||||
|
||||
|
||||
# Log vLLM output to file for debugging (unique file per restart)
|
||||
vllm_log_path = os.path.join(config.save_path, f"vllm_restart_{_vllm_restart_counter}.log")
|
||||
vllm_log_path = os.path.join(
|
||||
config.save_path, f"vllm_restart_{_vllm_restart_counter}.log"
|
||||
)
|
||||
_vllm_restart_counter += 1
|
||||
print(f" vLLM log: {vllm_log_path}")
|
||||
|
||||
|
||||
try:
|
||||
vllm_log_file = open(vllm_log_path, "w")
|
||||
# Start in new session so we can kill entire process group later
|
||||
proc = subprocess.Popen(
|
||||
cmd, env=env, stdout=vllm_log_file, stderr=subprocess.STDOUT,
|
||||
start_new_session=True # Creates new process group for easy cleanup
|
||||
cmd,
|
||||
env=env,
|
||||
stdout=vllm_log_file,
|
||||
stderr=subprocess.STDOUT,
|
||||
start_new_session=True, # Creates new process group for easy cleanup
|
||||
)
|
||||
print(f" vLLM PID: {proc.pid} (process group: {os.getpgid(proc.pid)})")
|
||||
print(" NOTE: vLLM without --enforce-eager compiles CUDA graphs on startup (takes 1-3 min)...")
|
||||
|
||||
print(
|
||||
" NOTE: vLLM without --enforce-eager compiles CUDA graphs on startup (takes 1-3 min)..."
|
||||
)
|
||||
|
||||
# Wait for server to be ready (longer timeout for CUDA graph compilation)
|
||||
if not wait_for_vllm_ready(config.vllm_port, timeout=300):
|
||||
print(" ERROR: vLLM failed to start after 300s")
|
||||
print(f" Check log: {vllm_log_path}")
|
||||
# Print last 30 lines of the log
|
||||
try:
|
||||
with open(vllm_log_path, 'r') as f:
|
||||
with open(vllm_log_path, "r") as f:
|
||||
lines = f.readlines()
|
||||
print(" Last 30 lines of vLLM log:")
|
||||
for line in lines[-30:]:
|
||||
|
|
@ -878,7 +897,7 @@ def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optiona
|
|||
print(f" Could not read log: {e}")
|
||||
proc.terminate()
|
||||
return None
|
||||
|
||||
|
||||
# Load the LoRA adapter
|
||||
print(" Loading LoRA adapter...")
|
||||
try:
|
||||
|
|
@ -890,13 +909,15 @@ def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optiona
|
|||
if resp.status_code == 200:
|
||||
print(" ✓ Adapter loaded successfully")
|
||||
else:
|
||||
print(f" WARNING: Adapter load returned {resp.status_code}: {resp.text}")
|
||||
print(
|
||||
f" WARNING: Adapter load returned {resp.status_code}: {resp.text}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f" WARNING: Could not load adapter: {e}")
|
||||
# Continue anyway - base model inference still works
|
||||
|
||||
|
||||
return proc
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f" ERROR: {e}")
|
||||
return None
|
||||
|
|
@ -906,12 +927,12 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None:
|
|||
"""Terminate a vLLM process and release GPU resources."""
|
||||
import signal
|
||||
import subprocess as sp
|
||||
|
||||
|
||||
print(f" Terminating vLLM on port {port}...")
|
||||
|
||||
|
||||
# Get current GPU device
|
||||
gpu_id = os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0]
|
||||
|
||||
|
||||
# Phase 1: Kill the process group if we have a handle (kills all children too)
|
||||
main_pid = None
|
||||
if proc is not None:
|
||||
|
|
@ -927,12 +948,13 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None:
|
|||
proc.wait(timeout=5)
|
||||
except Exception as e:
|
||||
print(f" Warning: {e}")
|
||||
|
||||
|
||||
# Phase 2: Kill by port (catches anything still running)
|
||||
from .vllm_manager import kill_process_on_port
|
||||
|
||||
kill_process_on_port(port)
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
# Phase 3: Aggressively kill ALL vLLM-related processes
|
||||
print(" Killing all vLLM-related processes...")
|
||||
kill_commands = [
|
||||
|
|
@ -948,19 +970,22 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None:
|
|||
sp.run(cmd, shell=True, capture_output=True, timeout=5)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# Phase 4: Use nvidia-smi to find and kill GPU processes (nuclear option)
|
||||
print(f" Checking for zombie GPU processes on GPU {gpu_id}...")
|
||||
try:
|
||||
result = sp.run(
|
||||
f"nvidia-smi --query-compute-apps=pid,used_memory --format=csv,noheader,nounits -i {gpu_id}",
|
||||
shell=True, capture_output=True, text=True, timeout=10
|
||||
shell=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.stdout.strip():
|
||||
print(f" Found GPU processes:\n{result.stdout}")
|
||||
for line in result.stdout.strip().split('\n'):
|
||||
for line in result.stdout.strip().split("\n"):
|
||||
if line.strip():
|
||||
parts = line.split(',')
|
||||
parts = line.split(",")
|
||||
if len(parts) >= 1:
|
||||
pid = parts[0].strip()
|
||||
# Don't kill the current Python process (trainer)
|
||||
|
|
@ -972,7 +997,7 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None:
|
|||
pass
|
||||
except Exception as e:
|
||||
print(f" Warning: nvidia-smi check failed: {e}")
|
||||
|
||||
|
||||
# Phase 5: Wait for GPU memory release - CRITICAL
|
||||
# The CUDA driver needs time to actually free memory after process death
|
||||
print(" Waiting for GPU memory release...")
|
||||
|
|
@ -982,24 +1007,26 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None:
|
|||
torch.cuda.empty_cache()
|
||||
free_mem = torch.cuda.mem_get_info()[0] / 1e9
|
||||
total_mem = torch.cuda.mem_get_info()[1] / 1e9
|
||||
print(f" [{(i+1)*5}s] GPU memory: {free_mem:.1f}/{total_mem:.1f} GB free ({100*free_mem/total_mem:.0f}%)")
|
||||
print(
|
||||
f" [{(i+1)*5}s] GPU memory: {free_mem:.1f}/{total_mem:.1f} GB free ({100*free_mem/total_mem:.0f}%)"
|
||||
)
|
||||
# If we have enough memory (>50% free), break early
|
||||
if free_mem > total_mem * 0.5:
|
||||
print(f" ✓ Sufficient memory available ({free_mem:.1f} GB)")
|
||||
break
|
||||
|
||||
|
||||
# Final cleanup
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
free_mem = torch.cuda.mem_get_info()[0] / 1e9
|
||||
total_mem = torch.cuda.mem_get_info()[1] / 1e9
|
||||
print(f" Final GPU memory: {free_mem:.1f}/{total_mem:.1f} GB free ({100*free_mem/total_mem:.0f}%)")
|
||||
|
||||
print(
|
||||
f" Final GPU memory: {free_mem:.1f}/{total_mem:.1f} GB free ({100*free_mem/total_mem:.0f}%)"
|
||||
)
|
||||
|
||||
if free_mem < total_mem * 0.3:
|
||||
print(" WARNING: Low GPU memory! May fail to restart vLLM.")
|
||||
print(" Consider reducing --vllm-gpu-memory-utilization")
|
||||
|
||||
|
||||
print(" vLLM terminated")
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue