fused memory

This commit is contained in:
Jai Suphavadeeprasit 2026-01-13 22:24:07 -05:00
parent 9a95ec5aa1
commit 906802299c

View file

@ -518,87 +518,119 @@ def _attach_to_vllm_shared_tensors(
hf_state_dict = {}
vllm_to_hf_mapping = _create_vllm_to_hf_mapping(model, ipc_handles)
# Helper to create tensor from IPC handle
def _create_ipc_tensor(ipc_info):
device_index = ipc_info["device_index"]
ipc_handle = base64.b64decode(ipc_info["ipc_handle_b64"])
storage_size = ipc_info["storage_size"]
storage_offset_orig = ipc_info["storage_offset_orig"]
ref_counter_handle = base64.b64decode(ipc_info["ref_counter_handle_b64"])
ref_counter_offset = ipc_info["ref_counter_offset"]
event_handle = base64.b64decode(ipc_info["event_handle_b64"])
event_sync_required = ipc_info["event_sync_required"]
share_tuple = (device_index, ipc_handle, storage_size, storage_offset_orig,
ref_counter_handle, ref_counter_offset, event_handle, event_sync_required)
storage = torch.UntypedStorage._new_shared_cuda(*share_tuple)
dtype = getattr(torch, ipc_info["dtype"].replace("torch.", ""))
tensor = torch.tensor([], dtype=dtype, device=f"cuda:{device_index}")
tensor.set_(storage, storage_offset=ipc_info["tensor_storage_offset"],
size=ipc_info["shape"], stride=ipc_info["stride"])
return tensor
# Cache for fused tensors (avoid recreating from IPC multiple times)
fused_tensor_cache = {}
# Extract fused mappings
fused_mappings = vllm_to_hf_mapping.pop('_fused_', {})
attached_count = 0
# First, handle direct mappings
for hf_name, vllm_name in vllm_to_hf_mapping.items():
if vllm_name not in ipc_handles:
continue
ipc_info = ipc_handles[vllm_name]
try:
# Reconstruct tensor from IPC handle
# We need all 8 items from the original _share_cuda_() call
ipc_info = ipc_handles[vllm_name]
if "ipc_handle_b64" not in ipc_info:
print(f"[Setup] Missing ipc_handle_b64 for {hf_name}")
continue
# DEBUG: Only try first tensor to see if IPC works at all
if attached_count == 0:
print(f"[Setup DEBUG] Attempting first tensor: {hf_name}", flush=True)
print(f"[Setup DEBUG] device_index: {ipc_info['device_index']}", flush=True)
print(f"[Setup DEBUG] storage_size: {ipc_info['storage_size']}", flush=True)
print(f"[Setup DEBUG] shape: {ipc_info['shape']}", flush=True)
# Decode all the bytes fields from base64
device_index = ipc_info["device_index"]
ipc_handle = base64.b64decode(ipc_info["ipc_handle_b64"])
storage_size = ipc_info["storage_size"]
storage_offset_orig = ipc_info["storage_offset_orig"]
ref_counter_handle = base64.b64decode(ipc_info["ref_counter_handle_b64"])
ref_counter_offset = ipc_info["ref_counter_offset"]
event_handle = base64.b64decode(ipc_info["event_handle_b64"])
event_sync_required = ipc_info["event_sync_required"]
if attached_count == 0:
print(f"[Setup DEBUG] Decoded IPC handle, len={len(ipc_handle)}", flush=True)
print(f"[Setup DEBUG] About to call _new_shared_cuda...", flush=True)
# Reconstruct the 8-tuple that _new_shared_cuda expects
share_tuple = (
device_index,
ipc_handle,
storage_size,
storage_offset_orig,
ref_counter_handle,
ref_counter_offset,
event_handle,
event_sync_required,
)
# Create storage from IPC handle (needs all 8 items)
storage = torch.UntypedStorage._new_shared_cuda(*share_tuple)
if attached_count == 0:
print(f"[Setup DEBUG] Storage created! size={storage.size()}", flush=True)
# Reconstruct tensor
dtype = getattr(torch, ipc_info["dtype"].replace("torch.", ""))
tensor = torch.tensor([], dtype=dtype, device=f"cuda:{device_index}")
tensor.set_(
storage,
storage_offset=ipc_info["tensor_storage_offset"],
size=ipc_info["shape"],
stride=ipc_info["stride"],
)
if attached_count == 0:
print(f"[Setup DEBUG] Tensor set! shape={tensor.shape}", flush=True)
# Make tensor require gradients for training
tensor = _create_ipc_tensor(ipc_info)
tensor.requires_grad_(True)
hf_state_dict[hf_name] = tensor
attached_count += 1
if attached_count == 1:
print(f"[Setup DEBUG] ✓ First tensor attached successfully!", flush=True)
print(f"[Setup DEBUG] ✓ First tensor attached: {hf_name}", flush=True)
except Exception as e:
print(f"[Setup] Failed to attach {hf_name}: {e}", flush=True)
import traceback
traceback.print_exc()
continue
print(f"[Setup] Direct attachments: {attached_count}")
# Now handle fused mappings (QKV, gate_up)
fused_count = 0
for hf_name, (vllm_name, fuse_type, slice_idx) in fused_mappings.items():
if vllm_name not in ipc_handles:
continue
try:
# Get or create the fused tensor
if vllm_name not in fused_tensor_cache:
ipc_info = ipc_handles[vllm_name]
if "ipc_handle_b64" not in ipc_info:
continue
fused_tensor_cache[vllm_name] = _create_ipc_tensor(ipc_info)
fused_tensor = fused_tensor_cache[vllm_name]
# Slice the fused tensor
# vLLM fuses along the first dimension (output features)
if fuse_type == 'qkv':
# QKV is fused: [hidden, (q_size + k_size + v_size)]
# For Qwen: q_size = num_heads * head_dim, k_size = v_size = num_kv_heads * head_dim
total_size = fused_tensor.shape[0]
# Assume equal splits for Q, K, V (common case)
# Actually for GQA: Q is larger, K=V are smaller
# We'll get sizes from the HF model's expected shapes
hf_param = dict(model.named_parameters())[hf_name]
expected_size = hf_param.shape[0]
if slice_idx == 0: # Q
sliced = fused_tensor[:expected_size]
elif slice_idx == 1: # K
# K starts after Q
q_size = dict(model.named_parameters())[hf_name.replace('.k_proj.', '.q_proj.')].shape[0]
sliced = fused_tensor[q_size:q_size + expected_size]
else: # V
q_size = dict(model.named_parameters())[hf_name.replace('.v_proj.', '.q_proj.')].shape[0]
k_size = dict(model.named_parameters())[hf_name.replace('.v_proj.', '.k_proj.')].shape[0]
sliced = fused_tensor[q_size + k_size:q_size + k_size + expected_size]
elif fuse_type == 'gate_up':
# gate_up is fused: [hidden, gate_size + up_size]
total_size = fused_tensor.shape[0]
half_size = total_size // 2
if slice_idx == 0: # gate
sliced = fused_tensor[:half_size]
else: # up
sliced = fused_tensor[half_size:]
# The slice shares memory with the original!
sliced.requires_grad_(True)
hf_state_dict[hf_name] = sliced
fused_count += 1
except Exception as e:
print(f"[Setup] Failed to slice {hf_name} from {vllm_name}: {e}", flush=True)
continue
print(f"[Setup] Fused slices: {fused_count}")
attached_count += fused_count
if attached_count == 0:
print("[Setup] Could not attach any tensors - IPC failed")
print("[Setup] Model loaded normally (not sharing memory with vLLM)")
@ -656,21 +688,18 @@ def _create_vllm_to_hf_mapping(model: torch.nn.Module, ipc_handles: dict) -> dic
"""
Create mapping from HuggingFace parameter names to vLLM tensor names.
vLLM uses slightly different naming conventions than HuggingFace.
This function creates the bidirectional mapping.
vLLM uses fused layers that need special handling:
- qkv_proj -> q_proj, k_proj, v_proj (need slicing)
- gate_up_proj -> gate_proj, up_proj (need slicing)
"""
hf_params = set(model.state_dict().keys())
vllm_params = set(ipc_handles.keys())
print(f"[Setup] HF model has {len(hf_params)} params, vLLM exported {len(vllm_params)} tensors")
# Debug: show sample names
hf_sample = sorted(list(hf_params))[:3]
vllm_sample = sorted(list(vllm_params))[:3]
print(f"[Setup] Sample HF names: {hf_sample}")
print(f"[Setup] Sample vLLM names: {vllm_sample}")
mapping = {}
# Track fused tensors that need slicing
fused_mappings = {} # hf_name -> (vllm_name, slice_type, slice_index)
for hf_name in hf_params:
# Try direct match first
@ -678,18 +707,35 @@ def _create_vllm_to_hf_mapping(model: torch.nn.Module, ipc_handles: dict) -> dic
mapping[hf_name] = hf_name
continue
# Try common transformations
# vLLM often uses 'model.' prefix
vllm_name = f"model.{hf_name}" if not hf_name.startswith("model.") else hf_name
if vllm_name in vllm_params:
mapping[hf_name] = vllm_name
continue
# Remove 'model.' prefix if present
if hf_name.startswith("model."):
vllm_name = hf_name[6:]
# Check for fused QKV projection
# HF: model.layers.X.self_attn.q_proj.weight -> vLLM: model.layers.X.self_attn.qkv_proj.weight
if '.self_attn.q_proj.' in hf_name:
vllm_name = hf_name.replace('.q_proj.', '.qkv_proj.')
if vllm_name in vllm_params:
mapping[hf_name] = vllm_name
fused_mappings[hf_name] = (vllm_name, 'qkv', 0) # Q is first
continue
if '.self_attn.k_proj.' in hf_name:
vllm_name = hf_name.replace('.k_proj.', '.qkv_proj.')
if vllm_name in vllm_params:
fused_mappings[hf_name] = (vllm_name, 'qkv', 1) # K is second
continue
if '.self_attn.v_proj.' in hf_name:
vllm_name = hf_name.replace('.v_proj.', '.qkv_proj.')
if vllm_name in vllm_params:
fused_mappings[hf_name] = (vllm_name, 'qkv', 2) # V is third
continue
# Check for fused gate_up projection
# HF: model.layers.X.mlp.gate_proj.weight -> vLLM: model.layers.X.mlp.gate_up_proj.weight
if '.mlp.gate_proj.' in hf_name:
vllm_name = hf_name.replace('.gate_proj.', '.gate_up_proj.')
if vllm_name in vllm_params:
fused_mappings[hf_name] = (vllm_name, 'gate_up', 0) # gate is first
continue
if '.mlp.up_proj.' in hf_name:
vllm_name = hf_name.replace('.up_proj.', '.gate_up_proj.')
if vllm_name in vllm_params:
fused_mappings[hf_name] = (vllm_name, 'gate_up', 1) # up is second
continue
# For lm_head, check if it's tied to embed_tokens
@ -697,13 +743,12 @@ def _create_vllm_to_hf_mapping(model: torch.nn.Module, ipc_handles: dict) -> dic
mapping[hf_name] = "model.embed_tokens.weight"
continue
print(f"[Setup] Mapped {len(mapping)} HF params to vLLM tensors")
total_mapped = len(mapping) + len(fused_mappings)
print(f"[Setup] Direct mappings: {len(mapping)}, Fused mappings: {len(fused_mappings)}")
print(f"[Setup] Total mapped: {total_mapped} / {len(hf_params)}")
# Show what's NOT mapped
unmapped = hf_params - set(mapping.keys())
if unmapped:
unmapped_sample = sorted(list(unmapped))[:5]
print(f"[Setup] Unmapped HF params ({len(unmapped)} total): {unmapped_sample}...")
# Store fused mappings for later slicing
mapping['_fused_'] = fused_mappings
return mapping