clearing more bloat

This commit is contained in:
Jai Suphavadeeprasit 2026-01-17 13:49:43 -05:00
parent ab8d2f2dac
commit 036b87e921
4 changed files with 27 additions and 682 deletions

View file

@ -1,13 +1,15 @@
"""
Patched GPU Model Runner - Enables shared memory weight updates.
Patched GPU Model Runner - Enables CUDA IPC for single-copy training.
This patches vLLM's GPUModelRunner to:
1. Call share_memory_() on model weights after loading
2. Spawn a daemon process that receives NCCL weight updates from trainers
2. Export CUDA IPC handles to vllm_bridge_config.json
The key insight is that share_memory_() makes tensors accessible from
multiple processes. The daemon receives updates via NCCL and copies them
directly into the shared tensors, which vLLM reads for inference.
The key insight is that CUDA IPC handles allow the trainer process to
attach to the EXACT SAME GPU memory that vLLM uses. This means:
- ONE copy of model weights in GPU memory
- Trainer's optimizer.step() updates vLLM's weights directly
- No synchronization needed - vLLM immediately sees new weights
CRITICAL: This module must be imported and apply_patches() called BEFORE
any vLLM imports. The patches MUST happen before vLLM caches module references.
@ -119,28 +121,24 @@ def _create_patched_runner(BaseRunner: type) -> type:
Create a patched GPUModelRunner class.
Returns a new class that inherits from the original and adds
shared memory + daemon functionality.
CUDA IPC export functionality for single-copy training.
"""
import torch
import torch.multiprocessing as mp
from .weight_updater import weight_updater_process
class PatchedGPUModelRunner(BaseRunner):
"""
Patched GPUModelRunner that enables shared memory weight updates.
Patched GPUModelRunner that enables CUDA IPC for single-copy training.
After loading the model, this:
1. Calls share_memory_() on all parameters to make them accessible
from other processes
2. Spawns a daemon process that joins NCCL groups with the trainer
and receives weight updates
The daemon copies updates directly into the shared tensors, so
vLLM immediately sees the new weights for inference.
1. Calls share_memory_() on all parameters
2. Exports CUDA IPC handles to vllm_bridge_config.json
The trainer reads these IPC handles and attaches to the SAME
GPU memory, so optimizer.step() updates weights that vLLM
immediately sees for inference.
"""
_shared_memory_setup_done = False
weight_updater_process = None
def load_model(self, *args, **kwargs) -> None:
"""Load model and set up shared memory + update daemon."""
@ -171,27 +169,12 @@ def _create_patched_runner(BaseRunner: type) -> type:
self._setup_shared_memory()
PatchedGPUModelRunner._shared_memory_setup_done = True
print("[vLLM Patch] ✓ Shared memory setup complete!", flush=True)
print("[vLLM Patch] ✓ IPC handles exported - trainer can now attach!", flush=True)
except Exception as e:
print(f"[vLLM Patch] ERROR in _setup_shared_memory: {e}", flush=True)
import traceback
traceback.print_exc()
return
# Spawn weight updater daemon (optional - can be skipped for HTTP-only mode)
skip_daemon = os.environ.get("VLLM_SKIP_WEIGHT_DAEMON", "0") == "1"
if skip_daemon:
print("[vLLM Patch] Skipping weight updater daemon (VLLM_SKIP_WEIGHT_DAEMON=1)", flush=True)
return
try:
print("[vLLM Patch] Spawning weight updater daemon...", flush=True)
self._spawn_weight_updater()
print("[vLLM Patch] ✓ Weight updater daemon spawned!", flush=True)
except Exception as e:
print(f"[vLLM Patch] ERROR spawning weight updater: {e}", flush=True)
import traceback
traceback.print_exc()
print("[vLLM Patch] Continuing without daemon (HTTP-only mode)", flush=True)
def _setup_shared_memory(self) -> None:
"""Move model tensors to shared memory and export param info."""
@ -326,70 +309,6 @@ def _create_patched_runner(BaseRunner: type) -> type:
import traceback
traceback.print_exc()
def _spawn_weight_updater(self) -> None:
"""Start the weight updater as a background thread.
Note: We use threading instead of multiprocessing because vLLM's
worker processes are daemons, and daemons cannot spawn child processes.
"""
import threading
print("[vLLM Patch] _spawn_weight_updater() called", flush=True)
try:
from vllm.distributed import get_tensor_model_parallel_rank
print("[vLLM Patch] Imported get_tensor_model_parallel_rank", flush=True)
except ImportError as e:
print(f"[vLLM Patch] Could not import get_tensor_model_parallel_rank: {e}", flush=True)
get_tensor_model_parallel_rank = lambda: 0
# Get model configuration
state_dict = self.model.state_dict()
print(f"[vLLM Patch] Got state_dict with {len(state_dict)} params", flush=True)
# Get attention head counts
hf_config = self.model_config.hf_text_config
num_heads = getattr(hf_config, "num_attention_heads", 0)
num_kv_heads = self.model_config.get_total_num_kv_heads()
print(f"[vLLM Patch] num_heads={num_heads}, num_kv_heads={num_kv_heads}", flush=True)
# Get parallel configuration
tp_rank = get_tensor_model_parallel_rank()
print(f"[vLLM Patch] tp_rank={tp_rank}", flush=True)
# Get GPU ID
gpu_id = 0
try:
if hasattr(self, 'device'):
if hasattr(self.device, 'index'):
gpu_id = self.device.index or 0
elif isinstance(self.device, int):
gpu_id = self.device
except Exception:
gpu_id = tp_rank
print(f"[vLLM Patch] Starting weight updater thread: tp_rank={tp_rank}, gpu={gpu_id}", flush=True)
# Start as a daemon thread (threads CAN be started from daemon processes)
self.weight_updater_thread = threading.Thread(
target=weight_updater_process,
args=(
state_dict,
num_heads,
num_kv_heads,
tp_rank,
self.parallel_config.tensor_parallel_size,
gpu_id,
),
daemon=True,
name=f"WeightUpdater_TP{tp_rank}",
)
print("[vLLM Patch] Starting thread...", flush=True)
self.weight_updater_thread.start()
print(f"[vLLM Patch] ✓ Weight updater thread started (name: {self.weight_updater_thread.name})", flush=True)
# Set proper class name
PatchedGPUModelRunner.__name__ = "PatchedGPUModelRunner"
PatchedGPUModelRunner.__qualname__ = "PatchedGPUModelRunner"