param locations update

This commit is contained in:
Jai Suphavadeeprasit 2025-12-29 20:11:23 -05:00
parent e2c99f7f97
commit ff8eaf9e3c
4 changed files with 117 additions and 284 deletions

View file

@ -111,19 +111,47 @@ def _create_patched_runner(BaseRunner: type) -> type:
traceback.print_exc()
def _setup_shared_memory(self) -> None:
"""Move model tensors to shared memory."""
"""Move model tensors to shared memory and export param info."""
import json
from pathlib import Path
# Make entire model shareable
self.model.share_memory()
# Also share_memory_() on each parameter individually
# (some implementations may need this)
state_dict = self.model.state_dict()
for key, val in state_dict.items():
if val.is_cuda or val.device.type == 'cuda':
# For CUDA tensors, we need to ensure they're in shared memory
val.share_memory_()
print(f"[vLLM Patch] Shared {len(state_dict)} tensors in model")
# Export parameter info to JSON for trainer
log_dir = os.environ.get("LOGDIR", ".")
json_path = Path(log_dir) / "vllm_bridge_config.json"
param_mappings = {}
for name, tensor in state_dict.items():
param_mappings[name] = {
"vllm_name": name,
"shape": list(tensor.shape),
"dtype": str(tensor.dtype),
}
info = {
"model": str(self.model_config.model),
"tp_degree": self.parallel_config.tensor_parallel_size,
"dp_shard_degree": 1,
"param_mappings": param_mappings,
"param_names": sorted(state_dict.keys()),
}
try:
with open(json_path, "w") as f:
json.dump(info, f, indent=2)
print(f"[vLLM Patch] Exported {len(param_mappings)} params to {json_path}")
except Exception as e:
print(f"[vLLM Patch] Warning: Failed to export params: {e}")
def _spawn_weight_updater(self) -> None:
"""Spawn the daemon process for receiving weight updates."""

View file

@ -8,10 +8,9 @@ copying them directly into vLLM's shared memory tensors.
from __future__ import annotations
import json
import os
import time
from typing import Any, Dict, List, Optional
from typing import Dict
import torch
import torch.distributed as dist
@ -20,10 +19,6 @@ from .distributed_utils import (
init_process_group,
get_inference_urls,
get_hostnames,
get_json_data,
get_name_conversions,
permute,
permute_1d,
)
@ -110,28 +105,18 @@ def weight_updater_process(
print(f"[Updater] Master: {master_addr}, world_size={world_size}, my_rank={rank}", flush=True)
# Load config from vLLM (optional - may not exist for simple setups)
print("[Updater] Loading bridge config...", flush=True)
json_data = {}
param_name_list = []
try:
json_data = get_json_data()
param_name_list = sorted(json_data.get("param_mappings", {}).keys())
print(f"[Updater] Loaded {len(param_name_list)} parameter mappings", flush=True)
except Exception as e:
print(f"[Updater] No config file found (will receive all params): {e}", flush=True)
# Use state_dict keys as parameter list (we already have the model!)
param_name_list = sorted(state_dict.keys())
print(f"[Updater] Model has {len(param_name_list)} parameters", flush=True)
# Use the world_size and rank we already calculated
total_group_size = world_size
# For single-node mode, trainer is rank 0
if num_inference_nodes == 0:
num_training_gpus = 1
else:
num_training_gpus = json_data.get("dp_shard_degree", 1) * json_data.get("tp_degree", 1)
num_training_ranks = 1 if num_inference_nodes == 0 else 1
print(f"[Updater] Total group size: {total_group_size}", flush=True)
print(f"[Updater] Training ranks: {num_training_gpus}", flush=True)
print(f"[Updater] Training ranks: {num_training_ranks}", flush=True)
print(f"[Updater] My rank: {rank}", flush=True)
# Initialize process groups
@ -167,17 +152,18 @@ def weight_updater_process(
# Get device for tensors
my_device = next(iter(state_dict.values())).device
# Write dtype mapping if rank 0
if rank == num_training_gpus: # First inference rank
_write_dtype_mapping(state_dict, json_data)
# Build param info dict from state_dict
param_info_dict = {}
for name, tensor in state_dict.items():
param_info_dict[name] = {
"shape": list(tensor.shape),
"dtype": tensor.dtype,
}
print("[Updater] Entering update loop...", flush=True)
print(f"[Updater] Waiting for weight updates from trainer (rank 0)...", flush=True)
# Buffers for merged QKV and gate_up projections
qkv_buffer = {}
gate_up_buffer = {}
qkv_bias_buffer = {}
w1w3_buffer = {}
update_count = 0
with torch.no_grad():
while True:
@ -188,76 +174,43 @@ def weight_updater_process(
tt_indx = obj_indx.item()
# -1 signals no update this round (heartbeat)
# -1 signals heartbeat (no update)
if tt_indx == -1:
continue
# -2 signals shutdown
if tt_indx == -2:
print("[Updater] Received shutdown signal", flush=True)
break
# Get parameter info
if tt_indx >= len(param_name_list):
print(f"[Updater] Invalid index {tt_indx}, skipping", flush=True)
continue
tt_name = param_name_list[tt_indx]
param_info = json_data["param_mappings"].get(tt_name, {})
vllm_name = param_info.get("vllm_name", tt_name)
local_shape = param_info.get("local_shape", [])
if vllm_name not in state_dict:
if tt_indx < 0 or tt_indx >= len(param_name_list):
if debug:
print(f"[Updater] {vllm_name} not in state_dict, skipping", flush=True)
print(f"[Updater] Invalid index {tt_indx}, skipping", flush=True)
continue
target_dtype = state_dict[vllm_name].dtype
param_name = param_name_list[tt_indx]
if debug:
print(
f"[Updater] Receiving {tt_name} -> {vllm_name}, "
f"shape={local_shape}, dtype={target_dtype}",
flush=True,
)
if param_name not in state_dict:
if debug:
print(f"[Updater] {param_name} not in state_dict, skipping", flush=True)
continue
# Gather tensors from all training ranks
tensor_list = [
torch.zeros(
local_shape if idx < num_training_gpus else [1],
dtype=target_dtype,
device=my_device,
)
for idx in range(total_group_size)
]
target_tensor = state_dict[param_name]
target_shape = list(target_tensor.shape)
target_dtype = target_tensor.dtype
dist.all_gather(
tensor_list,
torch.zeros(1, dtype=target_dtype, device=my_device),
group=nccl_group,
)
# Receive the tensor from trainer
# Trainer sends via broadcast, we receive
received_tensor = torch.zeros(target_shape, dtype=target_dtype, device=my_device)
dist.broadcast(received_tensor, src=0, group=nccl_group)
# Only keep training tensors
tensor_list = tensor_list[:num_training_gpus]
# Copy to shared memory
state_dict[param_name].data.copy_(received_tensor)
# Merge tensors from different parallel configurations
tensor = _merge_tensors(
tensor_list,
json_data,
param_info,
state_dict[vllm_name],
)
# Apply updates (handling merged QKV, gate_up, etc.)
_apply_weight_update(
state_dict,
vllm_name,
tt_name,
tensor,
param_info,
num_q_heads,
num_kv_heads,
qkv_buffer,
gate_up_buffer,
qkv_bias_buffer,
w1w3_buffer,
debug,
)
update_count += 1
if debug or (update_count % 50 == 0):
print(f"[Updater] Updated {param_name} (#{update_count})", flush=True)
except Exception as e:
print(f"[Updater] Error in update loop: {e}", flush=True)
@ -266,172 +219,5 @@ def weight_updater_process(
time.sleep(1)
def _write_dtype_mapping(
state_dict: Dict[str, torch.Tensor],
json_data: Dict[str, Any],
) -> None:
"""Write dtype mapping file for trainer reference."""
try:
log_dir = os.environ.get("LOGDIR", ".")
name_conversions = get_name_conversions(json_data.get("param_mappings", {}))
weight_dtypes = {}
for name in state_dict.keys():
tt_names = name_conversions.get(name, [name])
for tt_name in tt_names:
weight_dtypes[tt_name] = str(state_dict[name].dtype).split(".")[-1]
with open(f"{log_dir}/vllm_dtypes.json", "w") as f:
json.dump(weight_dtypes, f, indent=2)
print("[Updater] Wrote dtype mapping", flush=True)
except Exception as e:
print(f"[Updater] Failed to write dtype mapping: {e}", flush=True)
def _merge_tensors(
tensor_list: List[torch.Tensor],
json_data: Dict[str, Any],
param_info: Dict[str, Any],
target_tensor: torch.Tensor,
) -> torch.Tensor:
"""
Merge tensors from distributed training into single tensor.
Handles FSDP (data parallel) and TP (tensor parallel) sharding.
"""
dp_shard_degree = json_data.get("dp_shard_degree", 1)
tp_degree = json_data.get("tp_degree", 1)
tp_shard_dim = param_info.get("tp_shard_dim", 0)
if dp_shard_degree > 1:
# First merge across data parallel dimension
tp_tensors = []
for i in range(tp_degree):
dp_tensors = tensor_list[i::tp_degree]
tp_tensors.append(torch.cat(dp_tensors, dim=0))
# Then merge across tensor parallel dimension if needed
if tp_degree > 1:
if tp_tensors[0].shape == target_tensor.shape:
tensor = tp_tensors[0].contiguous()
else:
tensor = torch.cat(tp_tensors, dim=tp_shard_dim).contiguous()
else:
tensor = tp_tensors[0].contiguous()
else:
# No FSDP, just merge TP shards
tensor = torch.cat(tensor_list, dim=tp_shard_dim).contiguous()
# Cast to target dtype if needed
if tensor.dtype != target_tensor.dtype:
tensor = tensor.to(target_tensor.dtype)
return tensor
def _apply_weight_update(
state_dict: Dict[str, torch.Tensor],
vllm_name: str,
tt_name: str,
tensor: torch.Tensor,
param_info: Dict[str, Any],
num_q_heads: int,
num_kv_heads: int,
qkv_buffer: Dict[str, torch.Tensor],
gate_up_buffer: Dict[str, torch.Tensor],
qkv_bias_buffer: Dict[str, torch.Tensor],
w1w3_buffer: Dict[str, torch.Tensor],
debug: bool,
) -> None:
"""
Apply weight update to state_dict, handling merged projections.
vLLM often merges QKV projections and gate/up projections into single
tensors for efficiency. This handles unpacking and merging correctly.
"""
needs_permute = param_info.get("needs_permute", False)
shape = param_info.get("shape", list(tensor.shape))
def _debug_diff(name: str, old: torch.Tensor, new: torch.Tensor) -> None:
if debug:
diff = (new.float() - old.float()).abs()
print(
f"[WEIGHT DIFF] {name}: mean={diff.mean().item():.6e}, "
f"std={diff.std().item():.6e}",
flush=True,
)
# Handle merged QKV projection weights
if "qkv_proj.weight" in vllm_name:
key_val = "q" if ".wq." in tt_name or "q_proj" in tt_name else \
"v" if ".wv." in tt_name or "v_proj" in tt_name else "k"
if key_val == "q" and needs_permute:
tensor = permute(tensor, num_q_heads)
elif key_val == "k" and needs_permute:
tensor = permute(tensor, num_kv_heads)
qkv_buffer[key_val] = tensor
if len(qkv_buffer) == 3:
merged = torch.cat([qkv_buffer["q"], qkv_buffer["k"], qkv_buffer["v"]], dim=0)
_debug_diff(vllm_name, state_dict[vllm_name].data, merged)
state_dict[vllm_name].data.copy_(merged.contiguous())
qkv_buffer.clear()
# Handle merged gate/up projection weights
elif "gate_up_proj.weight" in vllm_name:
key_val = "w1" if ".w1." in tt_name or "gate_proj" in tt_name else "w3"
gate_up_buffer[key_val] = tensor
if len(gate_up_buffer) == 2:
merged = torch.cat([gate_up_buffer["w1"], gate_up_buffer["w3"]], dim=0)
_debug_diff(vllm_name, state_dict[vllm_name].data, merged)
state_dict[vllm_name].data.copy_(merged.contiguous())
gate_up_buffer.clear()
# Handle merged w1/w3 weights (alternative naming)
elif "w13_weight" in vllm_name:
key_val = "w1" if ".w1" in tt_name else "w3"
w1w3_buffer[key_val] = tensor
if len(w1w3_buffer) == 2:
merged = torch.cat([w1w3_buffer["w1"], w1w3_buffer["w3"]], dim=1)
_debug_diff(vllm_name, state_dict[vllm_name].data, merged)
state_dict[vllm_name].data.copy_(merged.contiguous())
w1w3_buffer.clear()
# Handle merged QKV bias
elif "qkv_proj.bias" in vllm_name:
key_val = "q" if ".wq." in tt_name else "v" if ".wv." in tt_name else "k"
if key_val == "q" and needs_permute:
tensor = permute_1d(tensor, num_q_heads)
elif key_val == "k" and needs_permute:
tensor = permute_1d(tensor, num_kv_heads)
qkv_bias_buffer[key_val] = tensor
if len(qkv_bias_buffer) == 3:
merged = torch.cat([qkv_bias_buffer["q"], qkv_bias_buffer["k"], qkv_bias_buffer["v"]], dim=0)
_debug_diff(vllm_name, state_dict[vllm_name].data, merged)
state_dict[vllm_name].data.copy_(merged.contiguous())
qkv_bias_buffer.clear()
# Handle regular weights (possibly needing permutation)
elif needs_permute:
if len(shape) == 2:
tensor = permute(tensor, shape[0]).contiguous()
elif len(shape) == 1:
tensor = permute_1d(tensor, shape[0]).contiguous()
_debug_diff(vllm_name, state_dict[vllm_name].data, tensor)
state_dict[vllm_name].data.copy_(tensor)
# Simple weight copy
else:
_debug_diff(vllm_name, state_dict[vllm_name].data, tensor)
state_dict[vllm_name].data.copy_(tensor)
# Note: Advanced multi-GPU tensor parallelism support removed for simplicity.
# For single-node mode, we use direct tensor broadcast which is sufficient.