mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
hot swap adapter
This commit is contained in:
parent
c86b36844b
commit
347f9ea363
3 changed files with 121 additions and 86 deletions
|
|
@ -18,6 +18,7 @@ any vLLM imports. The patches MUST happen before vLLM caches module references.
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
# Flag to track if patches have been applied
|
||||
|
|
@ -25,12 +26,91 @@ _PATCHES_APPLIED = False
|
|||
_PATCHED_RUNNER_CLASS = None
|
||||
|
||||
|
||||
def _patch_lora_triton_for_blackwell() -> bool:
|
||||
"""
|
||||
Patch vLLM's LoRA Triton kernels to disable GDC (Grid Dependency Control).
|
||||
|
||||
GDC is a Blackwell-specific feature that causes Triton compilation to fail
|
||||
on B200 GPUs. This patches the kernel_utils.py to disable GDC.
|
||||
|
||||
Returns True if patch was applied successfully.
|
||||
"""
|
||||
try:
|
||||
import vllm
|
||||
vllm_path = vllm.__path__[0]
|
||||
kernel_utils_path = f"{vllm_path}/lora/ops/triton_ops/kernel_utils.py"
|
||||
|
||||
# Check if file exists
|
||||
if not os.path.exists(kernel_utils_path):
|
||||
print("[vLLM Patch] LoRA kernel_utils.py not found, skipping GDC patch")
|
||||
return False
|
||||
|
||||
with open(kernel_utils_path, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
# Check if already patched
|
||||
if 'PATCHED FOR B200' in content:
|
||||
print("[vLLM Patch] LoRA GDC already patched for B200")
|
||||
return True
|
||||
|
||||
modified = False
|
||||
|
||||
# Patch USE_GDC = True -> False
|
||||
if 'USE_GDC = True' in content:
|
||||
content = content.replace(
|
||||
'USE_GDC = True',
|
||||
'USE_GDC = False # PATCHED FOR B200 - GDC causes Triton compilation failure'
|
||||
)
|
||||
modified = True
|
||||
|
||||
# Patch USE_GDC: tl.constexpr = True -> False
|
||||
if 'USE_GDC: tl.constexpr = True' in content:
|
||||
content = content.replace(
|
||||
'USE_GDC: tl.constexpr = True',
|
||||
'USE_GDC: tl.constexpr = False # PATCHED FOR B200'
|
||||
)
|
||||
modified = True
|
||||
|
||||
# Patch the gdc_wait call itself
|
||||
if 'tl.extra.cuda.gdc_wait()' in content:
|
||||
content = content.replace(
|
||||
'tl.extra.cuda.gdc_wait()',
|
||||
'pass # tl.extra.cuda.gdc_wait() PATCHED FOR B200 - disabled'
|
||||
)
|
||||
modified = True
|
||||
|
||||
if modified:
|
||||
with open(kernel_utils_path, 'w') as f:
|
||||
f.write(content)
|
||||
print(f"[vLLM Patch] ✓ Patched LoRA Triton GDC in {kernel_utils_path}")
|
||||
|
||||
# Clear Triton cache to force recompilation
|
||||
triton_cache = os.path.expanduser("~/.triton/cache")
|
||||
if os.path.exists(triton_cache):
|
||||
try:
|
||||
shutil.rmtree(triton_cache)
|
||||
print("[vLLM Patch] ✓ Cleared Triton cache")
|
||||
except Exception as e:
|
||||
print(f"[vLLM Patch] Warning: Could not clear Triton cache: {e}")
|
||||
|
||||
return True
|
||||
else:
|
||||
print("[vLLM Patch] No GDC patterns found to patch")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"[vLLM Patch] Warning: Could not patch LoRA GDC: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def apply_patches() -> bool:
|
||||
"""
|
||||
Apply patches to vLLM's GPUModelRunner in ALL locations.
|
||||
|
||||
This must be called BEFORE importing vLLM's engine classes.
|
||||
Safe to call multiple times (idempotent).
|
||||
|
||||
Also patches LoRA Triton kernels to disable GDC for B200 compatibility.
|
||||
|
||||
Returns True if patches were applied successfully.
|
||||
|
||||
|
|
@ -49,6 +129,9 @@ def apply_patches() -> bool:
|
|||
|
||||
if _PATCHES_APPLIED:
|
||||
return True
|
||||
|
||||
# First, patch LoRA Triton for B200 compatibility
|
||||
_patch_lora_triton_for_blackwell()
|
||||
|
||||
try:
|
||||
# Import the source module and get original class
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue