atropos/example_trainer/vllm_patching/patched_gpu_runner.py
2026-03-02 11:18:51 -05:00

208 lines
7.4 KiB
Python

"""
Patched GPU Model Runner - Enables shared memory weight updates.
This patches vLLM's GPUModelRunner to:
1. Call share_memory_() on model weights after loading
2. Spawn a daemon process that receives NCCL weight updates from trainers
The key insight is that share_memory_() makes tensors accessible from
multiple processes. The daemon receives updates via NCCL and copies them
directly into the shared tensors, which vLLM reads for inference.
"""
from __future__ import annotations
import os
from typing import TYPE_CHECKING
import torch
import torch.multiprocessing as mp
# Lazy imports to avoid circular dependencies
if TYPE_CHECKING:
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
# Flag to track if patches have been applied
_PATCHES_APPLIED = False
def apply_patches() -> None:
"""
Apply patches to vLLM's GPUModelRunner.
This must be called BEFORE importing vLLM's engine classes.
Safe to call multiple times (idempotent).
Usage:
from example_trainer.vllm_patching import apply_patches
apply_patches()
from vllm import AsyncLLM # Now uses patched runner
"""
global _PATCHES_APPLIED
if _PATCHES_APPLIED:
return
try:
import vllm.v1.worker.gpu_worker
from vllm.v1.worker.gpu_model_runner import GPUModelRunner as OriginalRunner
# Create patched class
PatchedRunner = _create_patched_runner(OriginalRunner)
# Replace in vllm module
vllm.v1.worker.gpu_worker.GPUModelRunner = PatchedRunner
_PATCHES_APPLIED = True
print("[vLLM Patch] ✓ GPUModelRunner patched for shared memory updates")
except ImportError as e:
print(f"[vLLM Patch] Warning: Could not apply patches: {e}")
print("[vLLM Patch] Shared memory updates will not be available")
def _create_patched_runner(BaseRunner: type) -> type:
"""
Create a patched GPUModelRunner class.
Returns a new class that inherits from the original and adds
shared memory + daemon functionality.
"""
from .weight_updater import weight_updater_process
class PatchedGPUModelRunner(BaseRunner):
"""
Patched GPUModelRunner that enables shared memory weight updates.
After loading the model, this:
1. Calls share_memory_() on all parameters to make them accessible
from other processes
2. Spawns a daemon process that joins NCCL groups with the trainer
and receives weight updates
The daemon copies updates directly into the shared tensors, so
vLLM immediately sees the new weights for inference.
"""
def load_model(self, *args, **kwargs) -> None:
"""Load model and set up shared memory + update daemon."""
# Call original load_model
super().load_model(*args, **kwargs)
# Check if shared memory updates are enabled
enable_shared = os.environ.get("VLLM_ENABLE_SHARED_WEIGHTS", "0") == "1"
num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", -1))
if not enable_shared and num_inference_nodes < 0:
print("[vLLM Patch] Shared weights disabled (set VLLM_ENABLE_SHARED_WEIGHTS=1 to enable)")
return
print("[vLLM Patch] Setting up shared memory weight updates...")
try:
self._setup_shared_memory()
self._spawn_weight_updater()
print("[vLLM Patch] ✓ Shared memory updates enabled")
except Exception as e:
print(f"[vLLM Patch] Warning: Failed to set up shared memory: {e}")
import traceback
traceback.print_exc()
def _setup_shared_memory(self) -> None:
"""Move model tensors to shared memory and export param info."""
import json
from pathlib import Path
# Make entire model shareable
self.model.share_memory()
# Also share_memory_() on each parameter individually
state_dict = self.model.state_dict()
for key, val in state_dict.items():
if val.is_cuda or val.device.type == 'cuda':
val.share_memory_()
print(f"[vLLM Patch] Shared {len(state_dict)} tensors in model")
# Export parameter info to JSON for trainer
log_dir = os.environ.get("LOGDIR", ".")
json_path = Path(log_dir) / "vllm_bridge_config.json"
param_mappings = {}
for name, tensor in state_dict.items():
param_mappings[name] = {
"vllm_name": name,
"shape": list(tensor.shape),
"dtype": str(tensor.dtype),
}
info = {
"model": str(self.model_config.model),
"tp_degree": self.parallel_config.tensor_parallel_size,
"dp_shard_degree": 1,
"param_mappings": param_mappings,
"param_names": sorted(state_dict.keys()),
}
try:
with open(json_path, "w") as f:
json.dump(info, f, indent=2)
print(f"[vLLM Patch] Exported {len(param_mappings)} params to {json_path}")
except Exception as e:
print(f"[vLLM Patch] Warning: Failed to export params: {e}")
def _spawn_weight_updater(self) -> None:
"""Spawn the daemon process for receiving weight updates."""
try:
from vllm.distributed import get_tensor_model_parallel_rank
except ImportError:
# Fallback for older vLLM versions
get_tensor_model_parallel_rank = lambda: 0
# Get model configuration
state_dict = self.model.state_dict()
# Get attention head counts
hf_config = self.model_config.hf_text_config
num_heads = getattr(hf_config, "num_attention_heads", 0)
num_kv_heads = self.model_config.get_total_num_kv_heads()
# Get parallel configuration
tp_rank = get_tensor_model_parallel_rank()
gpu_id = torch.cuda.device(self.device).idx if hasattr(self.device, 'idx') else 0
print(f"[vLLM Patch] Spawning updater: tp_rank={tp_rank}, gpu={gpu_id}")
# Spawn daemon process
ctx = mp.get_context("spawn")
self.weight_updater_process = ctx.Process(
target=weight_updater_process,
args=(
state_dict,
num_heads,
num_kv_heads,
tp_rank,
self.parallel_config.tensor_parallel_size,
gpu_id,
),
daemon=True,
)
self.weight_updater_process.start()
print(f"[vLLM Patch] Weight updater daemon started (PID: {self.weight_updater_process.pid})")
return PatchedGPUModelRunner
class PatchedGPUModelRunner:
"""
Placeholder class for type checking.
The actual patched class is created dynamically by _create_patched_runner()
to properly inherit from vLLM's GPUModelRunner.
"""
pass