patching problem

This commit is contained in:
Jai Suphavadeeprasit 2026-01-13 00:09:40 -05:00
parent 80d2608c4e
commit fad8e77be2
4 changed files with 234 additions and 76 deletions

View file

@ -8,16 +8,17 @@ This patches vLLM's GPUModelRunner to:
The key insight is that share_memory_() makes tensors accessible from
multiple processes. The daemon receives updates via NCCL and copies them
directly into the shared tensors, which vLLM reads for inference.
CRITICAL: This module must be imported and apply_patches() called BEFORE
any vLLM imports. The patches MUST happen before vLLM caches module references.
"""
from __future__ import annotations
import os
import sys
from typing import TYPE_CHECKING
import torch
import torch.multiprocessing as mp
# Lazy imports to avoid circular dependencies
if TYPE_CHECKING:
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
@ -25,42 +26,92 @@ if TYPE_CHECKING:
# Flag to track if patches have been applied
_PATCHES_APPLIED = False
_PATCHED_RUNNER_CLASS = None
def apply_patches() -> None:
def apply_patches() -> bool:
"""
Apply patches to vLLM's GPUModelRunner.
Apply patches to vLLM's GPUModelRunner in ALL locations.
This must be called BEFORE importing vLLM's engine classes.
Safe to call multiple times (idempotent).
Returns True if patches were applied successfully.
Usage:
# CRITICAL: Import and call BEFORE any vLLM imports!
import os
os.environ["VLLM_ENABLE_SHARED_WEIGHTS"] = "1"
from example_trainer.vllm_patching import apply_patches
apply_patches()
from vllm import AsyncLLM # Now uses patched runner
# Now import vLLM
from vllm import AsyncLLM # Uses patched runner
"""
global _PATCHES_APPLIED
global _PATCHES_APPLIED, _PATCHED_RUNNER_CLASS
if _PATCHES_APPLIED:
return
return True
try:
import vllm.v1.worker.gpu_worker
# Import the source module and get original class
import vllm.v1.worker.gpu_model_runner as gpu_model_runner_module
from vllm.v1.worker.gpu_model_runner import GPUModelRunner as OriginalRunner
# Create patched class
# Create the patched class
PatchedRunner = _create_patched_runner(OriginalRunner)
_PATCHED_RUNNER_CLASS = PatchedRunner
# Replace in vllm module
vllm.v1.worker.gpu_worker.GPUModelRunner = PatchedRunner
# =================================================================
# PATCH 1: Replace in source module
# =================================================================
gpu_model_runner_module.GPUModelRunner = PatchedRunner
print("[vLLM Patch] ✓ Patched vllm.v1.worker.gpu_model_runner.GPUModelRunner")
# =================================================================
# PATCH 2: Replace in gpu_worker module (main usage location)
# =================================================================
try:
import vllm.v1.worker.gpu_worker as gpu_worker_module
gpu_worker_module.GPUModelRunner = PatchedRunner
print("[vLLM Patch] ✓ Patched vllm.v1.worker.gpu_worker.GPUModelRunner")
except ImportError:
pass
# =================================================================
# PATCH 3: Update sys.modules entry for source module
# =================================================================
# This ensures new imports get the patched version
if 'vllm.v1.worker.gpu_model_runner' in sys.modules:
sys.modules['vllm.v1.worker.gpu_model_runner'].GPUModelRunner = PatchedRunner
# =================================================================
# PATCH 4: Patch GPUWorker if already imported
# =================================================================
try:
if 'vllm.v1.worker.gpu_worker' in sys.modules:
worker_module = sys.modules['vllm.v1.worker.gpu_worker']
if hasattr(worker_module, 'GPUWorker'):
# Update any class-level references
worker_module.GPUModelRunner = PatchedRunner
except Exception:
pass
_PATCHES_APPLIED = True
print("[vLLM Patch] ✓ GPUModelRunner patched for shared memory updates")
return True
except ImportError as e:
print(f"[vLLM Patch] Warning: Could not apply patches: {e}")
print("[vLLM Patch] This may be due to vLLM version incompatibility")
print("[vLLM Patch] Shared memory updates will not be available")
return False
except Exception as e:
print(f"[vLLM Patch] Error applying patches: {e}")
import traceback
traceback.print_exc()
return False
def _create_patched_runner(BaseRunner: type) -> type:
@ -70,6 +121,8 @@ def _create_patched_runner(BaseRunner: type) -> type:
Returns a new class that inherits from the original and adds
shared memory + daemon functionality.
"""
import torch
import torch.multiprocessing as mp
from .weight_updater import weight_updater_process
class PatchedGPUModelRunner(BaseRunner):
@ -86,25 +139,39 @@ def _create_patched_runner(BaseRunner: type) -> type:
vLLM immediately sees the new weights for inference.
"""
_shared_memory_setup_done = False
weight_updater_process = None
def load_model(self, *args, **kwargs) -> None:
"""Load model and set up shared memory + update daemon."""
print(f"[vLLM Patch] PatchedGPUModelRunner.load_model() called!")
# Call original load_model
super().load_model(*args, **kwargs)
print(f"[vLLM Patch] Model loaded, checking shared weights setup...")
# Check if shared memory updates are enabled
enable_shared = os.environ.get("VLLM_ENABLE_SHARED_WEIGHTS", "0") == "1"
num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", -1))
num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", "-1"))
print(f"[vLLM Patch] VLLM_ENABLE_SHARED_WEIGHTS={enable_shared}, NUM_INFERENCE_NODES={num_inference_nodes}")
if not enable_shared and num_inference_nodes < 0:
print("[vLLM Patch] Shared weights disabled (set VLLM_ENABLE_SHARED_WEIGHTS=1 to enable)")
return
if self._shared_memory_setup_done:
print("[vLLM Patch] Shared memory already set up, skipping")
return
print("[vLLM Patch] Setting up shared memory weight updates...")
try:
self._setup_shared_memory()
self._spawn_weight_updater()
print("[vLLM Patch] ✓ Shared memory updates enabled")
PatchedGPUModelRunner._shared_memory_setup_done = True
print("[vLLM Patch] ✓ Shared memory updates enabled successfully!")
except Exception as e:
print(f"[vLLM Patch] Warning: Failed to set up shared memory: {e}")
import traceback
@ -115,43 +182,73 @@ def _create_patched_runner(BaseRunner: type) -> type:
import json
from pathlib import Path
# Make entire model shareable
self.model.share_memory()
print("[vLLM Patch] _setup_shared_memory() starting...")
# Also share_memory_() on each parameter individually
# Get state dict
state_dict = self.model.state_dict()
for key, val in state_dict.items():
if val.is_cuda or val.device.type == 'cuda':
val.share_memory_()
print(f"[vLLM Patch] Model has {len(state_dict)} parameters")
print(f"[vLLM Patch] Shared {len(state_dict)} tensors in model")
# Make entire model shareable via share_memory_() on each tensor
shared_count = 0
for key, val in state_dict.items():
try:
if val.is_cuda:
val.share_memory_()
shared_count += 1
except Exception as e:
print(f"[vLLM Patch] Warning: Could not share {key}: {e}")
print(f"[vLLM Patch] Called share_memory_() on {shared_count} CUDA tensors")
# Also try calling share_memory() on the model itself
try:
self.model.share_memory()
print("[vLLM Patch] Called model.share_memory()")
except Exception as e:
print(f"[vLLM Patch] Note: model.share_memory() not available: {e}")
# Export parameter info to JSON for trainer
log_dir = os.environ.get("LOGDIR", ".")
log_dir = os.environ.get("LOGDIR", "/tmp/atropos_bridge")
Path(log_dir).mkdir(parents=True, exist_ok=True)
json_path = Path(log_dir) / "vllm_bridge_config.json"
param_mappings = {}
param_names = []
for name, tensor in state_dict.items():
param_mappings[name] = {
"vllm_name": name,
"shape": list(tensor.shape),
"dtype": str(tensor.dtype),
}
param_names.append(name)
# Get model info
model_name = "unknown"
tp_degree = 1
try:
model_name = str(self.model_config.model)
tp_degree = self.parallel_config.tensor_parallel_size
except Exception as e:
print(f"[vLLM Patch] Warning: Could not get model config: {e}")
info = {
"model": str(self.model_config.model),
"tp_degree": self.parallel_config.tensor_parallel_size,
"model": model_name,
"tp_degree": tp_degree,
"dp_shard_degree": 1,
"param_mappings": param_mappings,
"param_names": sorted(state_dict.keys()),
"param_names": sorted(param_names),
"shared_weights_enabled": True,
"num_params": len(param_names),
}
try:
with open(json_path, "w") as f:
json.dump(info, f, indent=2)
print(f"[vLLM Patch] Exported {len(param_mappings)} params to {json_path}")
print(f"[vLLM Patch] Exported {len(param_mappings)} params to {json_path}")
except Exception as e:
print(f"[vLLM Patch] Warning: Failed to export params: {e}")
print(f"[vLLM Patch] ERROR: Failed to export params: {e}")
import traceback
traceback.print_exc()
def _spawn_weight_updater(self) -> None:
"""Spawn the daemon process for receiving weight updates."""
@ -171,9 +268,19 @@ def _create_patched_runner(BaseRunner: type) -> type:
# Get parallel configuration
tp_rank = get_tensor_model_parallel_rank()
gpu_id = torch.cuda.device(self.device).idx if hasattr(self.device, 'idx') else 0
print(f"[vLLM Patch] Spawning updater: tp_rank={tp_rank}, gpu={gpu_id}")
# Get GPU ID
gpu_id = 0
try:
if hasattr(self, 'device'):
if hasattr(self.device, 'index'):
gpu_id = self.device.index or 0
elif isinstance(self.device, int):
gpu_id = self.device
except Exception:
gpu_id = tp_rank
print(f"[vLLM Patch] Spawning weight updater: tp_rank={tp_rank}, gpu={gpu_id}")
# Spawn daemon process
ctx = mp.get_context("spawn")
@ -191,11 +298,26 @@ def _create_patched_runner(BaseRunner: type) -> type:
)
self.weight_updater_process.start()
print(f"[vLLM Patch] Weight updater daemon started (PID: {self.weight_updater_process.pid})")
print(f"[vLLM Patch] ✓ Weight updater daemon started (PID: {self.weight_updater_process.pid})")
# Set proper class name
PatchedGPUModelRunner.__name__ = "PatchedGPUModelRunner"
PatchedGPUModelRunner.__qualname__ = "PatchedGPUModelRunner"
return PatchedGPUModelRunner
def get_patched_runner() -> type | None:
"""Get the patched runner class if patches have been applied."""
return _PATCHED_RUNNER_CLASS
def is_patched() -> bool:
"""Check if patches have been applied."""
return _PATCHES_APPLIED
# Placeholder class for type checking
class PatchedGPUModelRunner:
"""
Placeholder class for type checking.
@ -204,5 +326,3 @@ class PatchedGPUModelRunner:
to properly inherit from vLLM's GPUModelRunner.
"""
pass