changes based on torchtitan

This commit is contained in:
Jai Suphavadeeprasit 2025-12-28 12:27:29 -05:00
parent 078dd4a333
commit 53b29472b4
7 changed files with 1535 additions and 1977 deletions

View file

@ -0,0 +1,37 @@
"""
vLLM Patching Module - Enables shared memory weight updates.
This module patches vLLM's GPUModelRunner to:
1. Call share_memory_() on model weights after loading
2. Spawn a daemon process that receives NCCL weight updates from trainers
3. Enable real-time weight synchronization without restarting vLLM
Usage:
# Import this BEFORE importing vllm
from example_trainer.vllm_patching import apply_patches
apply_patches()
# Then import vllm normally
from vllm import AsyncLLM
"""
from .patched_gpu_runner import PatchedGPUModelRunner, apply_patches
from .weight_updater import weight_updater_process
from .distributed_utils import (
init_process_group,
broadcast_object_list,
get_inference_urls,
get_json_data,
)
__all__ = [
"PatchedGPUModelRunner",
"apply_patches",
"weight_updater_process",
"init_process_group",
"broadcast_object_list",
"get_inference_urls",
"get_json_data",
]

View file

@ -0,0 +1,328 @@
"""
Distributed utilities for vLLM weight synchronization.
Provides process group initialization and communication helpers
for coordinating weight updates between trainer and vLLM.
"""
from __future__ import annotations
import json
import os
import socket
import time
from collections import defaultdict
from datetime import timedelta
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.distributed as dist
def init_process_group(
backend: Optional[str] = None,
init_method: Optional[str] = None,
timeout: Optional[timedelta] = None,
world_size: int = -1,
rank: int = -1,
store: Optional[Any] = None,
group_name: str = "",
pg_options: Optional[Any] = None,
) -> dist.ProcessGroup:
"""
Initialize a custom process group for weight synchronization.
This creates a named process group that coexists with vLLM's internal
process groups, enabling direct tensor communication between trainer
and inference processes.
Args:
backend: "nccl" for GPU, "gloo" for CPU
init_method: Rendezvous URL (e.g., "tcp://host:port")
timeout: How long to wait for other ranks
world_size: Total number of processes
rank: This process's rank
store: Optional torch.distributed Store
group_name: Name for this process group (must match across ranks)
pg_options: Backend-specific options
Returns:
ProcessGroup for collective operations
"""
from torch.distributed.distributed_c10d import (
_new_process_group_helper,
_world,
Backend,
default_pg_timeout,
PrefixStore,
rendezvous,
)
assert (store is None) or (init_method is None), \
"Cannot specify both init_method and store."
if store is not None:
assert world_size > 0, "world_size must be positive if using store"
assert rank >= 0, "rank must be non-negative if using store"
elif init_method is None:
init_method = "env://"
if backend:
backend = Backend(backend)
else:
backend = Backend("undefined")
if timeout is None:
timeout = default_pg_timeout
# Create store via rendezvous if not provided
if store is None:
rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
store, rank, world_size = next(rendezvous_iterator)
store.set_timeout(timeout)
store = PrefixStore(group_name, store)
# Handle PyTorch version differences for pg_options parameter
pg_options_param_name = (
"backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
)
pg, _ = _new_process_group_helper(
world_size,
rank,
[],
backend,
store,
group_name=group_name,
**{pg_options_param_name: pg_options},
timeout=timeout,
)
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
return pg
def broadcast_object_list(
object_list: List[Any],
src: Optional[int] = None,
group: Optional[dist.ProcessGroup] = None,
device: Optional[torch.device] = None,
group_src: Optional[int] = None,
) -> None:
"""
Broadcast a list of objects from source rank to all other ranks.
Modified from torch.distributed.broadcast_object_list to work correctly
with custom process groups where rank 0 may not be the default group's rank 0.
Args:
object_list: List of objects to broadcast (modified in-place on receivers)
src: Global source rank (deprecated, use group_src)
group: Process group to use
device: Device for temporary tensors
group_src: Source rank within the group
"""
global_src = group_src if group_src is not None else src
current_device = device
# Broadcast object sizes first
object_sizes_tensor = torch.empty(
len(object_list), dtype=torch.long, device=current_device
)
dist.broadcast(object_sizes_tensor, src=global_src, group=group)
# Broadcast serialized objects
object_tensor = torch.empty(
torch.sum(object_sizes_tensor).item(),
dtype=torch.uint8,
device=current_device,
)
dist.broadcast(object_tensor, src=global_src, group=group)
# Deserialize objects
offset = 0
for i, obj_size in enumerate(object_sizes_tensor):
obj_view = object_tensor[offset : offset + obj_size]
obj_view = obj_view.type(torch.uint8)
offset += obj_size
object_list[i] = dist.distributed_c10d._tensor_to_object(
obj_view, obj_size, group
)
def get_inference_urls(num_inference_nodes: int = 0) -> Tuple[Optional[str], ...]:
"""
Get URLs for inference server communication.
Parses SLURM environment or uses localhost for single-machine setup.
Args:
num_inference_nodes: Number of dedicated inference nodes.
0 = single machine, trainer and vLLM share the node
>0 = multi-node, last N nodes are for inference
Returns:
Tuple of (master_addr, master_gloo_addr, master_inference_addr, nodelist)
Returns (None, None, None, None) if not in a valid setup.
"""
if num_inference_nodes > 0:
# Multi-node SLURM setup
slurm_nodelist = os.environ.get("SLURM_JOB_NODELIST")
if not slurm_nodelist:
return None, None, None, None
# Parse SLURM node list
nodelist = (
os.popen(f'scontrol show hostnames {slurm_nodelist}')
.read()
.strip()
.split("\n")
)
nodelist = [node for node in nodelist if node]
# First node is master for process groups
master_server = f"{nodelist[0]}:26756"
master_gloo_server = f"{nodelist[0]}:26757"
# Last N nodes are inference nodes
inference_nodes = nodelist[-num_inference_nodes:]
master_inference_server = f"{inference_nodes[0]}:26758"
return master_server, master_gloo_server, master_inference_server, inference_nodes
elif num_inference_nodes == 0:
# Single machine setup
master_server = "localhost:26756"
master_gloo_server = "localhost:26757"
master_inference_server = "localhost:26758"
nodelist = ["localhost"]
return master_server, master_gloo_server, master_inference_server, nodelist
else:
return None, None, None, None
def get_hostnames() -> Optional[List[str]]:
"""
Get the hostnames for this machine.
Parses /etc/hosts to find all hostnames associated with this machine's IP.
Returns:
List of [ip, hostname1, hostname2, ...] or None if not found.
"""
my_ip = socket.gethostbyname(socket.gethostname())
my_hostname = socket.gethostname()
try:
with open("/etc/hosts", "r") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#"):
parts = line.split()
if len(parts) >= 2 and ((parts[0] == my_ip) or (my_hostname in parts)):
ip = parts[0]
if ip.startswith("127."):
continue
return parts
except Exception:
pass
return None
def get_json_data(log_dir: Optional[str] = None, timeout: int = 300) -> Dict[str, Any]:
"""
Load the bridge configuration JSON from vLLM.
Waits for the file to be created by vLLM's weight bridge setup.
Args:
log_dir: Directory containing the JSON file (defaults to LOGDIR env var)
timeout: Maximum seconds to wait for file
Returns:
Parsed JSON data with parameter mappings and configuration.
Raises:
ValueError: If LOGDIR not set and log_dir not provided
FileNotFoundError: If file not found after timeout
"""
if log_dir is None:
log_dir = os.environ.get("LOGDIR")
if log_dir is None:
raise ValueError("LOGDIR environment variable not set and log_dir not provided")
json_path = os.path.join(log_dir, "vllm_bridge_config.json")
wait_time = 0
while not os.path.exists(json_path):
if wait_time >= timeout:
raise FileNotFoundError(f"Config file not found after {timeout}s: {json_path}")
if wait_time % 10 == 0:
print(f"[Updater] Waiting for {json_path}...", flush=True)
time.sleep(1)
wait_time += 1
# Wait a moment for file to finish writing
time.sleep(0.5)
with open(json_path, "r") as f:
return json.load(f)
def get_name_conversions(param_mappings: Dict[str, Any]) -> Dict[str, List[str]]:
"""
Build reverse mapping from vLLM names to trainer names.
Args:
param_mappings: Dict mapping trainer param names to vLLM info
Returns:
Dict mapping vLLM names to list of trainer names
"""
name_conversions = defaultdict(list)
for name, info in param_mappings.items():
vllm_name = info.get("vllm_name", name)
name_conversions[vllm_name].append(name)
return name_conversions
# Permutation functions for rotary embeddings
def permute(w: torch.Tensor, n_heads: int) -> torch.Tensor:
"""
Permute weight tensor for sliced rotary embeddings.
Args:
w: Weight tensor of shape [dim1, dim2]
n_heads: Number of attention heads
Returns:
Permuted tensor for rotary embedding compatibility
"""
dim1 = w.shape[0]
dim2 = w.shape[1]
return (
w.view(n_heads, dim1 // n_heads // 2, 2, dim2)
.transpose(1, 2)
.reshape(dim1, dim2)
)
def permute_1d(w: torch.Tensor, n_heads: int) -> torch.Tensor:
"""
Permute 1D weight tensor (bias) for sliced rotary embeddings.
Args:
w: Weight tensor of shape [dim1]
n_heads: Number of attention heads
Returns:
Permuted tensor
"""
dim1 = w.shape[0]
return w.view(n_heads, dim1 // n_heads // 2, 2).transpose(1, 2).reshape(dim1)

View file

@ -0,0 +1,425 @@
"""
Weight Updater Process - Daemon that receives NCCL weight updates.
This process runs as a daemon spawned by the patched vLLM GPUModelRunner.
It joins NCCL process groups with the trainer and receives weight updates,
copying them directly into vLLM's shared memory tensors.
"""
from __future__ import annotations
import json
import os
import time
from typing import Any, Dict, List, Optional
import torch
import torch.distributed as dist
from .distributed_utils import (
init_process_group,
get_inference_urls,
get_hostnames,
get_json_data,
get_name_conversions,
permute,
permute_1d,
)
def weight_updater_process(
state_dict: Dict[str, torch.Tensor],
num_q_heads: int,
num_kv_heads: int,
tp_rank: int,
tp_size: int,
gpu_id: int,
) -> None:
"""
Daemon process that receives weight updates from trainers via NCCL.
This runs inside a subprocess spawned by PatchedGPUModelRunner. It:
1. Joins NCCL/Gloo process groups with the trainer
2. Receives weight update broadcasts from rank 0 (trainer)
3. Copies updated weights directly into the shared state_dict
Since state_dict tensors have share_memory_() called on them, the main
vLLM process immediately sees the updates for inference.
Args:
state_dict: Model state dict with shared memory tensors
num_q_heads: Number of query attention heads (for permutation)
num_kv_heads: Number of key/value attention heads
tp_rank: Tensor parallel rank of this worker
tp_size: Total tensor parallel size
gpu_id: GPU device ID for this worker
"""
# Configuration from environment
num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", 0))
cuda_devices = str(os.environ.get("CUDA_VISIBLE_DEVICES", "0")).split(",")
debug = int(os.environ.get("WEIGHT_UPDATER_DEBUG", 0))
# Determine world size based on setup
if num_inference_nodes > 0:
# Multi-node: 8 GPUs per node
world_size = num_inference_nodes * 8
ranks_per_node = 8
else:
# Single node: typically 4 inference GPUs
world_size = 4
ranks_per_node = 4
# Get network info
hostnames = get_hostnames()
master_addr, master_gloo_addr, master_inference_addr, urls = get_inference_urls(
num_inference_nodes
)
if master_addr is None:
print(f"[Updater] Master address not found, exiting", flush=True)
return
# Set CUDA device
torch.cuda.set_device(tp_rank)
print(
f"[Updater] Starting on TP rank {tp_rank}/{tp_size}, "
f"q_heads={num_q_heads}, kv_heads={num_kv_heads}, gpu_id={gpu_id}",
flush=True,
)
print(f"[Updater] Master: {master_addr}, world_size={world_size}", flush=True)
# Determine this worker's rank within the inference group
rank = -1
if num_inference_nodes == 0:
# Single node: skip first N GPUs (used by trainer)
rank = int(cuda_devices[gpu_id]) - (8 - ranks_per_node)
else:
# Multi-node: find which inference node we're on
for i, url in enumerate(urls):
if hostnames and url in hostnames:
rank = ranks_per_node * i + int(cuda_devices[gpu_id])
break
if rank < 0:
print(f"[Updater] Could not determine rank, exiting", flush=True)
return
# Load config from vLLM
print("[Updater] Loading bridge config...", flush=True)
try:
json_data = get_json_data()
except Exception as e:
print(f"[Updater] Failed to load config: {e}", flush=True)
return
param_name_list = sorted(json_data.get("param_mappings", {}).keys())
num_training_gpus = json_data.get("dp_shard_degree", 1) * json_data.get("tp_degree", 1)
total_group_size = num_training_gpus + world_size
# Offset rank by training GPUs
rank = rank + num_training_gpus
print(f"[Updater] Total group size: {total_group_size}", flush=True)
print(f"[Updater] Training GPUs: {num_training_gpus}", flush=True)
print(f"[Updater] My rank: {rank}", flush=True)
# Initialize process groups
print("[Updater] Creating process groups...", flush=True)
try:
# Gloo group for coordination
gloo_group = init_process_group(
backend="gloo",
init_method=f"tcp://{master_addr}",
world_size=total_group_size,
rank=rank,
group_name="gloo_group",
)
print("[Updater] ✓ Gloo group created", flush=True)
# NCCL group for tensor transfers
nccl_group = init_process_group(
backend="nccl",
init_method=f"tcp://{master_addr}",
world_size=total_group_size,
rank=rank,
group_name="weight_update_group",
)
print("[Updater] ✓ NCCL group created", flush=True)
except Exception as e:
print(f"[Updater] Failed to create process groups: {e}", flush=True)
return
# Get device for tensors
my_device = next(iter(state_dict.values())).device
# Write dtype mapping if rank 0
if rank == num_training_gpus: # First inference rank
_write_dtype_mapping(state_dict, json_data)
print("[Updater] Entering update loop...", flush=True)
# Buffers for merged QKV and gate_up projections
qkv_buffer = {}
gate_up_buffer = {}
qkv_bias_buffer = {}
w1w3_buffer = {}
with torch.no_grad():
while True:
try:
# Receive parameter index from trainer (rank 0)
obj_indx = torch.zeros(1, dtype=torch.long, device=my_device)
dist.broadcast(obj_indx, src=0, group=nccl_group)
tt_indx = obj_indx.item()
# -1 signals no update this round (heartbeat)
if tt_indx == -1:
continue
# Get parameter info
if tt_indx >= len(param_name_list):
print(f"[Updater] Invalid index {tt_indx}, skipping", flush=True)
continue
tt_name = param_name_list[tt_indx]
param_info = json_data["param_mappings"].get(tt_name, {})
vllm_name = param_info.get("vllm_name", tt_name)
local_shape = param_info.get("local_shape", [])
if vllm_name not in state_dict:
if debug:
print(f"[Updater] {vllm_name} not in state_dict, skipping", flush=True)
continue
target_dtype = state_dict[vllm_name].dtype
if debug:
print(
f"[Updater] Receiving {tt_name} -> {vllm_name}, "
f"shape={local_shape}, dtype={target_dtype}",
flush=True,
)
# Gather tensors from all training ranks
tensor_list = [
torch.zeros(
local_shape if idx < num_training_gpus else [1],
dtype=target_dtype,
device=my_device,
)
for idx in range(total_group_size)
]
dist.all_gather(
tensor_list,
torch.zeros(1, dtype=target_dtype, device=my_device),
group=nccl_group,
)
# Only keep training tensors
tensor_list = tensor_list[:num_training_gpus]
# Merge tensors from different parallel configurations
tensor = _merge_tensors(
tensor_list,
json_data,
param_info,
state_dict[vllm_name],
)
# Apply updates (handling merged QKV, gate_up, etc.)
_apply_weight_update(
state_dict,
vllm_name,
tt_name,
tensor,
param_info,
num_q_heads,
num_kv_heads,
qkv_buffer,
gate_up_buffer,
qkv_bias_buffer,
w1w3_buffer,
debug,
)
except Exception as e:
print(f"[Updater] Error in update loop: {e}", flush=True)
import traceback
traceback.print_exc()
time.sleep(1)
def _write_dtype_mapping(
state_dict: Dict[str, torch.Tensor],
json_data: Dict[str, Any],
) -> None:
"""Write dtype mapping file for trainer reference."""
try:
log_dir = os.environ.get("LOGDIR", ".")
name_conversions = get_name_conversions(json_data.get("param_mappings", {}))
weight_dtypes = {}
for name in state_dict.keys():
tt_names = name_conversions.get(name, [name])
for tt_name in tt_names:
weight_dtypes[tt_name] = str(state_dict[name].dtype).split(".")[-1]
with open(f"{log_dir}/vllm_dtypes.json", "w") as f:
json.dump(weight_dtypes, f, indent=2)
print("[Updater] Wrote dtype mapping", flush=True)
except Exception as e:
print(f"[Updater] Failed to write dtype mapping: {e}", flush=True)
def _merge_tensors(
tensor_list: List[torch.Tensor],
json_data: Dict[str, Any],
param_info: Dict[str, Any],
target_tensor: torch.Tensor,
) -> torch.Tensor:
"""
Merge tensors from distributed training into single tensor.
Handles FSDP (data parallel) and TP (tensor parallel) sharding.
"""
dp_shard_degree = json_data.get("dp_shard_degree", 1)
tp_degree = json_data.get("tp_degree", 1)
tp_shard_dim = param_info.get("tp_shard_dim", 0)
if dp_shard_degree > 1:
# First merge across data parallel dimension
tp_tensors = []
for i in range(tp_degree):
dp_tensors = tensor_list[i::tp_degree]
tp_tensors.append(torch.cat(dp_tensors, dim=0))
# Then merge across tensor parallel dimension if needed
if tp_degree > 1:
if tp_tensors[0].shape == target_tensor.shape:
tensor = tp_tensors[0].contiguous()
else:
tensor = torch.cat(tp_tensors, dim=tp_shard_dim).contiguous()
else:
tensor = tp_tensors[0].contiguous()
else:
# No FSDP, just merge TP shards
tensor = torch.cat(tensor_list, dim=tp_shard_dim).contiguous()
# Cast to target dtype if needed
if tensor.dtype != target_tensor.dtype:
tensor = tensor.to(target_tensor.dtype)
return tensor
def _apply_weight_update(
state_dict: Dict[str, torch.Tensor],
vllm_name: str,
tt_name: str,
tensor: torch.Tensor,
param_info: Dict[str, Any],
num_q_heads: int,
num_kv_heads: int,
qkv_buffer: Dict[str, torch.Tensor],
gate_up_buffer: Dict[str, torch.Tensor],
qkv_bias_buffer: Dict[str, torch.Tensor],
w1w3_buffer: Dict[str, torch.Tensor],
debug: bool,
) -> None:
"""
Apply weight update to state_dict, handling merged projections.
vLLM often merges QKV projections and gate/up projections into single
tensors for efficiency. This handles unpacking and merging correctly.
"""
needs_permute = param_info.get("needs_permute", False)
shape = param_info.get("shape", list(tensor.shape))
def _debug_diff(name: str, old: torch.Tensor, new: torch.Tensor) -> None:
if debug:
diff = (new.float() - old.float()).abs()
print(
f"[WEIGHT DIFF] {name}: mean={diff.mean().item():.6e}, "
f"std={diff.std().item():.6e}",
flush=True,
)
# Handle merged QKV projection weights
if "qkv_proj.weight" in vllm_name:
key_val = "q" if ".wq." in tt_name or "q_proj" in tt_name else \
"v" if ".wv." in tt_name or "v_proj" in tt_name else "k"
if key_val == "q" and needs_permute:
tensor = permute(tensor, num_q_heads)
elif key_val == "k" and needs_permute:
tensor = permute(tensor, num_kv_heads)
qkv_buffer[key_val] = tensor
if len(qkv_buffer) == 3:
merged = torch.cat([qkv_buffer["q"], qkv_buffer["k"], qkv_buffer["v"]], dim=0)
_debug_diff(vllm_name, state_dict[vllm_name].data, merged)
state_dict[vllm_name].data.copy_(merged.contiguous())
qkv_buffer.clear()
# Handle merged gate/up projection weights
elif "gate_up_proj.weight" in vllm_name:
key_val = "w1" if ".w1." in tt_name or "gate_proj" in tt_name else "w3"
gate_up_buffer[key_val] = tensor
if len(gate_up_buffer) == 2:
merged = torch.cat([gate_up_buffer["w1"], gate_up_buffer["w3"]], dim=0)
_debug_diff(vllm_name, state_dict[vllm_name].data, merged)
state_dict[vllm_name].data.copy_(merged.contiguous())
gate_up_buffer.clear()
# Handle merged w1/w3 weights (alternative naming)
elif "w13_weight" in vllm_name:
key_val = "w1" if ".w1" in tt_name else "w3"
w1w3_buffer[key_val] = tensor
if len(w1w3_buffer) == 2:
merged = torch.cat([w1w3_buffer["w1"], w1w3_buffer["w3"]], dim=1)
_debug_diff(vllm_name, state_dict[vllm_name].data, merged)
state_dict[vllm_name].data.copy_(merged.contiguous())
w1w3_buffer.clear()
# Handle merged QKV bias
elif "qkv_proj.bias" in vllm_name:
key_val = "q" if ".wq." in tt_name else "v" if ".wv." in tt_name else "k"
if key_val == "q" and needs_permute:
tensor = permute_1d(tensor, num_q_heads)
elif key_val == "k" and needs_permute:
tensor = permute_1d(tensor, num_kv_heads)
qkv_bias_buffer[key_val] = tensor
if len(qkv_bias_buffer) == 3:
merged = torch.cat([qkv_bias_buffer["q"], qkv_bias_buffer["k"], qkv_bias_buffer["v"]], dim=0)
_debug_diff(vllm_name, state_dict[vllm_name].data, merged)
state_dict[vllm_name].data.copy_(merged.contiguous())
qkv_bias_buffer.clear()
# Handle regular weights (possibly needing permutation)
elif needs_permute:
if len(shape) == 2:
tensor = permute(tensor, shape[0]).contiguous()
elif len(shape) == 1:
tensor = permute_1d(tensor, shape[0]).contiguous()
_debug_diff(vllm_name, state_dict[vllm_name].data, tensor)
state_dict[vllm_name].data.copy_(tensor)
# Simple weight copy
else:
_debug_diff(vllm_name, state_dict[vllm_name].data, tensor)
state_dict[vllm_name].data.copy_(tensor)