""" Patched GPU Model Runner - Enables shared memory weight updates. This patches vLLM's GPUModelRunner to: 1. Call share_memory_() on model weights after loading 2. Spawn a daemon process that receives NCCL weight updates from trainers The key insight is that share_memory_() makes tensors accessible from multiple processes. The daemon receives updates via NCCL and copies them directly into the shared tensors, which vLLM reads for inference. CRITICAL: This module must be imported and apply_patches() called BEFORE any vLLM imports. The patches MUST happen before vLLM caches module references. """ from __future__ import annotations import os import sys from typing import TYPE_CHECKING # Lazy imports to avoid circular dependencies if TYPE_CHECKING: from vllm.v1.worker.gpu_model_runner import GPUModelRunner # Flag to track if patches have been applied _PATCHES_APPLIED = False _PATCHED_RUNNER_CLASS = None def apply_patches() -> bool: """ Apply patches to vLLM's GPUModelRunner in ALL locations. This must be called BEFORE importing vLLM's engine classes. Safe to call multiple times (idempotent). Returns True if patches were applied successfully. Usage: # CRITICAL: Import and call BEFORE any vLLM imports! import os os.environ["VLLM_ENABLE_SHARED_WEIGHTS"] = "1" from example_trainer.vllm_patching import apply_patches apply_patches() # Now import vLLM from vllm import AsyncLLM # Uses patched runner """ global _PATCHES_APPLIED, _PATCHED_RUNNER_CLASS if _PATCHES_APPLIED: return True try: # Import the source module and get original class import vllm.v1.worker.gpu_model_runner as gpu_model_runner_module from vllm.v1.worker.gpu_model_runner import GPUModelRunner as OriginalRunner # Create the patched class PatchedRunner = _create_patched_runner(OriginalRunner) _PATCHED_RUNNER_CLASS = PatchedRunner # ================================================================= # PATCH 1: Replace in source module # ================================================================= gpu_model_runner_module.GPUModelRunner = PatchedRunner print("[vLLM Patch] ✓ Patched vllm.v1.worker.gpu_model_runner.GPUModelRunner") # ================================================================= # PATCH 2: Replace in gpu_worker module (main usage location) # ================================================================= try: import vllm.v1.worker.gpu_worker as gpu_worker_module gpu_worker_module.GPUModelRunner = PatchedRunner print("[vLLM Patch] ✓ Patched vllm.v1.worker.gpu_worker.GPUModelRunner") except ImportError: pass # ================================================================= # PATCH 3: Update sys.modules entry for source module # ================================================================= # This ensures new imports get the patched version if 'vllm.v1.worker.gpu_model_runner' in sys.modules: sys.modules['vllm.v1.worker.gpu_model_runner'].GPUModelRunner = PatchedRunner # ================================================================= # PATCH 4: Patch GPUWorker if already imported # ================================================================= try: if 'vllm.v1.worker.gpu_worker' in sys.modules: worker_module = sys.modules['vllm.v1.worker.gpu_worker'] if hasattr(worker_module, 'GPUWorker'): # Update any class-level references worker_module.GPUModelRunner = PatchedRunner except Exception: pass _PATCHES_APPLIED = True print("[vLLM Patch] ✓ GPUModelRunner patched for shared memory updates") return True except ImportError as e: print(f"[vLLM Patch] Warning: Could not apply patches: {e}") print("[vLLM Patch] This may be due to vLLM version incompatibility") print("[vLLM Patch] Shared memory updates will not be available") return False except Exception as e: print(f"[vLLM Patch] Error applying patches: {e}") import traceback traceback.print_exc() return False def _create_patched_runner(BaseRunner: type) -> type: """ Create a patched GPUModelRunner class. Returns a new class that inherits from the original and adds shared memory + daemon functionality. """ import torch import torch.multiprocessing as mp from .weight_updater import weight_updater_process class PatchedGPUModelRunner(BaseRunner): """ Patched GPUModelRunner that enables shared memory weight updates. After loading the model, this: 1. Calls share_memory_() on all parameters to make them accessible from other processes 2. Spawns a daemon process that joins NCCL groups with the trainer and receives weight updates The daemon copies updates directly into the shared tensors, so vLLM immediately sees the new weights for inference. """ _shared_memory_setup_done = False weight_updater_process = None def load_model(self, *args, **kwargs) -> None: """Load model and set up shared memory + update daemon.""" print(f"[vLLM Patch] PatchedGPUModelRunner.load_model() called!") # Call original load_model super().load_model(*args, **kwargs) print(f"[vLLM Patch] Model loaded, checking shared weights setup...") # Check if shared memory updates are enabled enable_shared = os.environ.get("VLLM_ENABLE_SHARED_WEIGHTS", "0") == "1" num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", "-1")) print(f"[vLLM Patch] VLLM_ENABLE_SHARED_WEIGHTS={enable_shared}, NUM_INFERENCE_NODES={num_inference_nodes}") if not enable_shared and num_inference_nodes < 0: print("[vLLM Patch] Shared weights disabled (set VLLM_ENABLE_SHARED_WEIGHTS=1 to enable)") return if self._shared_memory_setup_done: print("[vLLM Patch] Shared memory already set up, skipping") return print("[vLLM Patch] Setting up shared memory weight updates...", flush=True) try: self._setup_shared_memory() PatchedGPUModelRunner._shared_memory_setup_done = True print("[vLLM Patch] ✓ Shared memory setup complete!", flush=True) except Exception as e: print(f"[vLLM Patch] ERROR in _setup_shared_memory: {e}", flush=True) import traceback traceback.print_exc() return # Spawn weight updater daemon (optional - can be skipped for HTTP-only mode) skip_daemon = os.environ.get("VLLM_SKIP_WEIGHT_DAEMON", "0") == "1" if skip_daemon: print("[vLLM Patch] Skipping weight updater daemon (VLLM_SKIP_WEIGHT_DAEMON=1)", flush=True) return try: print("[vLLM Patch] Spawning weight updater daemon...", flush=True) self._spawn_weight_updater() print("[vLLM Patch] ✓ Weight updater daemon spawned!", flush=True) except Exception as e: print(f"[vLLM Patch] ERROR spawning weight updater: {e}", flush=True) import traceback traceback.print_exc() print("[vLLM Patch] Continuing without daemon (HTTP-only mode)", flush=True) def _setup_shared_memory(self) -> None: """Move model tensors to shared memory and export param info.""" import json from pathlib import Path print("[vLLM Patch] _setup_shared_memory() starting...") # Get state dict state_dict = self.model.state_dict() print(f"[vLLM Patch] Model has {len(state_dict)} parameters") # Make entire model shareable via share_memory_() on each tensor shared_count = 0 for key, val in state_dict.items(): try: if val.is_cuda: val.share_memory_() shared_count += 1 except Exception as e: print(f"[vLLM Patch] Warning: Could not share {key}: {e}") print(f"[vLLM Patch] Called share_memory_() on {shared_count} CUDA tensors") # Also try calling share_memory() on the model itself try: self.model.share_memory() print("[vLLM Patch] Called model.share_memory()") except Exception as e: print(f"[vLLM Patch] Note: model.share_memory() not available: {e}") # Export parameter info to JSON for trainer log_dir = os.environ.get("LOGDIR", ".") Path(log_dir).mkdir(parents=True, exist_ok=True) json_path = Path(log_dir) / "vllm_bridge_config.json" param_mappings = {} param_names = [] ipc_handles = {} for name, tensor in state_dict.items(): param_mappings[name] = { "vllm_name": name, "shape": list(tensor.shape), "dtype": str(tensor.dtype), "device": str(tensor.device), } param_names.append(name) # Export CUDA IPC handles for true single-copy mode if tensor.is_cuda: try: # Get the storage's IPC handle storage = tensor.untyped_storage() ipc_handle = storage._share_cuda_() ipc_handles[name] = { "handle": ipc_handle[0].hex() if isinstance(ipc_handle[0], bytes) else str(ipc_handle[0]), "storage_size": ipc_handle[1], "storage_offset": tensor.storage_offset(), "shape": list(tensor.shape), "stride": list(tensor.stride()), "dtype": str(tensor.dtype), "device_index": tensor.device.index, } except Exception as e: print(f"[vLLM Patch] Could not get IPC handle for {name}: {e}", flush=True) print(f"[vLLM Patch] Exported {len(ipc_handles)} IPC handles for single-copy mode", flush=True) # Get model info model_name = "unknown" tp_degree = 1 try: model_name = str(self.model_config.model) tp_degree = self.parallel_config.tensor_parallel_size except Exception as e: print(f"[vLLM Patch] Warning: Could not get model config: {e}") info = { "model": model_name, "tp_degree": tp_degree, "dp_shard_degree": 1, "param_mappings": param_mappings, "param_names": sorted(param_names), "ipc_handles": ipc_handles, "shared_weights_enabled": True, "num_params": len(param_names), "single_copy_enabled": len(ipc_handles) > 0, } try: with open(json_path, "w") as f: json.dump(info, f, indent=2) print(f"[vLLM Patch] ✓ Exported {len(param_mappings)} params to {json_path}") except Exception as e: print(f"[vLLM Patch] ERROR: Failed to export params: {e}") import traceback traceback.print_exc() def _spawn_weight_updater(self) -> None: """Start the weight updater as a background thread. Note: We use threading instead of multiprocessing because vLLM's worker processes are daemons, and daemons cannot spawn child processes. """ import threading print("[vLLM Patch] _spawn_weight_updater() called", flush=True) try: from vllm.distributed import get_tensor_model_parallel_rank print("[vLLM Patch] Imported get_tensor_model_parallel_rank", flush=True) except ImportError as e: print(f"[vLLM Patch] Could not import get_tensor_model_parallel_rank: {e}", flush=True) get_tensor_model_parallel_rank = lambda: 0 # Get model configuration state_dict = self.model.state_dict() print(f"[vLLM Patch] Got state_dict with {len(state_dict)} params", flush=True) # Get attention head counts hf_config = self.model_config.hf_text_config num_heads = getattr(hf_config, "num_attention_heads", 0) num_kv_heads = self.model_config.get_total_num_kv_heads() print(f"[vLLM Patch] num_heads={num_heads}, num_kv_heads={num_kv_heads}", flush=True) # Get parallel configuration tp_rank = get_tensor_model_parallel_rank() print(f"[vLLM Patch] tp_rank={tp_rank}", flush=True) # Get GPU ID gpu_id = 0 try: if hasattr(self, 'device'): if hasattr(self.device, 'index'): gpu_id = self.device.index or 0 elif isinstance(self.device, int): gpu_id = self.device except Exception: gpu_id = tp_rank print(f"[vLLM Patch] Starting weight updater thread: tp_rank={tp_rank}, gpu={gpu_id}", flush=True) # Start as a daemon thread (threads CAN be started from daemon processes) self.weight_updater_thread = threading.Thread( target=weight_updater_process, args=( state_dict, num_heads, num_kv_heads, tp_rank, self.parallel_config.tensor_parallel_size, gpu_id, ), daemon=True, name=f"WeightUpdater_TP{tp_rank}", ) print("[vLLM Patch] Starting thread...", flush=True) self.weight_updater_thread.start() print(f"[vLLM Patch] ✓ Weight updater thread started (name: {self.weight_updater_thread.name})", flush=True) # Set proper class name PatchedGPUModelRunner.__name__ = "PatchedGPUModelRunner" PatchedGPUModelRunner.__qualname__ = "PatchedGPUModelRunner" return PatchedGPUModelRunner def get_patched_runner() -> type | None: """Get the patched runner class if patches have been applied.""" return _PATCHED_RUNNER_CLASS def is_patched() -> bool: """Check if patches have been applied.""" return _PATCHES_APPLIED # Placeholder class for type checking class PatchedGPUModelRunner: """ Placeholder class for type checking. The actual patched class is created dynamically by _create_patched_runner() to properly inherit from vLLM's GPUModelRunner. """ pass