mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
437 lines
16 KiB
Python
437 lines
16 KiB
Python
"""
|
|
Weight Updater Process - Daemon that receives NCCL weight updates.
|
|
|
|
This process runs as a daemon spawned by the patched vLLM GPUModelRunner.
|
|
It joins NCCL process groups with the trainer and receives weight updates,
|
|
copying them directly into vLLM's shared memory tensors.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import time
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
from .distributed_utils import (
|
|
init_process_group,
|
|
get_inference_urls,
|
|
get_hostnames,
|
|
get_json_data,
|
|
get_name_conversions,
|
|
permute,
|
|
permute_1d,
|
|
)
|
|
|
|
|
|
def weight_updater_process(
|
|
state_dict: Dict[str, torch.Tensor],
|
|
num_q_heads: int,
|
|
num_kv_heads: int,
|
|
tp_rank: int,
|
|
tp_size: int,
|
|
gpu_id: int,
|
|
) -> None:
|
|
"""
|
|
Daemon process that receives weight updates from trainers via NCCL.
|
|
|
|
This runs inside a subprocess spawned by PatchedGPUModelRunner. It:
|
|
1. Joins NCCL/Gloo process groups with the trainer
|
|
2. Receives weight update broadcasts from rank 0 (trainer)
|
|
3. Copies updated weights directly into the shared state_dict
|
|
|
|
Since state_dict tensors have share_memory_() called on them, the main
|
|
vLLM process immediately sees the updates for inference.
|
|
|
|
Args:
|
|
state_dict: Model state dict with shared memory tensors
|
|
num_q_heads: Number of query attention heads (for permutation)
|
|
num_kv_heads: Number of key/value attention heads
|
|
tp_rank: Tensor parallel rank of this worker
|
|
tp_size: Total tensor parallel size
|
|
gpu_id: GPU device ID for this worker
|
|
"""
|
|
# Configuration from environment
|
|
num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", 0))
|
|
debug = int(os.environ.get("WEIGHT_UPDATER_DEBUG", 0))
|
|
|
|
# Get network info
|
|
master_addr, master_gloo_addr, master_inference_addr, urls = get_inference_urls(
|
|
num_inference_nodes
|
|
)
|
|
|
|
if master_addr is None:
|
|
print(f"[Updater] Master address not found, exiting", flush=True)
|
|
return
|
|
|
|
# Set CUDA device
|
|
torch.cuda.set_device(tp_rank)
|
|
|
|
print(
|
|
f"[Updater] Starting on TP rank {tp_rank}/{tp_size}, "
|
|
f"q_heads={num_q_heads}, kv_heads={num_kv_heads}, gpu_id={gpu_id}",
|
|
flush=True,
|
|
)
|
|
|
|
# For single-node mode (num_inference_nodes=0):
|
|
# - Trainer is rank 0
|
|
# - Inference daemon is rank 1 (or tp_rank + 1 for multi-GPU)
|
|
# Total world size = 1 trainer + 1 inference = 2
|
|
#
|
|
# For multi-node mode:
|
|
# - More complex, based on SLURM node allocation
|
|
|
|
if num_inference_nodes == 0:
|
|
# Single node: simple setup
|
|
# World = [trainer (rank 0), inference daemon (rank 1)]
|
|
num_training_ranks = 1
|
|
num_inference_ranks = 1
|
|
world_size = num_training_ranks + num_inference_ranks
|
|
rank = num_training_ranks + tp_rank # Daemon is rank 1
|
|
else:
|
|
# Multi-node: 8 GPUs per node
|
|
hostnames = get_hostnames()
|
|
cuda_devices = str(os.environ.get("CUDA_VISIBLE_DEVICES", "0")).split(",")
|
|
ranks_per_node = 8
|
|
world_size = num_inference_nodes * ranks_per_node
|
|
|
|
rank = -1
|
|
for i, url in enumerate(urls or []):
|
|
if hostnames and url in hostnames:
|
|
rank = ranks_per_node * i + int(cuda_devices[gpu_id])
|
|
break
|
|
|
|
if rank < 0:
|
|
print(f"[Updater] Could not determine rank for multi-node, exiting", flush=True)
|
|
return
|
|
|
|
print(f"[Updater] Master: {master_addr}, world_size={world_size}, my_rank={rank}", flush=True)
|
|
|
|
# Load config from vLLM (optional - may not exist for simple setups)
|
|
print("[Updater] Loading bridge config...", flush=True)
|
|
json_data = {}
|
|
param_name_list = []
|
|
try:
|
|
json_data = get_json_data()
|
|
param_name_list = sorted(json_data.get("param_mappings", {}).keys())
|
|
print(f"[Updater] Loaded {len(param_name_list)} parameter mappings", flush=True)
|
|
except Exception as e:
|
|
print(f"[Updater] No config file found (will receive all params): {e}", flush=True)
|
|
|
|
# Use the world_size and rank we already calculated
|
|
total_group_size = world_size
|
|
|
|
# For single-node mode, trainer is rank 0
|
|
if num_inference_nodes == 0:
|
|
num_training_gpus = 1
|
|
else:
|
|
num_training_gpus = json_data.get("dp_shard_degree", 1) * json_data.get("tp_degree", 1)
|
|
|
|
print(f"[Updater] Total group size: {total_group_size}", flush=True)
|
|
print(f"[Updater] Training ranks: {num_training_gpus}", flush=True)
|
|
print(f"[Updater] My rank: {rank}", flush=True)
|
|
|
|
# Initialize process groups
|
|
print("[Updater] Creating process groups...", flush=True)
|
|
|
|
try:
|
|
# Gloo group for coordination
|
|
gloo_group = init_process_group(
|
|
backend="gloo",
|
|
init_method=f"tcp://{master_addr}",
|
|
world_size=total_group_size,
|
|
rank=rank,
|
|
group_name="gloo_group",
|
|
)
|
|
print("[Updater] ✓ Gloo group created", flush=True)
|
|
|
|
# NCCL group for tensor transfers
|
|
nccl_group = init_process_group(
|
|
backend="nccl",
|
|
init_method=f"tcp://{master_addr}",
|
|
world_size=total_group_size,
|
|
rank=rank,
|
|
group_name="weight_update_group",
|
|
)
|
|
print("[Updater] ✓ NCCL group created", flush=True)
|
|
|
|
except Exception as e:
|
|
print(f"[Updater] Failed to create process groups: {e}", flush=True)
|
|
import traceback
|
|
traceback.print_exc()
|
|
return
|
|
|
|
# Get device for tensors
|
|
my_device = next(iter(state_dict.values())).device
|
|
|
|
# Write dtype mapping if rank 0
|
|
if rank == num_training_gpus: # First inference rank
|
|
_write_dtype_mapping(state_dict, json_data)
|
|
|
|
print("[Updater] Entering update loop...", flush=True)
|
|
|
|
# Buffers for merged QKV and gate_up projections
|
|
qkv_buffer = {}
|
|
gate_up_buffer = {}
|
|
qkv_bias_buffer = {}
|
|
w1w3_buffer = {}
|
|
|
|
with torch.no_grad():
|
|
while True:
|
|
try:
|
|
# Receive parameter index from trainer (rank 0)
|
|
obj_indx = torch.zeros(1, dtype=torch.long, device=my_device)
|
|
dist.broadcast(obj_indx, src=0, group=nccl_group)
|
|
|
|
tt_indx = obj_indx.item()
|
|
|
|
# -1 signals no update this round (heartbeat)
|
|
if tt_indx == -1:
|
|
continue
|
|
|
|
# Get parameter info
|
|
if tt_indx >= len(param_name_list):
|
|
print(f"[Updater] Invalid index {tt_indx}, skipping", flush=True)
|
|
continue
|
|
|
|
tt_name = param_name_list[tt_indx]
|
|
param_info = json_data["param_mappings"].get(tt_name, {})
|
|
vllm_name = param_info.get("vllm_name", tt_name)
|
|
local_shape = param_info.get("local_shape", [])
|
|
|
|
if vllm_name not in state_dict:
|
|
if debug:
|
|
print(f"[Updater] {vllm_name} not in state_dict, skipping", flush=True)
|
|
continue
|
|
|
|
target_dtype = state_dict[vllm_name].dtype
|
|
|
|
if debug:
|
|
print(
|
|
f"[Updater] Receiving {tt_name} -> {vllm_name}, "
|
|
f"shape={local_shape}, dtype={target_dtype}",
|
|
flush=True,
|
|
)
|
|
|
|
# Gather tensors from all training ranks
|
|
tensor_list = [
|
|
torch.zeros(
|
|
local_shape if idx < num_training_gpus else [1],
|
|
dtype=target_dtype,
|
|
device=my_device,
|
|
)
|
|
for idx in range(total_group_size)
|
|
]
|
|
|
|
dist.all_gather(
|
|
tensor_list,
|
|
torch.zeros(1, dtype=target_dtype, device=my_device),
|
|
group=nccl_group,
|
|
)
|
|
|
|
# Only keep training tensors
|
|
tensor_list = tensor_list[:num_training_gpus]
|
|
|
|
# Merge tensors from different parallel configurations
|
|
tensor = _merge_tensors(
|
|
tensor_list,
|
|
json_data,
|
|
param_info,
|
|
state_dict[vllm_name],
|
|
)
|
|
|
|
# Apply updates (handling merged QKV, gate_up, etc.)
|
|
_apply_weight_update(
|
|
state_dict,
|
|
vllm_name,
|
|
tt_name,
|
|
tensor,
|
|
param_info,
|
|
num_q_heads,
|
|
num_kv_heads,
|
|
qkv_buffer,
|
|
gate_up_buffer,
|
|
qkv_bias_buffer,
|
|
w1w3_buffer,
|
|
debug,
|
|
)
|
|
|
|
except Exception as e:
|
|
print(f"[Updater] Error in update loop: {e}", flush=True)
|
|
import traceback
|
|
traceback.print_exc()
|
|
time.sleep(1)
|
|
|
|
|
|
def _write_dtype_mapping(
|
|
state_dict: Dict[str, torch.Tensor],
|
|
json_data: Dict[str, Any],
|
|
) -> None:
|
|
"""Write dtype mapping file for trainer reference."""
|
|
try:
|
|
log_dir = os.environ.get("LOGDIR", ".")
|
|
name_conversions = get_name_conversions(json_data.get("param_mappings", {}))
|
|
|
|
weight_dtypes = {}
|
|
for name in state_dict.keys():
|
|
tt_names = name_conversions.get(name, [name])
|
|
for tt_name in tt_names:
|
|
weight_dtypes[tt_name] = str(state_dict[name].dtype).split(".")[-1]
|
|
|
|
with open(f"{log_dir}/vllm_dtypes.json", "w") as f:
|
|
json.dump(weight_dtypes, f, indent=2)
|
|
|
|
print("[Updater] Wrote dtype mapping", flush=True)
|
|
except Exception as e:
|
|
print(f"[Updater] Failed to write dtype mapping: {e}", flush=True)
|
|
|
|
|
|
def _merge_tensors(
|
|
tensor_list: List[torch.Tensor],
|
|
json_data: Dict[str, Any],
|
|
param_info: Dict[str, Any],
|
|
target_tensor: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Merge tensors from distributed training into single tensor.
|
|
|
|
Handles FSDP (data parallel) and TP (tensor parallel) sharding.
|
|
"""
|
|
dp_shard_degree = json_data.get("dp_shard_degree", 1)
|
|
tp_degree = json_data.get("tp_degree", 1)
|
|
tp_shard_dim = param_info.get("tp_shard_dim", 0)
|
|
|
|
if dp_shard_degree > 1:
|
|
# First merge across data parallel dimension
|
|
tp_tensors = []
|
|
for i in range(tp_degree):
|
|
dp_tensors = tensor_list[i::tp_degree]
|
|
tp_tensors.append(torch.cat(dp_tensors, dim=0))
|
|
|
|
# Then merge across tensor parallel dimension if needed
|
|
if tp_degree > 1:
|
|
if tp_tensors[0].shape == target_tensor.shape:
|
|
tensor = tp_tensors[0].contiguous()
|
|
else:
|
|
tensor = torch.cat(tp_tensors, dim=tp_shard_dim).contiguous()
|
|
else:
|
|
tensor = tp_tensors[0].contiguous()
|
|
else:
|
|
# No FSDP, just merge TP shards
|
|
tensor = torch.cat(tensor_list, dim=tp_shard_dim).contiguous()
|
|
|
|
# Cast to target dtype if needed
|
|
if tensor.dtype != target_tensor.dtype:
|
|
tensor = tensor.to(target_tensor.dtype)
|
|
|
|
return tensor
|
|
|
|
|
|
def _apply_weight_update(
|
|
state_dict: Dict[str, torch.Tensor],
|
|
vllm_name: str,
|
|
tt_name: str,
|
|
tensor: torch.Tensor,
|
|
param_info: Dict[str, Any],
|
|
num_q_heads: int,
|
|
num_kv_heads: int,
|
|
qkv_buffer: Dict[str, torch.Tensor],
|
|
gate_up_buffer: Dict[str, torch.Tensor],
|
|
qkv_bias_buffer: Dict[str, torch.Tensor],
|
|
w1w3_buffer: Dict[str, torch.Tensor],
|
|
debug: bool,
|
|
) -> None:
|
|
"""
|
|
Apply weight update to state_dict, handling merged projections.
|
|
|
|
vLLM often merges QKV projections and gate/up projections into single
|
|
tensors for efficiency. This handles unpacking and merging correctly.
|
|
"""
|
|
needs_permute = param_info.get("needs_permute", False)
|
|
shape = param_info.get("shape", list(tensor.shape))
|
|
|
|
def _debug_diff(name: str, old: torch.Tensor, new: torch.Tensor) -> None:
|
|
if debug:
|
|
diff = (new.float() - old.float()).abs()
|
|
print(
|
|
f"[WEIGHT DIFF] {name}: mean={diff.mean().item():.6e}, "
|
|
f"std={diff.std().item():.6e}",
|
|
flush=True,
|
|
)
|
|
|
|
# Handle merged QKV projection weights
|
|
if "qkv_proj.weight" in vllm_name:
|
|
key_val = "q" if ".wq." in tt_name or "q_proj" in tt_name else \
|
|
"v" if ".wv." in tt_name or "v_proj" in tt_name else "k"
|
|
|
|
if key_val == "q" and needs_permute:
|
|
tensor = permute(tensor, num_q_heads)
|
|
elif key_val == "k" and needs_permute:
|
|
tensor = permute(tensor, num_kv_heads)
|
|
|
|
qkv_buffer[key_val] = tensor
|
|
|
|
if len(qkv_buffer) == 3:
|
|
merged = torch.cat([qkv_buffer["q"], qkv_buffer["k"], qkv_buffer["v"]], dim=0)
|
|
_debug_diff(vllm_name, state_dict[vllm_name].data, merged)
|
|
state_dict[vllm_name].data.copy_(merged.contiguous())
|
|
qkv_buffer.clear()
|
|
|
|
# Handle merged gate/up projection weights
|
|
elif "gate_up_proj.weight" in vllm_name:
|
|
key_val = "w1" if ".w1." in tt_name or "gate_proj" in tt_name else "w3"
|
|
gate_up_buffer[key_val] = tensor
|
|
|
|
if len(gate_up_buffer) == 2:
|
|
merged = torch.cat([gate_up_buffer["w1"], gate_up_buffer["w3"]], dim=0)
|
|
_debug_diff(vllm_name, state_dict[vllm_name].data, merged)
|
|
state_dict[vllm_name].data.copy_(merged.contiguous())
|
|
gate_up_buffer.clear()
|
|
|
|
# Handle merged w1/w3 weights (alternative naming)
|
|
elif "w13_weight" in vllm_name:
|
|
key_val = "w1" if ".w1" in tt_name else "w3"
|
|
w1w3_buffer[key_val] = tensor
|
|
|
|
if len(w1w3_buffer) == 2:
|
|
merged = torch.cat([w1w3_buffer["w1"], w1w3_buffer["w3"]], dim=1)
|
|
_debug_diff(vllm_name, state_dict[vllm_name].data, merged)
|
|
state_dict[vllm_name].data.copy_(merged.contiguous())
|
|
w1w3_buffer.clear()
|
|
|
|
# Handle merged QKV bias
|
|
elif "qkv_proj.bias" in vllm_name:
|
|
key_val = "q" if ".wq." in tt_name else "v" if ".wv." in tt_name else "k"
|
|
|
|
if key_val == "q" and needs_permute:
|
|
tensor = permute_1d(tensor, num_q_heads)
|
|
elif key_val == "k" and needs_permute:
|
|
tensor = permute_1d(tensor, num_kv_heads)
|
|
|
|
qkv_bias_buffer[key_val] = tensor
|
|
|
|
if len(qkv_bias_buffer) == 3:
|
|
merged = torch.cat([qkv_bias_buffer["q"], qkv_bias_buffer["k"], qkv_bias_buffer["v"]], dim=0)
|
|
_debug_diff(vllm_name, state_dict[vllm_name].data, merged)
|
|
state_dict[vllm_name].data.copy_(merged.contiguous())
|
|
qkv_bias_buffer.clear()
|
|
|
|
# Handle regular weights (possibly needing permutation)
|
|
elif needs_permute:
|
|
if len(shape) == 2:
|
|
tensor = permute(tensor, shape[0]).contiguous()
|
|
elif len(shape) == 1:
|
|
tensor = permute_1d(tensor, shape[0]).contiguous()
|
|
|
|
_debug_diff(vllm_name, state_dict[vllm_name].data, tensor)
|
|
state_dict[vllm_name].data.copy_(tensor)
|
|
|
|
# Simple weight copy
|
|
else:
|
|
_debug_diff(vllm_name, state_dict[vllm_name].data, tensor)
|
|
state_dict[vllm_name].data.copy_(tensor)
|
|
|
|
|