atropos/example_trainer/vllm_patching/weight_updater.py
Jai Suphavadeeprasit e2c99f7f97 daemon errors
2026-03-02 11:18:51 -05:00

437 lines
16 KiB
Python

"""
Weight Updater Process - Daemon that receives NCCL weight updates.
This process runs as a daemon spawned by the patched vLLM GPUModelRunner.
It joins NCCL process groups with the trainer and receives weight updates,
copying them directly into vLLM's shared memory tensors.
"""
from __future__ import annotations
import json
import os
import time
from typing import Any, Dict, List, Optional
import torch
import torch.distributed as dist
from .distributed_utils import (
init_process_group,
get_inference_urls,
get_hostnames,
get_json_data,
get_name_conversions,
permute,
permute_1d,
)
def weight_updater_process(
state_dict: Dict[str, torch.Tensor],
num_q_heads: int,
num_kv_heads: int,
tp_rank: int,
tp_size: int,
gpu_id: int,
) -> None:
"""
Daemon process that receives weight updates from trainers via NCCL.
This runs inside a subprocess spawned by PatchedGPUModelRunner. It:
1. Joins NCCL/Gloo process groups with the trainer
2. Receives weight update broadcasts from rank 0 (trainer)
3. Copies updated weights directly into the shared state_dict
Since state_dict tensors have share_memory_() called on them, the main
vLLM process immediately sees the updates for inference.
Args:
state_dict: Model state dict with shared memory tensors
num_q_heads: Number of query attention heads (for permutation)
num_kv_heads: Number of key/value attention heads
tp_rank: Tensor parallel rank of this worker
tp_size: Total tensor parallel size
gpu_id: GPU device ID for this worker
"""
# Configuration from environment
num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", 0))
debug = int(os.environ.get("WEIGHT_UPDATER_DEBUG", 0))
# Get network info
master_addr, master_gloo_addr, master_inference_addr, urls = get_inference_urls(
num_inference_nodes
)
if master_addr is None:
print(f"[Updater] Master address not found, exiting", flush=True)
return
# Set CUDA device
torch.cuda.set_device(tp_rank)
print(
f"[Updater] Starting on TP rank {tp_rank}/{tp_size}, "
f"q_heads={num_q_heads}, kv_heads={num_kv_heads}, gpu_id={gpu_id}",
flush=True,
)
# For single-node mode (num_inference_nodes=0):
# - Trainer is rank 0
# - Inference daemon is rank 1 (or tp_rank + 1 for multi-GPU)
# Total world size = 1 trainer + 1 inference = 2
#
# For multi-node mode:
# - More complex, based on SLURM node allocation
if num_inference_nodes == 0:
# Single node: simple setup
# World = [trainer (rank 0), inference daemon (rank 1)]
num_training_ranks = 1
num_inference_ranks = 1
world_size = num_training_ranks + num_inference_ranks
rank = num_training_ranks + tp_rank # Daemon is rank 1
else:
# Multi-node: 8 GPUs per node
hostnames = get_hostnames()
cuda_devices = str(os.environ.get("CUDA_VISIBLE_DEVICES", "0")).split(",")
ranks_per_node = 8
world_size = num_inference_nodes * ranks_per_node
rank = -1
for i, url in enumerate(urls or []):
if hostnames and url in hostnames:
rank = ranks_per_node * i + int(cuda_devices[gpu_id])
break
if rank < 0:
print(f"[Updater] Could not determine rank for multi-node, exiting", flush=True)
return
print(f"[Updater] Master: {master_addr}, world_size={world_size}, my_rank={rank}", flush=True)
# Load config from vLLM (optional - may not exist for simple setups)
print("[Updater] Loading bridge config...", flush=True)
json_data = {}
param_name_list = []
try:
json_data = get_json_data()
param_name_list = sorted(json_data.get("param_mappings", {}).keys())
print(f"[Updater] Loaded {len(param_name_list)} parameter mappings", flush=True)
except Exception as e:
print(f"[Updater] No config file found (will receive all params): {e}", flush=True)
# Use the world_size and rank we already calculated
total_group_size = world_size
# For single-node mode, trainer is rank 0
if num_inference_nodes == 0:
num_training_gpus = 1
else:
num_training_gpus = json_data.get("dp_shard_degree", 1) * json_data.get("tp_degree", 1)
print(f"[Updater] Total group size: {total_group_size}", flush=True)
print(f"[Updater] Training ranks: {num_training_gpus}", flush=True)
print(f"[Updater] My rank: {rank}", flush=True)
# Initialize process groups
print("[Updater] Creating process groups...", flush=True)
try:
# Gloo group for coordination
gloo_group = init_process_group(
backend="gloo",
init_method=f"tcp://{master_addr}",
world_size=total_group_size,
rank=rank,
group_name="gloo_group",
)
print("[Updater] ✓ Gloo group created", flush=True)
# NCCL group for tensor transfers
nccl_group = init_process_group(
backend="nccl",
init_method=f"tcp://{master_addr}",
world_size=total_group_size,
rank=rank,
group_name="weight_update_group",
)
print("[Updater] ✓ NCCL group created", flush=True)
except Exception as e:
print(f"[Updater] Failed to create process groups: {e}", flush=True)
import traceback
traceback.print_exc()
return
# Get device for tensors
my_device = next(iter(state_dict.values())).device
# Write dtype mapping if rank 0
if rank == num_training_gpus: # First inference rank
_write_dtype_mapping(state_dict, json_data)
print("[Updater] Entering update loop...", flush=True)
# Buffers for merged QKV and gate_up projections
qkv_buffer = {}
gate_up_buffer = {}
qkv_bias_buffer = {}
w1w3_buffer = {}
with torch.no_grad():
while True:
try:
# Receive parameter index from trainer (rank 0)
obj_indx = torch.zeros(1, dtype=torch.long, device=my_device)
dist.broadcast(obj_indx, src=0, group=nccl_group)
tt_indx = obj_indx.item()
# -1 signals no update this round (heartbeat)
if tt_indx == -1:
continue
# Get parameter info
if tt_indx >= len(param_name_list):
print(f"[Updater] Invalid index {tt_indx}, skipping", flush=True)
continue
tt_name = param_name_list[tt_indx]
param_info = json_data["param_mappings"].get(tt_name, {})
vllm_name = param_info.get("vllm_name", tt_name)
local_shape = param_info.get("local_shape", [])
if vllm_name not in state_dict:
if debug:
print(f"[Updater] {vllm_name} not in state_dict, skipping", flush=True)
continue
target_dtype = state_dict[vllm_name].dtype
if debug:
print(
f"[Updater] Receiving {tt_name} -> {vllm_name}, "
f"shape={local_shape}, dtype={target_dtype}",
flush=True,
)
# Gather tensors from all training ranks
tensor_list = [
torch.zeros(
local_shape if idx < num_training_gpus else [1],
dtype=target_dtype,
device=my_device,
)
for idx in range(total_group_size)
]
dist.all_gather(
tensor_list,
torch.zeros(1, dtype=target_dtype, device=my_device),
group=nccl_group,
)
# Only keep training tensors
tensor_list = tensor_list[:num_training_gpus]
# Merge tensors from different parallel configurations
tensor = _merge_tensors(
tensor_list,
json_data,
param_info,
state_dict[vllm_name],
)
# Apply updates (handling merged QKV, gate_up, etc.)
_apply_weight_update(
state_dict,
vllm_name,
tt_name,
tensor,
param_info,
num_q_heads,
num_kv_heads,
qkv_buffer,
gate_up_buffer,
qkv_bias_buffer,
w1w3_buffer,
debug,
)
except Exception as e:
print(f"[Updater] Error in update loop: {e}", flush=True)
import traceback
traceback.print_exc()
time.sleep(1)
def _write_dtype_mapping(
state_dict: Dict[str, torch.Tensor],
json_data: Dict[str, Any],
) -> None:
"""Write dtype mapping file for trainer reference."""
try:
log_dir = os.environ.get("LOGDIR", ".")
name_conversions = get_name_conversions(json_data.get("param_mappings", {}))
weight_dtypes = {}
for name in state_dict.keys():
tt_names = name_conversions.get(name, [name])
for tt_name in tt_names:
weight_dtypes[tt_name] = str(state_dict[name].dtype).split(".")[-1]
with open(f"{log_dir}/vllm_dtypes.json", "w") as f:
json.dump(weight_dtypes, f, indent=2)
print("[Updater] Wrote dtype mapping", flush=True)
except Exception as e:
print(f"[Updater] Failed to write dtype mapping: {e}", flush=True)
def _merge_tensors(
tensor_list: List[torch.Tensor],
json_data: Dict[str, Any],
param_info: Dict[str, Any],
target_tensor: torch.Tensor,
) -> torch.Tensor:
"""
Merge tensors from distributed training into single tensor.
Handles FSDP (data parallel) and TP (tensor parallel) sharding.
"""
dp_shard_degree = json_data.get("dp_shard_degree", 1)
tp_degree = json_data.get("tp_degree", 1)
tp_shard_dim = param_info.get("tp_shard_dim", 0)
if dp_shard_degree > 1:
# First merge across data parallel dimension
tp_tensors = []
for i in range(tp_degree):
dp_tensors = tensor_list[i::tp_degree]
tp_tensors.append(torch.cat(dp_tensors, dim=0))
# Then merge across tensor parallel dimension if needed
if tp_degree > 1:
if tp_tensors[0].shape == target_tensor.shape:
tensor = tp_tensors[0].contiguous()
else:
tensor = torch.cat(tp_tensors, dim=tp_shard_dim).contiguous()
else:
tensor = tp_tensors[0].contiguous()
else:
# No FSDP, just merge TP shards
tensor = torch.cat(tensor_list, dim=tp_shard_dim).contiguous()
# Cast to target dtype if needed
if tensor.dtype != target_tensor.dtype:
tensor = tensor.to(target_tensor.dtype)
return tensor
def _apply_weight_update(
state_dict: Dict[str, torch.Tensor],
vllm_name: str,
tt_name: str,
tensor: torch.Tensor,
param_info: Dict[str, Any],
num_q_heads: int,
num_kv_heads: int,
qkv_buffer: Dict[str, torch.Tensor],
gate_up_buffer: Dict[str, torch.Tensor],
qkv_bias_buffer: Dict[str, torch.Tensor],
w1w3_buffer: Dict[str, torch.Tensor],
debug: bool,
) -> None:
"""
Apply weight update to state_dict, handling merged projections.
vLLM often merges QKV projections and gate/up projections into single
tensors for efficiency. This handles unpacking and merging correctly.
"""
needs_permute = param_info.get("needs_permute", False)
shape = param_info.get("shape", list(tensor.shape))
def _debug_diff(name: str, old: torch.Tensor, new: torch.Tensor) -> None:
if debug:
diff = (new.float() - old.float()).abs()
print(
f"[WEIGHT DIFF] {name}: mean={diff.mean().item():.6e}, "
f"std={diff.std().item():.6e}",
flush=True,
)
# Handle merged QKV projection weights
if "qkv_proj.weight" in vllm_name:
key_val = "q" if ".wq." in tt_name or "q_proj" in tt_name else \
"v" if ".wv." in tt_name or "v_proj" in tt_name else "k"
if key_val == "q" and needs_permute:
tensor = permute(tensor, num_q_heads)
elif key_val == "k" and needs_permute:
tensor = permute(tensor, num_kv_heads)
qkv_buffer[key_val] = tensor
if len(qkv_buffer) == 3:
merged = torch.cat([qkv_buffer["q"], qkv_buffer["k"], qkv_buffer["v"]], dim=0)
_debug_diff(vllm_name, state_dict[vllm_name].data, merged)
state_dict[vllm_name].data.copy_(merged.contiguous())
qkv_buffer.clear()
# Handle merged gate/up projection weights
elif "gate_up_proj.weight" in vllm_name:
key_val = "w1" if ".w1." in tt_name or "gate_proj" in tt_name else "w3"
gate_up_buffer[key_val] = tensor
if len(gate_up_buffer) == 2:
merged = torch.cat([gate_up_buffer["w1"], gate_up_buffer["w3"]], dim=0)
_debug_diff(vllm_name, state_dict[vllm_name].data, merged)
state_dict[vllm_name].data.copy_(merged.contiguous())
gate_up_buffer.clear()
# Handle merged w1/w3 weights (alternative naming)
elif "w13_weight" in vllm_name:
key_val = "w1" if ".w1" in tt_name else "w3"
w1w3_buffer[key_val] = tensor
if len(w1w3_buffer) == 2:
merged = torch.cat([w1w3_buffer["w1"], w1w3_buffer["w3"]], dim=1)
_debug_diff(vllm_name, state_dict[vllm_name].data, merged)
state_dict[vllm_name].data.copy_(merged.contiguous())
w1w3_buffer.clear()
# Handle merged QKV bias
elif "qkv_proj.bias" in vllm_name:
key_val = "q" if ".wq." in tt_name else "v" if ".wv." in tt_name else "k"
if key_val == "q" and needs_permute:
tensor = permute_1d(tensor, num_q_heads)
elif key_val == "k" and needs_permute:
tensor = permute_1d(tensor, num_kv_heads)
qkv_bias_buffer[key_val] = tensor
if len(qkv_bias_buffer) == 3:
merged = torch.cat([qkv_bias_buffer["q"], qkv_bias_buffer["k"], qkv_bias_buffer["v"]], dim=0)
_debug_diff(vllm_name, state_dict[vllm_name].data, merged)
state_dict[vllm_name].data.copy_(merged.contiguous())
qkv_bias_buffer.clear()
# Handle regular weights (possibly needing permutation)
elif needs_permute:
if len(shape) == 2:
tensor = permute(tensor, shape[0]).contiguous()
elif len(shape) == 1:
tensor = permute_1d(tensor, shape[0]).contiguous()
_debug_diff(vllm_name, state_dict[vllm_name].data, tensor)
state_dict[vllm_name].data.copy_(tensor)
# Simple weight copy
else:
_debug_diff(vllm_name, state_dict[vllm_name].data, tensor)
state_dict[vllm_name].data.copy_(tensor)