mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
452 lines
17 KiB
Python
452 lines
17 KiB
Python
"""
|
|
Patched GPU Model Runner - Enables CUDA IPC for single-copy training.
|
|
|
|
This patches vLLM's GPUModelRunner to:
|
|
1. Call share_memory_() on model weights after loading
|
|
2. Export CUDA IPC handles to vllm_bridge_config.json
|
|
|
|
The key insight is that CUDA IPC handles allow the trainer process to
|
|
attach to the EXACT SAME GPU memory that vLLM uses. This means:
|
|
- ONE copy of model weights in GPU memory
|
|
- Trainer's optimizer.step() updates vLLM's weights directly
|
|
- No synchronization needed - vLLM immediately sees new weights
|
|
|
|
CRITICAL: This module must be imported and apply_patches() called BEFORE
|
|
any vLLM imports. The patches MUST happen before vLLM caches module references.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import shutil
|
|
import sys
|
|
|
|
# Flag to track if patches have been applied
|
|
_PATCHES_APPLIED = False
|
|
_PATCHED_RUNNER_CLASS = None
|
|
|
|
|
|
def _patch_lora_triton_for_blackwell() -> bool:
|
|
"""
|
|
Patch vLLM's LoRA Triton kernels to disable GDC (Grid Dependency Control).
|
|
|
|
GDC is a Blackwell-specific feature that causes Triton compilation to fail
|
|
on B200 GPUs. This patches the kernel_utils.py to disable GDC.
|
|
|
|
Returns True if patch was applied successfully.
|
|
"""
|
|
try:
|
|
import vllm
|
|
|
|
vllm_path = vllm.__path__[0]
|
|
kernel_utils_path = f"{vllm_path}/lora/ops/triton_ops/kernel_utils.py"
|
|
|
|
# Check if file exists
|
|
if not os.path.exists(kernel_utils_path):
|
|
print("[vLLM Patch] LoRA kernel_utils.py not found, skipping GDC patch")
|
|
return False
|
|
|
|
with open(kernel_utils_path, "r") as f:
|
|
content = f.read()
|
|
|
|
# Check if already patched
|
|
if "PATCHED FOR B200" in content:
|
|
print("[vLLM Patch] LoRA GDC already patched for B200")
|
|
return True
|
|
|
|
modified = False
|
|
|
|
# Patch USE_GDC = True -> False
|
|
if "USE_GDC = True" in content:
|
|
content = content.replace(
|
|
"USE_GDC = True",
|
|
"USE_GDC = False # PATCHED FOR B200 - GDC causes Triton compilation failure",
|
|
)
|
|
modified = True
|
|
|
|
# Patch USE_GDC: tl.constexpr = True -> False
|
|
if "USE_GDC: tl.constexpr = True" in content:
|
|
content = content.replace(
|
|
"USE_GDC: tl.constexpr = True",
|
|
"USE_GDC: tl.constexpr = False # PATCHED FOR B200",
|
|
)
|
|
modified = True
|
|
|
|
# Patch the gdc_wait call itself
|
|
if "tl.extra.cuda.gdc_wait()" in content:
|
|
content = content.replace(
|
|
"tl.extra.cuda.gdc_wait()",
|
|
"pass # tl.extra.cuda.gdc_wait() PATCHED FOR B200 - disabled",
|
|
)
|
|
modified = True
|
|
|
|
if modified:
|
|
with open(kernel_utils_path, "w") as f:
|
|
f.write(content)
|
|
print(f"[vLLM Patch] ✓ Patched LoRA Triton GDC in {kernel_utils_path}")
|
|
|
|
# Clear Triton cache to force recompilation
|
|
triton_cache = os.path.expanduser("~/.triton/cache")
|
|
if os.path.exists(triton_cache):
|
|
try:
|
|
shutil.rmtree(triton_cache)
|
|
print("[vLLM Patch] ✓ Cleared Triton cache")
|
|
except Exception as e:
|
|
print(f"[vLLM Patch] Warning: Could not clear Triton cache: {e}")
|
|
|
|
return True
|
|
else:
|
|
print("[vLLM Patch] No GDC patterns found to patch")
|
|
return False
|
|
|
|
except Exception as e:
|
|
print(f"[vLLM Patch] Warning: Could not patch LoRA GDC: {e}")
|
|
return False
|
|
|
|
|
|
def apply_patches() -> bool:
|
|
"""
|
|
Apply patches to vLLM's GPUModelRunner in ALL locations.
|
|
|
|
This must be called BEFORE importing vLLM's engine classes.
|
|
Safe to call multiple times (idempotent).
|
|
|
|
Also patches LoRA Triton kernels to disable GDC for B200 compatibility.
|
|
|
|
Returns True if patches were applied successfully.
|
|
|
|
Usage:
|
|
# CRITICAL: Import and call BEFORE any vLLM imports!
|
|
import os
|
|
os.environ["VLLM_ENABLE_SHARED_WEIGHTS"] = "1"
|
|
|
|
from example_trainer.vllm_patching import apply_patches
|
|
apply_patches()
|
|
|
|
# Now import vLLM
|
|
from vllm import AsyncLLM # Uses patched runner
|
|
"""
|
|
global _PATCHES_APPLIED, _PATCHED_RUNNER_CLASS
|
|
|
|
if _PATCHES_APPLIED:
|
|
return True
|
|
|
|
# First, patch LoRA Triton for B200 compatibility
|
|
_patch_lora_triton_for_blackwell()
|
|
|
|
try:
|
|
# Import the source module and get original class
|
|
import vllm.v1.worker.gpu_model_runner as gpu_model_runner_module
|
|
from vllm.v1.worker.gpu_model_runner import GPUModelRunner as OriginalRunner
|
|
|
|
# Create the patched class
|
|
PatchedRunner = _create_patched_runner(OriginalRunner)
|
|
_PATCHED_RUNNER_CLASS = PatchedRunner
|
|
|
|
# =================================================================
|
|
# PATCH 1: Replace in source module
|
|
# =================================================================
|
|
gpu_model_runner_module.GPUModelRunner = PatchedRunner
|
|
print("[vLLM Patch] ✓ Patched vllm.v1.worker.gpu_model_runner.GPUModelRunner")
|
|
|
|
# =================================================================
|
|
# PATCH 2: Replace in gpu_worker module (main usage location)
|
|
# =================================================================
|
|
try:
|
|
import vllm.v1.worker.gpu_worker as gpu_worker_module
|
|
|
|
gpu_worker_module.GPUModelRunner = PatchedRunner
|
|
print("[vLLM Patch] ✓ Patched vllm.v1.worker.gpu_worker.GPUModelRunner")
|
|
except ImportError:
|
|
pass
|
|
|
|
# =================================================================
|
|
# PATCH 3: Update sys.modules entry for source module
|
|
# =================================================================
|
|
# This ensures new imports get the patched version
|
|
if "vllm.v1.worker.gpu_model_runner" in sys.modules:
|
|
sys.modules["vllm.v1.worker.gpu_model_runner"].GPUModelRunner = (
|
|
PatchedRunner
|
|
)
|
|
|
|
# =================================================================
|
|
# PATCH 4: Patch GPUWorker if already imported
|
|
# =================================================================
|
|
try:
|
|
if "vllm.v1.worker.gpu_worker" in sys.modules:
|
|
worker_module = sys.modules["vllm.v1.worker.gpu_worker"]
|
|
if hasattr(worker_module, "GPUWorker"):
|
|
# Update any class-level references
|
|
worker_module.GPUModelRunner = PatchedRunner
|
|
except Exception:
|
|
pass
|
|
|
|
_PATCHES_APPLIED = True
|
|
print("[vLLM Patch] ✓ GPUModelRunner patched for shared memory updates")
|
|
return True
|
|
|
|
except ImportError as e:
|
|
print(f"[vLLM Patch] Warning: Could not apply patches: {e}")
|
|
print("[vLLM Patch] This may be due to vLLM version incompatibility")
|
|
print("[vLLM Patch] Shared memory updates will not be available")
|
|
return False
|
|
except Exception as e:
|
|
print(f"[vLLM Patch] Error applying patches: {e}")
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
|
|
def _create_patched_runner(BaseRunner: type) -> type:
|
|
"""
|
|
Create a patched GPUModelRunner class.
|
|
|
|
Returns a new class that inherits from the original and adds
|
|
CUDA IPC export functionality for single-copy training.
|
|
"""
|
|
|
|
class PatchedGPUModelRunner(BaseRunner):
|
|
"""
|
|
Patched GPUModelRunner that enables CUDA IPC for single-copy training.
|
|
|
|
After loading the model, this:
|
|
1. Calls share_memory_() on all parameters
|
|
2. Exports CUDA IPC handles to vllm_bridge_config.json
|
|
|
|
The trainer reads these IPC handles and attaches to the SAME
|
|
GPU memory, so optimizer.step() updates weights that vLLM
|
|
immediately sees for inference.
|
|
"""
|
|
|
|
_shared_memory_setup_done = False
|
|
|
|
def load_model(self, *args, **kwargs) -> None:
|
|
"""Load model and set up shared memory + update daemon."""
|
|
print("[vLLM Patch] PatchedGPUModelRunner.load_model() called!")
|
|
|
|
# Call original load_model
|
|
super().load_model(*args, **kwargs)
|
|
|
|
print("[vLLM Patch] Model loaded, checking shared weights setup...")
|
|
|
|
# Check if shared memory updates are enabled
|
|
enable_shared = os.environ.get("VLLM_ENABLE_SHARED_WEIGHTS", "0") == "1"
|
|
num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", "-1"))
|
|
|
|
print(
|
|
f"[vLLM Patch] VLLM_ENABLE_SHARED_WEIGHTS={enable_shared}, NUM_INFERENCE_NODES={num_inference_nodes}"
|
|
)
|
|
|
|
if not enable_shared and num_inference_nodes < 0:
|
|
print(
|
|
"[vLLM Patch] Shared weights disabled (set VLLM_ENABLE_SHARED_WEIGHTS=1 to enable)"
|
|
)
|
|
return
|
|
|
|
if self._shared_memory_setup_done:
|
|
print("[vLLM Patch] Shared memory already set up, skipping")
|
|
return
|
|
|
|
print("[vLLM Patch] Setting up shared memory weight updates...", flush=True)
|
|
|
|
try:
|
|
self._setup_shared_memory()
|
|
PatchedGPUModelRunner._shared_memory_setup_done = True
|
|
print("[vLLM Patch] ✓ Shared memory setup complete!", flush=True)
|
|
print(
|
|
"[vLLM Patch] ✓ IPC handles exported - trainer can now attach!",
|
|
flush=True,
|
|
)
|
|
except Exception as e:
|
|
print(f"[vLLM Patch] ERROR in _setup_shared_memory: {e}", flush=True)
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
return
|
|
|
|
def _setup_shared_memory(self) -> None:
|
|
"""Move model tensors to shared memory and export param info."""
|
|
import json
|
|
from pathlib import Path
|
|
|
|
print("[vLLM Patch] _setup_shared_memory() starting...")
|
|
|
|
# Get state dict
|
|
state_dict = self.model.state_dict()
|
|
print(f"[vLLM Patch] Model has {len(state_dict)} parameters")
|
|
|
|
# Make entire model shareable via share_memory_() on each tensor
|
|
shared_count = 0
|
|
for key, val in state_dict.items():
|
|
try:
|
|
if val.is_cuda:
|
|
val.share_memory_()
|
|
shared_count += 1
|
|
except Exception as e:
|
|
print(f"[vLLM Patch] Warning: Could not share {key}: {e}")
|
|
|
|
print(f"[vLLM Patch] Called share_memory_() on {shared_count} CUDA tensors")
|
|
|
|
# Also try calling share_memory() on the model itself
|
|
try:
|
|
self.model.share_memory()
|
|
print("[vLLM Patch] Called model.share_memory()")
|
|
except Exception as e:
|
|
print(f"[vLLM Patch] Note: model.share_memory() not available: {e}")
|
|
|
|
# Export parameter info to JSON for trainer
|
|
# Allow explicit config path via env var, otherwise use LOGDIR
|
|
config_path = os.environ.get("VLLM_BRIDGE_CONFIG_PATH")
|
|
if config_path:
|
|
json_path = Path(config_path)
|
|
json_path.parent.mkdir(parents=True, exist_ok=True)
|
|
else:
|
|
log_dir = os.environ.get("LOGDIR", ".")
|
|
Path(log_dir).mkdir(parents=True, exist_ok=True)
|
|
json_path = Path(log_dir) / "vllm_bridge_config.json"
|
|
|
|
param_mappings = {}
|
|
param_names = []
|
|
ipc_handles = {}
|
|
|
|
for name, tensor in state_dict.items():
|
|
param_mappings[name] = {
|
|
"vllm_name": name,
|
|
"shape": list(tensor.shape),
|
|
"dtype": str(tensor.dtype),
|
|
"device": str(tensor.device),
|
|
}
|
|
param_names.append(name)
|
|
|
|
# Export CUDA IPC handles for true single-copy mode
|
|
if tensor.is_cuda:
|
|
try:
|
|
import base64
|
|
|
|
storage = tensor.untyped_storage()
|
|
share_data = storage._share_cuda_()
|
|
|
|
# share_data is a tuple of 8 items - we need ALL of them:
|
|
# [0] = device index (int)
|
|
# [1] = cudaIpcMemHandle_t (bytes)
|
|
# [2] = storage size (int)
|
|
# [3] = storage offset in original (int)
|
|
# [4] = ref counter handle (bytes - filename)
|
|
# [5] = ref counter offset (int)
|
|
# [6] = event handle (bytes)
|
|
# [7] = event sync required (bool)
|
|
|
|
ipc_handles[name] = {
|
|
"device_index": share_data[0],
|
|
"ipc_handle_b64": base64.b64encode(share_data[1]).decode(
|
|
"ascii"
|
|
),
|
|
"storage_size": share_data[2],
|
|
"storage_offset_orig": share_data[3],
|
|
"ref_counter_handle_b64": base64.b64encode(
|
|
share_data[4]
|
|
).decode("ascii"),
|
|
"ref_counter_offset": share_data[5],
|
|
"event_handle_b64": base64.b64encode(share_data[6]).decode(
|
|
"ascii"
|
|
),
|
|
"event_sync_required": share_data[7],
|
|
# Tensor metadata for reconstruction
|
|
"tensor_storage_offset": tensor.storage_offset(),
|
|
"shape": list(tensor.shape),
|
|
"stride": list(tensor.stride()),
|
|
"dtype": str(tensor.dtype),
|
|
}
|
|
except Exception as e:
|
|
print(
|
|
f"[vLLM Patch] Could not get IPC handle for {name}: {e}",
|
|
flush=True,
|
|
)
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
|
|
print(
|
|
f"[vLLM Patch] Exported {len(ipc_handles)} IPC handles for single-copy mode",
|
|
flush=True,
|
|
)
|
|
|
|
# Get model info
|
|
model_name = "unknown"
|
|
tp_degree = 1
|
|
try:
|
|
model_name = str(self.model_config.model)
|
|
tp_degree = self.parallel_config.tensor_parallel_size
|
|
except Exception as e:
|
|
print(f"[vLLM Patch] Warning: Could not get model config: {e}")
|
|
|
|
import base64
|
|
|
|
# Convert bytes to base64 for JSON serialization
|
|
def serialize_ipc_handles(handles):
|
|
result = {}
|
|
for k, v in handles.items():
|
|
if isinstance(v, bytes):
|
|
result[k] = {"_bytes_b64_": base64.b64encode(v).decode("ascii")}
|
|
elif isinstance(v, dict):
|
|
result[k] = serialize_ipc_handles(v)
|
|
else:
|
|
result[k] = v
|
|
return result
|
|
|
|
serialized_ipc_handles = (
|
|
serialize_ipc_handles(ipc_handles) if ipc_handles else {}
|
|
)
|
|
|
|
info = {
|
|
"model": model_name,
|
|
"tp_degree": tp_degree,
|
|
"dp_shard_degree": 1,
|
|
"param_mappings": param_mappings,
|
|
"param_names": sorted(param_names),
|
|
"ipc_handles": serialized_ipc_handles,
|
|
"shared_weights_enabled": True,
|
|
"num_params": len(param_names),
|
|
"single_copy_enabled": len(ipc_handles) > 0,
|
|
}
|
|
|
|
try:
|
|
with open(json_path, "w") as f:
|
|
json.dump(info, f, indent=2)
|
|
print(
|
|
f"[vLLM Patch] ✓ Exported {len(param_mappings)} params to {json_path}"
|
|
)
|
|
except Exception as e:
|
|
print(f"[vLLM Patch] ERROR: Failed to export params: {e}")
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
|
|
# Set proper class name
|
|
PatchedGPUModelRunner.__name__ = "PatchedGPUModelRunner"
|
|
PatchedGPUModelRunner.__qualname__ = "PatchedGPUModelRunner"
|
|
|
|
return PatchedGPUModelRunner
|
|
|
|
|
|
def get_patched_runner() -> type | None:
|
|
"""Get the patched runner class if patches have been applied."""
|
|
return _PATCHED_RUNNER_CLASS
|
|
|
|
|
|
def is_patched() -> bool:
|
|
"""Check if patches have been applied."""
|
|
return _PATCHES_APPLIED
|
|
|
|
|
|
# Placeholder class for type checking
|
|
class PatchedGPUModelRunner:
|
|
"""
|
|
Placeholder class for type checking.
|
|
|
|
The actual patched class is created dynamically by _create_patched_runner()
|
|
to properly inherit from vLLM's GPUModelRunner.
|
|
"""
|
|
|
|
pass
|