mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
patching problem
This commit is contained in:
parent
80d2608c4e
commit
fad8e77be2
4 changed files with 234 additions and 76 deletions
|
|
@ -8,16 +8,17 @@ This patches vLLM's GPUModelRunner to:
|
|||
The key insight is that share_memory_() makes tensors accessible from
|
||||
multiple processes. The daemon receives updates via NCCL and copies them
|
||||
directly into the shared tensors, which vLLM reads for inference.
|
||||
|
||||
CRITICAL: This module must be imported and apply_patches() called BEFORE
|
||||
any vLLM imports. The patches MUST happen before vLLM caches module references.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
# Lazy imports to avoid circular dependencies
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
|
@ -25,42 +26,92 @@ if TYPE_CHECKING:
|
|||
|
||||
# Flag to track if patches have been applied
|
||||
_PATCHES_APPLIED = False
|
||||
_PATCHED_RUNNER_CLASS = None
|
||||
|
||||
|
||||
def apply_patches() -> None:
|
||||
def apply_patches() -> bool:
|
||||
"""
|
||||
Apply patches to vLLM's GPUModelRunner.
|
||||
Apply patches to vLLM's GPUModelRunner in ALL locations.
|
||||
|
||||
This must be called BEFORE importing vLLM's engine classes.
|
||||
Safe to call multiple times (idempotent).
|
||||
|
||||
Returns True if patches were applied successfully.
|
||||
|
||||
Usage:
|
||||
# CRITICAL: Import and call BEFORE any vLLM imports!
|
||||
import os
|
||||
os.environ["VLLM_ENABLE_SHARED_WEIGHTS"] = "1"
|
||||
|
||||
from example_trainer.vllm_patching import apply_patches
|
||||
apply_patches()
|
||||
|
||||
from vllm import AsyncLLM # Now uses patched runner
|
||||
# Now import vLLM
|
||||
from vllm import AsyncLLM # Uses patched runner
|
||||
"""
|
||||
global _PATCHES_APPLIED
|
||||
global _PATCHES_APPLIED, _PATCHED_RUNNER_CLASS
|
||||
|
||||
if _PATCHES_APPLIED:
|
||||
return
|
||||
return True
|
||||
|
||||
try:
|
||||
import vllm.v1.worker.gpu_worker
|
||||
# Import the source module and get original class
|
||||
import vllm.v1.worker.gpu_model_runner as gpu_model_runner_module
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner as OriginalRunner
|
||||
|
||||
# Create patched class
|
||||
# Create the patched class
|
||||
PatchedRunner = _create_patched_runner(OriginalRunner)
|
||||
_PATCHED_RUNNER_CLASS = PatchedRunner
|
||||
|
||||
# Replace in vllm module
|
||||
vllm.v1.worker.gpu_worker.GPUModelRunner = PatchedRunner
|
||||
# =================================================================
|
||||
# PATCH 1: Replace in source module
|
||||
# =================================================================
|
||||
gpu_model_runner_module.GPUModelRunner = PatchedRunner
|
||||
print("[vLLM Patch] ✓ Patched vllm.v1.worker.gpu_model_runner.GPUModelRunner")
|
||||
|
||||
# =================================================================
|
||||
# PATCH 2: Replace in gpu_worker module (main usage location)
|
||||
# =================================================================
|
||||
try:
|
||||
import vllm.v1.worker.gpu_worker as gpu_worker_module
|
||||
gpu_worker_module.GPUModelRunner = PatchedRunner
|
||||
print("[vLLM Patch] ✓ Patched vllm.v1.worker.gpu_worker.GPUModelRunner")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# =================================================================
|
||||
# PATCH 3: Update sys.modules entry for source module
|
||||
# =================================================================
|
||||
# This ensures new imports get the patched version
|
||||
if 'vllm.v1.worker.gpu_model_runner' in sys.modules:
|
||||
sys.modules['vllm.v1.worker.gpu_model_runner'].GPUModelRunner = PatchedRunner
|
||||
|
||||
# =================================================================
|
||||
# PATCH 4: Patch GPUWorker if already imported
|
||||
# =================================================================
|
||||
try:
|
||||
if 'vllm.v1.worker.gpu_worker' in sys.modules:
|
||||
worker_module = sys.modules['vllm.v1.worker.gpu_worker']
|
||||
if hasattr(worker_module, 'GPUWorker'):
|
||||
# Update any class-level references
|
||||
worker_module.GPUModelRunner = PatchedRunner
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
_PATCHES_APPLIED = True
|
||||
print("[vLLM Patch] ✓ GPUModelRunner patched for shared memory updates")
|
||||
return True
|
||||
|
||||
except ImportError as e:
|
||||
print(f"[vLLM Patch] Warning: Could not apply patches: {e}")
|
||||
print("[vLLM Patch] This may be due to vLLM version incompatibility")
|
||||
print("[vLLM Patch] Shared memory updates will not be available")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"[vLLM Patch] Error applying patches: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def _create_patched_runner(BaseRunner: type) -> type:
|
||||
|
|
@ -70,6 +121,8 @@ def _create_patched_runner(BaseRunner: type) -> type:
|
|||
Returns a new class that inherits from the original and adds
|
||||
shared memory + daemon functionality.
|
||||
"""
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from .weight_updater import weight_updater_process
|
||||
|
||||
class PatchedGPUModelRunner(BaseRunner):
|
||||
|
|
@ -86,25 +139,39 @@ def _create_patched_runner(BaseRunner: type) -> type:
|
|||
vLLM immediately sees the new weights for inference.
|
||||
"""
|
||||
|
||||
_shared_memory_setup_done = False
|
||||
weight_updater_process = None
|
||||
|
||||
def load_model(self, *args, **kwargs) -> None:
|
||||
"""Load model and set up shared memory + update daemon."""
|
||||
print(f"[vLLM Patch] PatchedGPUModelRunner.load_model() called!")
|
||||
|
||||
# Call original load_model
|
||||
super().load_model(*args, **kwargs)
|
||||
|
||||
print(f"[vLLM Patch] Model loaded, checking shared weights setup...")
|
||||
|
||||
# Check if shared memory updates are enabled
|
||||
enable_shared = os.environ.get("VLLM_ENABLE_SHARED_WEIGHTS", "0") == "1"
|
||||
num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", -1))
|
||||
num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", "-1"))
|
||||
|
||||
print(f"[vLLM Patch] VLLM_ENABLE_SHARED_WEIGHTS={enable_shared}, NUM_INFERENCE_NODES={num_inference_nodes}")
|
||||
|
||||
if not enable_shared and num_inference_nodes < 0:
|
||||
print("[vLLM Patch] Shared weights disabled (set VLLM_ENABLE_SHARED_WEIGHTS=1 to enable)")
|
||||
return
|
||||
|
||||
if self._shared_memory_setup_done:
|
||||
print("[vLLM Patch] Shared memory already set up, skipping")
|
||||
return
|
||||
|
||||
print("[vLLM Patch] Setting up shared memory weight updates...")
|
||||
|
||||
try:
|
||||
self._setup_shared_memory()
|
||||
self._spawn_weight_updater()
|
||||
print("[vLLM Patch] ✓ Shared memory updates enabled")
|
||||
PatchedGPUModelRunner._shared_memory_setup_done = True
|
||||
print("[vLLM Patch] ✓ Shared memory updates enabled successfully!")
|
||||
except Exception as e:
|
||||
print(f"[vLLM Patch] Warning: Failed to set up shared memory: {e}")
|
||||
import traceback
|
||||
|
|
@ -115,43 +182,73 @@ def _create_patched_runner(BaseRunner: type) -> type:
|
|||
import json
|
||||
from pathlib import Path
|
||||
|
||||
# Make entire model shareable
|
||||
self.model.share_memory()
|
||||
print("[vLLM Patch] _setup_shared_memory() starting...")
|
||||
|
||||
# Also share_memory_() on each parameter individually
|
||||
# Get state dict
|
||||
state_dict = self.model.state_dict()
|
||||
for key, val in state_dict.items():
|
||||
if val.is_cuda or val.device.type == 'cuda':
|
||||
val.share_memory_()
|
||||
print(f"[vLLM Patch] Model has {len(state_dict)} parameters")
|
||||
|
||||
print(f"[vLLM Patch] Shared {len(state_dict)} tensors in model")
|
||||
# Make entire model shareable via share_memory_() on each tensor
|
||||
shared_count = 0
|
||||
for key, val in state_dict.items():
|
||||
try:
|
||||
if val.is_cuda:
|
||||
val.share_memory_()
|
||||
shared_count += 1
|
||||
except Exception as e:
|
||||
print(f"[vLLM Patch] Warning: Could not share {key}: {e}")
|
||||
|
||||
print(f"[vLLM Patch] Called share_memory_() on {shared_count} CUDA tensors")
|
||||
|
||||
# Also try calling share_memory() on the model itself
|
||||
try:
|
||||
self.model.share_memory()
|
||||
print("[vLLM Patch] Called model.share_memory()")
|
||||
except Exception as e:
|
||||
print(f"[vLLM Patch] Note: model.share_memory() not available: {e}")
|
||||
|
||||
# Export parameter info to JSON for trainer
|
||||
log_dir = os.environ.get("LOGDIR", ".")
|
||||
log_dir = os.environ.get("LOGDIR", "/tmp/atropos_bridge")
|
||||
Path(log_dir).mkdir(parents=True, exist_ok=True)
|
||||
json_path = Path(log_dir) / "vllm_bridge_config.json"
|
||||
|
||||
param_mappings = {}
|
||||
param_names = []
|
||||
for name, tensor in state_dict.items():
|
||||
param_mappings[name] = {
|
||||
"vllm_name": name,
|
||||
"shape": list(tensor.shape),
|
||||
"dtype": str(tensor.dtype),
|
||||
}
|
||||
param_names.append(name)
|
||||
|
||||
# Get model info
|
||||
model_name = "unknown"
|
||||
tp_degree = 1
|
||||
try:
|
||||
model_name = str(self.model_config.model)
|
||||
tp_degree = self.parallel_config.tensor_parallel_size
|
||||
except Exception as e:
|
||||
print(f"[vLLM Patch] Warning: Could not get model config: {e}")
|
||||
|
||||
info = {
|
||||
"model": str(self.model_config.model),
|
||||
"tp_degree": self.parallel_config.tensor_parallel_size,
|
||||
"model": model_name,
|
||||
"tp_degree": tp_degree,
|
||||
"dp_shard_degree": 1,
|
||||
"param_mappings": param_mappings,
|
||||
"param_names": sorted(state_dict.keys()),
|
||||
"param_names": sorted(param_names),
|
||||
"shared_weights_enabled": True,
|
||||
"num_params": len(param_names),
|
||||
}
|
||||
|
||||
try:
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(info, f, indent=2)
|
||||
print(f"[vLLM Patch] Exported {len(param_mappings)} params to {json_path}")
|
||||
print(f"[vLLM Patch] ✓ Exported {len(param_mappings)} params to {json_path}")
|
||||
except Exception as e:
|
||||
print(f"[vLLM Patch] Warning: Failed to export params: {e}")
|
||||
print(f"[vLLM Patch] ERROR: Failed to export params: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
def _spawn_weight_updater(self) -> None:
|
||||
"""Spawn the daemon process for receiving weight updates."""
|
||||
|
|
@ -171,9 +268,19 @@ def _create_patched_runner(BaseRunner: type) -> type:
|
|||
|
||||
# Get parallel configuration
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
gpu_id = torch.cuda.device(self.device).idx if hasattr(self.device, 'idx') else 0
|
||||
|
||||
print(f"[vLLM Patch] Spawning updater: tp_rank={tp_rank}, gpu={gpu_id}")
|
||||
# Get GPU ID
|
||||
gpu_id = 0
|
||||
try:
|
||||
if hasattr(self, 'device'):
|
||||
if hasattr(self.device, 'index'):
|
||||
gpu_id = self.device.index or 0
|
||||
elif isinstance(self.device, int):
|
||||
gpu_id = self.device
|
||||
except Exception:
|
||||
gpu_id = tp_rank
|
||||
|
||||
print(f"[vLLM Patch] Spawning weight updater: tp_rank={tp_rank}, gpu={gpu_id}")
|
||||
|
||||
# Spawn daemon process
|
||||
ctx = mp.get_context("spawn")
|
||||
|
|
@ -191,11 +298,26 @@ def _create_patched_runner(BaseRunner: type) -> type:
|
|||
)
|
||||
self.weight_updater_process.start()
|
||||
|
||||
print(f"[vLLM Patch] Weight updater daemon started (PID: {self.weight_updater_process.pid})")
|
||||
print(f"[vLLM Patch] ✓ Weight updater daemon started (PID: {self.weight_updater_process.pid})")
|
||||
|
||||
# Set proper class name
|
||||
PatchedGPUModelRunner.__name__ = "PatchedGPUModelRunner"
|
||||
PatchedGPUModelRunner.__qualname__ = "PatchedGPUModelRunner"
|
||||
|
||||
return PatchedGPUModelRunner
|
||||
|
||||
|
||||
def get_patched_runner() -> type | None:
|
||||
"""Get the patched runner class if patches have been applied."""
|
||||
return _PATCHED_RUNNER_CLASS
|
||||
|
||||
|
||||
def is_patched() -> bool:
|
||||
"""Check if patches have been applied."""
|
||||
return _PATCHES_APPLIED
|
||||
|
||||
|
||||
# Placeholder class for type checking
|
||||
class PatchedGPUModelRunner:
|
||||
"""
|
||||
Placeholder class for type checking.
|
||||
|
|
@ -204,5 +326,3 @@ class PatchedGPUModelRunner:
|
|||
to properly inherit from vLLM's GPUModelRunner.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue