hot swap adapter

This commit is contained in:
Jai Suphavadeeprasit 2026-01-20 22:11:53 -05:00
parent c86b36844b
commit 347f9ea363
3 changed files with 121 additions and 86 deletions

View file

@ -18,6 +18,7 @@ any vLLM imports. The patches MUST happen before vLLM caches module references.
from __future__ import annotations
import os
import shutil
import sys
# Flag to track if patches have been applied
@ -25,12 +26,91 @@ _PATCHES_APPLIED = False
_PATCHED_RUNNER_CLASS = None
def _patch_lora_triton_for_blackwell() -> bool:
"""
Patch vLLM's LoRA Triton kernels to disable GDC (Grid Dependency Control).
GDC is a Blackwell-specific feature that causes Triton compilation to fail
on B200 GPUs. This patches the kernel_utils.py to disable GDC.
Returns True if patch was applied successfully.
"""
try:
import vllm
vllm_path = vllm.__path__[0]
kernel_utils_path = f"{vllm_path}/lora/ops/triton_ops/kernel_utils.py"
# Check if file exists
if not os.path.exists(kernel_utils_path):
print("[vLLM Patch] LoRA kernel_utils.py not found, skipping GDC patch")
return False
with open(kernel_utils_path, 'r') as f:
content = f.read()
# Check if already patched
if 'PATCHED FOR B200' in content:
print("[vLLM Patch] LoRA GDC already patched for B200")
return True
modified = False
# Patch USE_GDC = True -> False
if 'USE_GDC = True' in content:
content = content.replace(
'USE_GDC = True',
'USE_GDC = False # PATCHED FOR B200 - GDC causes Triton compilation failure'
)
modified = True
# Patch USE_GDC: tl.constexpr = True -> False
if 'USE_GDC: tl.constexpr = True' in content:
content = content.replace(
'USE_GDC: tl.constexpr = True',
'USE_GDC: tl.constexpr = False # PATCHED FOR B200'
)
modified = True
# Patch the gdc_wait call itself
if 'tl.extra.cuda.gdc_wait()' in content:
content = content.replace(
'tl.extra.cuda.gdc_wait()',
'pass # tl.extra.cuda.gdc_wait() PATCHED FOR B200 - disabled'
)
modified = True
if modified:
with open(kernel_utils_path, 'w') as f:
f.write(content)
print(f"[vLLM Patch] ✓ Patched LoRA Triton GDC in {kernel_utils_path}")
# Clear Triton cache to force recompilation
triton_cache = os.path.expanduser("~/.triton/cache")
if os.path.exists(triton_cache):
try:
shutil.rmtree(triton_cache)
print("[vLLM Patch] ✓ Cleared Triton cache")
except Exception as e:
print(f"[vLLM Patch] Warning: Could not clear Triton cache: {e}")
return True
else:
print("[vLLM Patch] No GDC patterns found to patch")
return False
except Exception as e:
print(f"[vLLM Patch] Warning: Could not patch LoRA GDC: {e}")
return False
def apply_patches() -> bool:
"""
Apply patches to vLLM's GPUModelRunner in ALL locations.
This must be called BEFORE importing vLLM's engine classes.
Safe to call multiple times (idempotent).
Also patches LoRA Triton kernels to disable GDC for B200 compatibility.
Returns True if patches were applied successfully.
@ -49,6 +129,9 @@ def apply_patches() -> bool:
if _PATCHES_APPLIED:
return True
# First, patch LoRA Triton for B200 compatibility
_patch_lora_triton_for_blackwell()
try:
# Import the source module and get original class