mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
208 lines
7.4 KiB
Python
208 lines
7.4 KiB
Python
"""
|
|
Patched GPU Model Runner - Enables shared memory weight updates.
|
|
|
|
This patches vLLM's GPUModelRunner to:
|
|
1. Call share_memory_() on model weights after loading
|
|
2. Spawn a daemon process that receives NCCL weight updates from trainers
|
|
|
|
The key insight is that share_memory_() makes tensors accessible from
|
|
multiple processes. The daemon receives updates via NCCL and copies them
|
|
directly into the shared tensors, which vLLM reads for inference.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
from typing import TYPE_CHECKING
|
|
|
|
import torch
|
|
import torch.multiprocessing as mp
|
|
|
|
# Lazy imports to avoid circular dependencies
|
|
if TYPE_CHECKING:
|
|
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
|
|
|
|
|
# Flag to track if patches have been applied
|
|
_PATCHES_APPLIED = False
|
|
|
|
|
|
def apply_patches() -> None:
|
|
"""
|
|
Apply patches to vLLM's GPUModelRunner.
|
|
|
|
This must be called BEFORE importing vLLM's engine classes.
|
|
Safe to call multiple times (idempotent).
|
|
|
|
Usage:
|
|
from example_trainer.vllm_patching import apply_patches
|
|
apply_patches()
|
|
|
|
from vllm import AsyncLLM # Now uses patched runner
|
|
"""
|
|
global _PATCHES_APPLIED
|
|
|
|
if _PATCHES_APPLIED:
|
|
return
|
|
|
|
try:
|
|
import vllm.v1.worker.gpu_worker
|
|
from vllm.v1.worker.gpu_model_runner import GPUModelRunner as OriginalRunner
|
|
|
|
# Create patched class
|
|
PatchedRunner = _create_patched_runner(OriginalRunner)
|
|
|
|
# Replace in vllm module
|
|
vllm.v1.worker.gpu_worker.GPUModelRunner = PatchedRunner
|
|
|
|
_PATCHES_APPLIED = True
|
|
print("[vLLM Patch] ✓ GPUModelRunner patched for shared memory updates")
|
|
|
|
except ImportError as e:
|
|
print(f"[vLLM Patch] Warning: Could not apply patches: {e}")
|
|
print("[vLLM Patch] Shared memory updates will not be available")
|
|
|
|
|
|
def _create_patched_runner(BaseRunner: type) -> type:
|
|
"""
|
|
Create a patched GPUModelRunner class.
|
|
|
|
Returns a new class that inherits from the original and adds
|
|
shared memory + daemon functionality.
|
|
"""
|
|
from .weight_updater import weight_updater_process
|
|
|
|
class PatchedGPUModelRunner(BaseRunner):
|
|
"""
|
|
Patched GPUModelRunner that enables shared memory weight updates.
|
|
|
|
After loading the model, this:
|
|
1. Calls share_memory_() on all parameters to make them accessible
|
|
from other processes
|
|
2. Spawns a daemon process that joins NCCL groups with the trainer
|
|
and receives weight updates
|
|
|
|
The daemon copies updates directly into the shared tensors, so
|
|
vLLM immediately sees the new weights for inference.
|
|
"""
|
|
|
|
def load_model(self, *args, **kwargs) -> None:
|
|
"""Load model and set up shared memory + update daemon."""
|
|
# Call original load_model
|
|
super().load_model(*args, **kwargs)
|
|
|
|
# Check if shared memory updates are enabled
|
|
enable_shared = os.environ.get("VLLM_ENABLE_SHARED_WEIGHTS", "0") == "1"
|
|
num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", -1))
|
|
|
|
if not enable_shared and num_inference_nodes < 0:
|
|
print("[vLLM Patch] Shared weights disabled (set VLLM_ENABLE_SHARED_WEIGHTS=1 to enable)")
|
|
return
|
|
|
|
print("[vLLM Patch] Setting up shared memory weight updates...")
|
|
|
|
try:
|
|
self._setup_shared_memory()
|
|
self._spawn_weight_updater()
|
|
print("[vLLM Patch] ✓ Shared memory updates enabled")
|
|
except Exception as e:
|
|
print(f"[vLLM Patch] Warning: Failed to set up shared memory: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
def _setup_shared_memory(self) -> None:
|
|
"""Move model tensors to shared memory and export param info."""
|
|
import json
|
|
from pathlib import Path
|
|
|
|
# Make entire model shareable
|
|
self.model.share_memory()
|
|
|
|
# Also share_memory_() on each parameter individually
|
|
state_dict = self.model.state_dict()
|
|
for key, val in state_dict.items():
|
|
if val.is_cuda or val.device.type == 'cuda':
|
|
val.share_memory_()
|
|
|
|
print(f"[vLLM Patch] Shared {len(state_dict)} tensors in model")
|
|
|
|
# Export parameter info to JSON for trainer
|
|
log_dir = os.environ.get("LOGDIR", ".")
|
|
json_path = Path(log_dir) / "vllm_bridge_config.json"
|
|
|
|
param_mappings = {}
|
|
for name, tensor in state_dict.items():
|
|
param_mappings[name] = {
|
|
"vllm_name": name,
|
|
"shape": list(tensor.shape),
|
|
"dtype": str(tensor.dtype),
|
|
}
|
|
|
|
info = {
|
|
"model": str(self.model_config.model),
|
|
"tp_degree": self.parallel_config.tensor_parallel_size,
|
|
"dp_shard_degree": 1,
|
|
"param_mappings": param_mappings,
|
|
"param_names": sorted(state_dict.keys()),
|
|
}
|
|
|
|
try:
|
|
with open(json_path, "w") as f:
|
|
json.dump(info, f, indent=2)
|
|
print(f"[vLLM Patch] Exported {len(param_mappings)} params to {json_path}")
|
|
except Exception as e:
|
|
print(f"[vLLM Patch] Warning: Failed to export params: {e}")
|
|
|
|
def _spawn_weight_updater(self) -> None:
|
|
"""Spawn the daemon process for receiving weight updates."""
|
|
try:
|
|
from vllm.distributed import get_tensor_model_parallel_rank
|
|
except ImportError:
|
|
# Fallback for older vLLM versions
|
|
get_tensor_model_parallel_rank = lambda: 0
|
|
|
|
# Get model configuration
|
|
state_dict = self.model.state_dict()
|
|
|
|
# Get attention head counts
|
|
hf_config = self.model_config.hf_text_config
|
|
num_heads = getattr(hf_config, "num_attention_heads", 0)
|
|
num_kv_heads = self.model_config.get_total_num_kv_heads()
|
|
|
|
# Get parallel configuration
|
|
tp_rank = get_tensor_model_parallel_rank()
|
|
gpu_id = torch.cuda.device(self.device).idx if hasattr(self.device, 'idx') else 0
|
|
|
|
print(f"[vLLM Patch] Spawning updater: tp_rank={tp_rank}, gpu={gpu_id}")
|
|
|
|
# Spawn daemon process
|
|
ctx = mp.get_context("spawn")
|
|
self.weight_updater_process = ctx.Process(
|
|
target=weight_updater_process,
|
|
args=(
|
|
state_dict,
|
|
num_heads,
|
|
num_kv_heads,
|
|
tp_rank,
|
|
self.parallel_config.tensor_parallel_size,
|
|
gpu_id,
|
|
),
|
|
daemon=True,
|
|
)
|
|
self.weight_updater_process.start()
|
|
|
|
print(f"[vLLM Patch] Weight updater daemon started (PID: {self.weight_updater_process.pid})")
|
|
|
|
return PatchedGPUModelRunner
|
|
|
|
|
|
class PatchedGPUModelRunner:
|
|
"""
|
|
Placeholder class for type checking.
|
|
|
|
The actual patched class is created dynamically by _create_patched_runner()
|
|
to properly inherit from vLLM's GPUModelRunner.
|
|
"""
|
|
pass
|
|
|
|
|