clearing more bloat

This commit is contained in:
Jai Suphavadeeprasit 2026-01-17 13:49:43 -05:00
parent ab8d2f2dac
commit 036b87e921
4 changed files with 27 additions and 682 deletions

View file

@ -1,13 +1,20 @@
"""
vLLM Patching Module - Enables shared memory weight updates.
vLLM Patching Module - Enables CUDA IPC shared memory for single-copy training.
This module patches vLLM's GPUModelRunner to:
1. Call share_memory_() on model weights after loading
2. Spawn a daemon process that receives NCCL weight updates from trainers
3. Enable real-time weight synchronization without restarting vLLM
2. Export CUDA IPC handles to vllm_bridge_config.json
3. Enable the trainer to attach to vLLM's tensors directly
The result: ONE copy of model weights in GPU memory, shared between
vLLM (inference) and the trainer (gradient updates).
Usage:
# Import this BEFORE importing vllm
# Set environment BEFORE importing
import os
os.environ["VLLM_ENABLE_SHARED_WEIGHTS"] = "1"
# Import and apply patches BEFORE importing vllm
from example_trainer.vllm_patching import apply_patches
apply_patches()
@ -21,24 +28,10 @@ from .patched_gpu_runner import (
get_patched_runner,
is_patched,
)
from .weight_updater import weight_updater_process
from .distributed_utils import (
init_process_group,
broadcast_object_list,
get_inference_urls,
get_json_data,
)
__all__ = [
"PatchedGPUModelRunner",
"apply_patches",
"get_patched_runner",
"is_patched",
"weight_updater_process",
"init_process_group",
"broadcast_object_list",
"get_inference_urls",
"get_json_data",
]

View file

@ -1,328 +0,0 @@
"""
Distributed utilities for vLLM weight synchronization.
Provides process group initialization and communication helpers
for coordinating weight updates between trainer and vLLM.
"""
from __future__ import annotations
import json
import os
import socket
import time
from collections import defaultdict
from datetime import timedelta
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.distributed as dist
def init_process_group(
backend: Optional[str] = None,
init_method: Optional[str] = None,
timeout: Optional[timedelta] = None,
world_size: int = -1,
rank: int = -1,
store: Optional[Any] = None,
group_name: str = "",
pg_options: Optional[Any] = None,
) -> dist.ProcessGroup:
"""
Initialize a custom process group for weight synchronization.
This creates a named process group that coexists with vLLM's internal
process groups, enabling direct tensor communication between trainer
and inference processes.
Args:
backend: "nccl" for GPU, "gloo" for CPU
init_method: Rendezvous URL (e.g., "tcp://host:port")
timeout: How long to wait for other ranks
world_size: Total number of processes
rank: This process's rank
store: Optional torch.distributed Store
group_name: Name for this process group (must match across ranks)
pg_options: Backend-specific options
Returns:
ProcessGroup for collective operations
"""
from torch.distributed.distributed_c10d import (
_new_process_group_helper,
_world,
Backend,
default_pg_timeout,
PrefixStore,
rendezvous,
)
assert (store is None) or (init_method is None), \
"Cannot specify both init_method and store."
if store is not None:
assert world_size > 0, "world_size must be positive if using store"
assert rank >= 0, "rank must be non-negative if using store"
elif init_method is None:
init_method = "env://"
if backend:
backend = Backend(backend)
else:
backend = Backend("undefined")
if timeout is None:
timeout = default_pg_timeout
# Create store via rendezvous if not provided
if store is None:
rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
store, rank, world_size = next(rendezvous_iterator)
store.set_timeout(timeout)
store = PrefixStore(group_name, store)
# Handle PyTorch version differences for pg_options parameter
pg_options_param_name = (
"backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
)
pg, _ = _new_process_group_helper(
world_size,
rank,
[],
backend,
store,
group_name=group_name,
**{pg_options_param_name: pg_options},
timeout=timeout,
)
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
return pg
def broadcast_object_list(
object_list: List[Any],
src: Optional[int] = None,
group: Optional[dist.ProcessGroup] = None,
device: Optional[torch.device] = None,
group_src: Optional[int] = None,
) -> None:
"""
Broadcast a list of objects from source rank to all other ranks.
Modified from torch.distributed.broadcast_object_list to work correctly
with custom process groups where rank 0 may not be the default group's rank 0.
Args:
object_list: List of objects to broadcast (modified in-place on receivers)
src: Global source rank (deprecated, use group_src)
group: Process group to use
device: Device for temporary tensors
group_src: Source rank within the group
"""
global_src = group_src if group_src is not None else src
current_device = device
# Broadcast object sizes first
object_sizes_tensor = torch.empty(
len(object_list), dtype=torch.long, device=current_device
)
dist.broadcast(object_sizes_tensor, src=global_src, group=group)
# Broadcast serialized objects
object_tensor = torch.empty(
torch.sum(object_sizes_tensor).item(),
dtype=torch.uint8,
device=current_device,
)
dist.broadcast(object_tensor, src=global_src, group=group)
# Deserialize objects
offset = 0
for i, obj_size in enumerate(object_sizes_tensor):
obj_view = object_tensor[offset : offset + obj_size]
obj_view = obj_view.type(torch.uint8)
offset += obj_size
object_list[i] = dist.distributed_c10d._tensor_to_object(
obj_view, obj_size, group
)
def get_inference_urls(num_inference_nodes: int = 0) -> Tuple[Optional[str], ...]:
"""
Get URLs for inference server communication.
Parses SLURM environment or uses localhost for single-machine setup.
Args:
num_inference_nodes: Number of dedicated inference nodes.
0 = single machine, trainer and vLLM share the node
>0 = multi-node, last N nodes are for inference
Returns:
Tuple of (master_addr, master_gloo_addr, master_inference_addr, nodelist)
Returns (None, None, None, None) if not in a valid setup.
"""
if num_inference_nodes > 0:
# Multi-node SLURM setup
slurm_nodelist = os.environ.get("SLURM_JOB_NODELIST")
if not slurm_nodelist:
return None, None, None, None
# Parse SLURM node list
nodelist = (
os.popen(f'scontrol show hostnames {slurm_nodelist}')
.read()
.strip()
.split("\n")
)
nodelist = [node for node in nodelist if node]
# First node is master for process groups
master_server = f"{nodelist[0]}:26756"
master_gloo_server = f"{nodelist[0]}:26757"
# Last N nodes are inference nodes
inference_nodes = nodelist[-num_inference_nodes:]
master_inference_server = f"{inference_nodes[0]}:26758"
return master_server, master_gloo_server, master_inference_server, inference_nodes
elif num_inference_nodes == 0:
# Single machine setup
master_server = "localhost:26756"
master_gloo_server = "localhost:26757"
master_inference_server = "localhost:26758"
nodelist = ["localhost"]
return master_server, master_gloo_server, master_inference_server, nodelist
else:
return None, None, None, None
def get_hostnames() -> Optional[List[str]]:
"""
Get the hostnames for this machine.
Parses /etc/hosts to find all hostnames associated with this machine's IP.
Returns:
List of [ip, hostname1, hostname2, ...] or None if not found.
"""
my_ip = socket.gethostbyname(socket.gethostname())
my_hostname = socket.gethostname()
try:
with open("/etc/hosts", "r") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#"):
parts = line.split()
if len(parts) >= 2 and ((parts[0] == my_ip) or (my_hostname in parts)):
ip = parts[0]
if ip.startswith("127."):
continue
return parts
except Exception:
pass
return None
def get_json_data(log_dir: Optional[str] = None, timeout: int = 300) -> Dict[str, Any]:
"""
Load the bridge configuration JSON from vLLM.
Waits for the file to be created by vLLM's weight bridge setup.
Args:
log_dir: Directory containing the JSON file (defaults to LOGDIR env var)
timeout: Maximum seconds to wait for file
Returns:
Parsed JSON data with parameter mappings and configuration.
Raises:
ValueError: If LOGDIR not set and log_dir not provided
FileNotFoundError: If file not found after timeout
"""
if log_dir is None:
log_dir = os.environ.get("LOGDIR")
if log_dir is None:
raise ValueError("LOGDIR environment variable not set and log_dir not provided")
json_path = os.path.join(log_dir, "vllm_bridge_config.json")
wait_time = 0
while not os.path.exists(json_path):
if wait_time >= timeout:
raise FileNotFoundError(f"Config file not found after {timeout}s: {json_path}")
if wait_time % 10 == 0:
print(f"[Updater] Waiting for {json_path}...", flush=True)
time.sleep(1)
wait_time += 1
# Wait a moment for file to finish writing
time.sleep(0.5)
with open(json_path, "r") as f:
return json.load(f)
def get_name_conversions(param_mappings: Dict[str, Any]) -> Dict[str, List[str]]:
"""
Build reverse mapping from vLLM names to trainer names.
Args:
param_mappings: Dict mapping trainer param names to vLLM info
Returns:
Dict mapping vLLM names to list of trainer names
"""
name_conversions = defaultdict(list)
for name, info in param_mappings.items():
vllm_name = info.get("vllm_name", name)
name_conversions[vllm_name].append(name)
return name_conversions
# Permutation functions for rotary embeddings
def permute(w: torch.Tensor, n_heads: int) -> torch.Tensor:
"""
Permute weight tensor for sliced rotary embeddings.
Args:
w: Weight tensor of shape [dim1, dim2]
n_heads: Number of attention heads
Returns:
Permuted tensor for rotary embedding compatibility
"""
dim1 = w.shape[0]
dim2 = w.shape[1]
return (
w.view(n_heads, dim1 // n_heads // 2, 2, dim2)
.transpose(1, 2)
.reshape(dim1, dim2)
)
def permute_1d(w: torch.Tensor, n_heads: int) -> torch.Tensor:
"""
Permute 1D weight tensor (bias) for sliced rotary embeddings.
Args:
w: Weight tensor of shape [dim1]
n_heads: Number of attention heads
Returns:
Permuted tensor
"""
dim1 = w.shape[0]
return w.view(n_heads, dim1 // n_heads // 2, 2).transpose(1, 2).reshape(dim1)

View file

@ -1,13 +1,15 @@
"""
Patched GPU Model Runner - Enables shared memory weight updates.
Patched GPU Model Runner - Enables CUDA IPC for single-copy training.
This patches vLLM's GPUModelRunner to:
1. Call share_memory_() on model weights after loading
2. Spawn a daemon process that receives NCCL weight updates from trainers
2. Export CUDA IPC handles to vllm_bridge_config.json
The key insight is that share_memory_() makes tensors accessible from
multiple processes. The daemon receives updates via NCCL and copies them
directly into the shared tensors, which vLLM reads for inference.
The key insight is that CUDA IPC handles allow the trainer process to
attach to the EXACT SAME GPU memory that vLLM uses. This means:
- ONE copy of model weights in GPU memory
- Trainer's optimizer.step() updates vLLM's weights directly
- No synchronization needed - vLLM immediately sees new weights
CRITICAL: This module must be imported and apply_patches() called BEFORE
any vLLM imports. The patches MUST happen before vLLM caches module references.
@ -119,28 +121,24 @@ def _create_patched_runner(BaseRunner: type) -> type:
Create a patched GPUModelRunner class.
Returns a new class that inherits from the original and adds
shared memory + daemon functionality.
CUDA IPC export functionality for single-copy training.
"""
import torch
import torch.multiprocessing as mp
from .weight_updater import weight_updater_process
class PatchedGPUModelRunner(BaseRunner):
"""
Patched GPUModelRunner that enables shared memory weight updates.
Patched GPUModelRunner that enables CUDA IPC for single-copy training.
After loading the model, this:
1. Calls share_memory_() on all parameters to make them accessible
from other processes
2. Spawns a daemon process that joins NCCL groups with the trainer
and receives weight updates
The daemon copies updates directly into the shared tensors, so
vLLM immediately sees the new weights for inference.
1. Calls share_memory_() on all parameters
2. Exports CUDA IPC handles to vllm_bridge_config.json
The trainer reads these IPC handles and attaches to the SAME
GPU memory, so optimizer.step() updates weights that vLLM
immediately sees for inference.
"""
_shared_memory_setup_done = False
weight_updater_process = None
def load_model(self, *args, **kwargs) -> None:
"""Load model and set up shared memory + update daemon."""
@ -171,27 +169,12 @@ def _create_patched_runner(BaseRunner: type) -> type:
self._setup_shared_memory()
PatchedGPUModelRunner._shared_memory_setup_done = True
print("[vLLM Patch] ✓ Shared memory setup complete!", flush=True)
print("[vLLM Patch] ✓ IPC handles exported - trainer can now attach!", flush=True)
except Exception as e:
print(f"[vLLM Patch] ERROR in _setup_shared_memory: {e}", flush=True)
import traceback
traceback.print_exc()
return
# Spawn weight updater daemon (optional - can be skipped for HTTP-only mode)
skip_daemon = os.environ.get("VLLM_SKIP_WEIGHT_DAEMON", "0") == "1"
if skip_daemon:
print("[vLLM Patch] Skipping weight updater daemon (VLLM_SKIP_WEIGHT_DAEMON=1)", flush=True)
return
try:
print("[vLLM Patch] Spawning weight updater daemon...", flush=True)
self._spawn_weight_updater()
print("[vLLM Patch] ✓ Weight updater daemon spawned!", flush=True)
except Exception as e:
print(f"[vLLM Patch] ERROR spawning weight updater: {e}", flush=True)
import traceback
traceback.print_exc()
print("[vLLM Patch] Continuing without daemon (HTTP-only mode)", flush=True)
def _setup_shared_memory(self) -> None:
"""Move model tensors to shared memory and export param info."""
@ -326,70 +309,6 @@ def _create_patched_runner(BaseRunner: type) -> type:
import traceback
traceback.print_exc()
def _spawn_weight_updater(self) -> None:
"""Start the weight updater as a background thread.
Note: We use threading instead of multiprocessing because vLLM's
worker processes are daemons, and daemons cannot spawn child processes.
"""
import threading
print("[vLLM Patch] _spawn_weight_updater() called", flush=True)
try:
from vllm.distributed import get_tensor_model_parallel_rank
print("[vLLM Patch] Imported get_tensor_model_parallel_rank", flush=True)
except ImportError as e:
print(f"[vLLM Patch] Could not import get_tensor_model_parallel_rank: {e}", flush=True)
get_tensor_model_parallel_rank = lambda: 0
# Get model configuration
state_dict = self.model.state_dict()
print(f"[vLLM Patch] Got state_dict with {len(state_dict)} params", flush=True)
# Get attention head counts
hf_config = self.model_config.hf_text_config
num_heads = getattr(hf_config, "num_attention_heads", 0)
num_kv_heads = self.model_config.get_total_num_kv_heads()
print(f"[vLLM Patch] num_heads={num_heads}, num_kv_heads={num_kv_heads}", flush=True)
# Get parallel configuration
tp_rank = get_tensor_model_parallel_rank()
print(f"[vLLM Patch] tp_rank={tp_rank}", flush=True)
# Get GPU ID
gpu_id = 0
try:
if hasattr(self, 'device'):
if hasattr(self.device, 'index'):
gpu_id = self.device.index or 0
elif isinstance(self.device, int):
gpu_id = self.device
except Exception:
gpu_id = tp_rank
print(f"[vLLM Patch] Starting weight updater thread: tp_rank={tp_rank}, gpu={gpu_id}", flush=True)
# Start as a daemon thread (threads CAN be started from daemon processes)
self.weight_updater_thread = threading.Thread(
target=weight_updater_process,
args=(
state_dict,
num_heads,
num_kv_heads,
tp_rank,
self.parallel_config.tensor_parallel_size,
gpu_id,
),
daemon=True,
name=f"WeightUpdater_TP{tp_rank}",
)
print("[vLLM Patch] Starting thread...", flush=True)
self.weight_updater_thread.start()
print(f"[vLLM Patch] ✓ Weight updater thread started (name: {self.weight_updater_thread.name})", flush=True)
# Set proper class name
PatchedGPUModelRunner.__name__ = "PatchedGPUModelRunner"
PatchedGPUModelRunner.__qualname__ = "PatchedGPUModelRunner"

View file

@ -1,239 +0,0 @@
"""
Weight Updater Process - Daemon that receives NCCL weight updates.
This process runs as a daemon spawned by the patched vLLM GPUModelRunner.
It joins NCCL process groups with the trainer and receives weight updates,
copying them directly into vLLM's shared memory tensors.
"""
from __future__ import annotations
import os
import time
from typing import Dict
import torch
import torch.distributed as dist
from .distributed_utils import (
init_process_group,
get_inference_urls,
get_hostnames,
)
def weight_updater_process(
state_dict: Dict[str, torch.Tensor],
num_q_heads: int,
num_kv_heads: int,
tp_rank: int,
tp_size: int,
gpu_id: int,
) -> None:
"""
Daemon process that receives weight updates from trainers via NCCL.
This runs inside a subprocess spawned by PatchedGPUModelRunner. It:
1. Joins NCCL/Gloo process groups with the trainer
2. Receives weight update broadcasts from rank 0 (trainer)
3. Copies updated weights directly into the shared state_dict
Since state_dict tensors have share_memory_() called on them, the main
vLLM process immediately sees the updates for inference.
Args:
state_dict: Model state dict with shared memory tensors
num_q_heads: Number of query attention heads (for permutation)
num_kv_heads: Number of key/value attention heads
tp_rank: Tensor parallel rank of this worker
tp_size: Total tensor parallel size
gpu_id: GPU device ID for this worker
"""
# Configuration from environment
num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", 0))
debug = int(os.environ.get("WEIGHT_UPDATER_DEBUG", 0))
# Get network info
master_addr, master_gloo_addr, master_inference_addr, urls = get_inference_urls(
num_inference_nodes
)
if master_addr is None:
print(f"[Updater] Master address not found, exiting", flush=True)
return
# Set CUDA device
torch.cuda.set_device(tp_rank)
print(
f"[Updater] Starting on TP rank {tp_rank}/{tp_size}, "
f"q_heads={num_q_heads}, kv_heads={num_kv_heads}, gpu_id={gpu_id}",
flush=True,
)
# For single-node mode (num_inference_nodes=0):
# - Trainer is rank 0
# - Inference daemon is rank 1 (or tp_rank + 1 for multi-GPU)
# Total world size = 1 trainer + 1 inference = 2
#
# For multi-node mode:
# - More complex, based on SLURM node allocation
if num_inference_nodes == 0:
# Single node: simple setup
# World = [trainer (rank 0), inference daemon (rank 1)]
num_training_ranks = 1
num_inference_ranks = 1
world_size = num_training_ranks + num_inference_ranks
rank = num_training_ranks + tp_rank # Daemon is rank 1
else:
# Multi-node: 8 GPUs per node
hostnames = get_hostnames()
cuda_devices = str(os.environ.get("CUDA_VISIBLE_DEVICES", "0")).split(",")
ranks_per_node = 8
world_size = num_inference_nodes * ranks_per_node
rank = -1
for i, url in enumerate(urls or []):
if hostnames and url in hostnames:
rank = ranks_per_node * i + int(cuda_devices[gpu_id])
break
if rank < 0:
print(f"[Updater] Could not determine rank for multi-node, exiting", flush=True)
return
print(f"[Updater] Master: {master_addr}, world_size={world_size}, my_rank={rank}", flush=True)
# Use state_dict keys as parameter list (we already have the model!)
param_name_list = sorted(state_dict.keys())
print(f"[Updater] Model has {len(param_name_list)} parameters", flush=True)
# Use the world_size and rank we already calculated
total_group_size = world_size
# For single-node mode, trainer is rank 0
num_training_ranks = 1 if num_inference_nodes == 0 else 1
print(f"[Updater] Total group size: {total_group_size}", flush=True)
print(f"[Updater] Training ranks: {num_training_ranks}", flush=True)
print(f"[Updater] My rank: {rank}", flush=True)
# Initialize process groups
print("[Updater] Creating process groups...", flush=True)
try:
# Gloo group for coordination
gloo_group = init_process_group(
backend="gloo",
init_method=f"tcp://{master_addr}",
world_size=total_group_size,
rank=rank,
group_name="gloo_group",
)
print("[Updater] ✓ Gloo group created", flush=True)
# NCCL group for tensor transfers
nccl_group = init_process_group(
backend="nccl",
init_method=f"tcp://{master_addr}",
world_size=total_group_size,
rank=rank,
group_name="weight_update_group",
)
print("[Updater] ✓ NCCL group created", flush=True)
# Barrier synchronization to confirm both sides are ready
print("[Updater] Waiting for trainer to be ready...", flush=True)
dist.barrier(group=gloo_group)
print("[Updater] ✓ Trainer is ready, starting update loop", flush=True)
except Exception as e:
print(f"[Updater] Failed to create process groups: {e}", flush=True)
import traceback
traceback.print_exc()
return
# Get device for tensors
my_device = next(iter(state_dict.values())).device
# Build param info dict from state_dict
param_info_dict = {}
for name, tensor in state_dict.items():
param_info_dict[name] = {
"shape": list(tensor.shape),
"dtype": tensor.dtype,
}
print("[Updater] Entering update loop...", flush=True)
print(f"[Updater] Waiting for weight updates from trainer (rank 0)...", flush=True)
update_count = 0
with torch.no_grad():
while True:
try:
# Receive parameter index from trainer (rank 0)
obj_indx = torch.zeros(1, dtype=torch.long, device=my_device)
dist.broadcast(obj_indx, src=0, group=nccl_group)
tt_indx = obj_indx.item()
# -1 signals heartbeat (no update)
if tt_indx == -1:
continue
# -2 signals shutdown
if tt_indx == -2:
print("[Updater] Received shutdown signal", flush=True)
break
# Get parameter info
if tt_indx < 0 or tt_indx >= len(param_name_list):
if debug:
print(f"[Updater] Invalid index {tt_indx}, skipping", flush=True)
continue
param_name = param_name_list[tt_indx]
if param_name not in state_dict:
if debug:
print(f"[Updater] {param_name} not in state_dict, skipping", flush=True)
continue
target_tensor = state_dict[param_name]
target_shape = list(target_tensor.shape)
target_dtype = target_tensor.dtype
# Receive the tensor from trainer
# Trainer sends via broadcast, we receive
received_tensor = torch.zeros(target_shape, dtype=target_dtype, device=my_device)
dist.broadcast(received_tensor, src=0, group=nccl_group)
# Copy to shared memory
state_dict[param_name].data.copy_(received_tensor)
update_count += 1
if debug or (update_count % 50 == 0):
print(f"[Updater] Updated {param_name} (#{update_count})", flush=True)
except torch.distributed.DistBackendError as e:
# NCCL communication failure - likely trainer crashed
error_str = str(e)
if "Broken pipe" in error_str or "Connection reset" in error_str:
print("[Updater] Trainer disconnected (broken pipe). Exiting.", flush=True)
break
else:
print(f"[Updater] NCCL error: {e}", flush=True)
import traceback
traceback.print_exc()
time.sleep(1)
except Exception as e:
print(f"[Updater] Error in update loop: {e}", flush=True)
import traceback
traceback.print_exc()
time.sleep(1)
# Note: Advanced multi-GPU tensor parallelism support removed for simplicity.
# For single-node mode, we use direct tensor broadcast which is sufficient.