fused memory

This commit is contained in:
Jai Suphavadeeprasit 2026-01-13 22:24:07 -05:00
parent 9a95ec5aa1
commit 906802299c

View file

@ -518,87 +518,119 @@ def _attach_to_vllm_shared_tensors(
hf_state_dict = {} hf_state_dict = {}
vllm_to_hf_mapping = _create_vllm_to_hf_mapping(model, ipc_handles) vllm_to_hf_mapping = _create_vllm_to_hf_mapping(model, ipc_handles)
# Helper to create tensor from IPC handle
def _create_ipc_tensor(ipc_info):
device_index = ipc_info["device_index"]
ipc_handle = base64.b64decode(ipc_info["ipc_handle_b64"])
storage_size = ipc_info["storage_size"]
storage_offset_orig = ipc_info["storage_offset_orig"]
ref_counter_handle = base64.b64decode(ipc_info["ref_counter_handle_b64"])
ref_counter_offset = ipc_info["ref_counter_offset"]
event_handle = base64.b64decode(ipc_info["event_handle_b64"])
event_sync_required = ipc_info["event_sync_required"]
share_tuple = (device_index, ipc_handle, storage_size, storage_offset_orig,
ref_counter_handle, ref_counter_offset, event_handle, event_sync_required)
storage = torch.UntypedStorage._new_shared_cuda(*share_tuple)
dtype = getattr(torch, ipc_info["dtype"].replace("torch.", ""))
tensor = torch.tensor([], dtype=dtype, device=f"cuda:{device_index}")
tensor.set_(storage, storage_offset=ipc_info["tensor_storage_offset"],
size=ipc_info["shape"], stride=ipc_info["stride"])
return tensor
# Cache for fused tensors (avoid recreating from IPC multiple times)
fused_tensor_cache = {}
# Extract fused mappings
fused_mappings = vllm_to_hf_mapping.pop('_fused_', {})
attached_count = 0 attached_count = 0
# First, handle direct mappings
for hf_name, vllm_name in vllm_to_hf_mapping.items(): for hf_name, vllm_name in vllm_to_hf_mapping.items():
if vllm_name not in ipc_handles: if vllm_name not in ipc_handles:
continue continue
ipc_info = ipc_handles[vllm_name]
try: try:
# Reconstruct tensor from IPC handle ipc_info = ipc_handles[vllm_name]
# We need all 8 items from the original _share_cuda_() call
if "ipc_handle_b64" not in ipc_info: if "ipc_handle_b64" not in ipc_info:
print(f"[Setup] Missing ipc_handle_b64 for {hf_name}")
continue continue
# DEBUG: Only try first tensor to see if IPC works at all tensor = _create_ipc_tensor(ipc_info)
if attached_count == 0:
print(f"[Setup DEBUG] Attempting first tensor: {hf_name}", flush=True)
print(f"[Setup DEBUG] device_index: {ipc_info['device_index']}", flush=True)
print(f"[Setup DEBUG] storage_size: {ipc_info['storage_size']}", flush=True)
print(f"[Setup DEBUG] shape: {ipc_info['shape']}", flush=True)
# Decode all the bytes fields from base64
device_index = ipc_info["device_index"]
ipc_handle = base64.b64decode(ipc_info["ipc_handle_b64"])
storage_size = ipc_info["storage_size"]
storage_offset_orig = ipc_info["storage_offset_orig"]
ref_counter_handle = base64.b64decode(ipc_info["ref_counter_handle_b64"])
ref_counter_offset = ipc_info["ref_counter_offset"]
event_handle = base64.b64decode(ipc_info["event_handle_b64"])
event_sync_required = ipc_info["event_sync_required"]
if attached_count == 0:
print(f"[Setup DEBUG] Decoded IPC handle, len={len(ipc_handle)}", flush=True)
print(f"[Setup DEBUG] About to call _new_shared_cuda...", flush=True)
# Reconstruct the 8-tuple that _new_shared_cuda expects
share_tuple = (
device_index,
ipc_handle,
storage_size,
storage_offset_orig,
ref_counter_handle,
ref_counter_offset,
event_handle,
event_sync_required,
)
# Create storage from IPC handle (needs all 8 items)
storage = torch.UntypedStorage._new_shared_cuda(*share_tuple)
if attached_count == 0:
print(f"[Setup DEBUG] Storage created! size={storage.size()}", flush=True)
# Reconstruct tensor
dtype = getattr(torch, ipc_info["dtype"].replace("torch.", ""))
tensor = torch.tensor([], dtype=dtype, device=f"cuda:{device_index}")
tensor.set_(
storage,
storage_offset=ipc_info["tensor_storage_offset"],
size=ipc_info["shape"],
stride=ipc_info["stride"],
)
if attached_count == 0:
print(f"[Setup DEBUG] Tensor set! shape={tensor.shape}", flush=True)
# Make tensor require gradients for training
tensor.requires_grad_(True) tensor.requires_grad_(True)
hf_state_dict[hf_name] = tensor hf_state_dict[hf_name] = tensor
attached_count += 1 attached_count += 1
if attached_count == 1: if attached_count == 1:
print(f"[Setup DEBUG] ✓ First tensor attached successfully!", flush=True) print(f"[Setup DEBUG] ✓ First tensor attached: {hf_name}", flush=True)
except Exception as e: except Exception as e:
print(f"[Setup] Failed to attach {hf_name}: {e}", flush=True) print(f"[Setup] Failed to attach {hf_name}: {e}", flush=True)
import traceback
traceback.print_exc()
continue continue
print(f"[Setup] Direct attachments: {attached_count}")
# Now handle fused mappings (QKV, gate_up)
fused_count = 0
for hf_name, (vllm_name, fuse_type, slice_idx) in fused_mappings.items():
if vllm_name not in ipc_handles:
continue
try:
# Get or create the fused tensor
if vllm_name not in fused_tensor_cache:
ipc_info = ipc_handles[vllm_name]
if "ipc_handle_b64" not in ipc_info:
continue
fused_tensor_cache[vllm_name] = _create_ipc_tensor(ipc_info)
fused_tensor = fused_tensor_cache[vllm_name]
# Slice the fused tensor
# vLLM fuses along the first dimension (output features)
if fuse_type == 'qkv':
# QKV is fused: [hidden, (q_size + k_size + v_size)]
# For Qwen: q_size = num_heads * head_dim, k_size = v_size = num_kv_heads * head_dim
total_size = fused_tensor.shape[0]
# Assume equal splits for Q, K, V (common case)
# Actually for GQA: Q is larger, K=V are smaller
# We'll get sizes from the HF model's expected shapes
hf_param = dict(model.named_parameters())[hf_name]
expected_size = hf_param.shape[0]
if slice_idx == 0: # Q
sliced = fused_tensor[:expected_size]
elif slice_idx == 1: # K
# K starts after Q
q_size = dict(model.named_parameters())[hf_name.replace('.k_proj.', '.q_proj.')].shape[0]
sliced = fused_tensor[q_size:q_size + expected_size]
else: # V
q_size = dict(model.named_parameters())[hf_name.replace('.v_proj.', '.q_proj.')].shape[0]
k_size = dict(model.named_parameters())[hf_name.replace('.v_proj.', '.k_proj.')].shape[0]
sliced = fused_tensor[q_size + k_size:q_size + k_size + expected_size]
elif fuse_type == 'gate_up':
# gate_up is fused: [hidden, gate_size + up_size]
total_size = fused_tensor.shape[0]
half_size = total_size // 2
if slice_idx == 0: # gate
sliced = fused_tensor[:half_size]
else: # up
sliced = fused_tensor[half_size:]
# The slice shares memory with the original!
sliced.requires_grad_(True)
hf_state_dict[hf_name] = sliced
fused_count += 1
except Exception as e:
print(f"[Setup] Failed to slice {hf_name} from {vllm_name}: {e}", flush=True)
continue
print(f"[Setup] Fused slices: {fused_count}")
attached_count += fused_count
if attached_count == 0: if attached_count == 0:
print("[Setup] Could not attach any tensors - IPC failed") print("[Setup] Could not attach any tensors - IPC failed")
print("[Setup] Model loaded normally (not sharing memory with vLLM)") print("[Setup] Model loaded normally (not sharing memory with vLLM)")
@ -656,21 +688,18 @@ def _create_vllm_to_hf_mapping(model: torch.nn.Module, ipc_handles: dict) -> dic
""" """
Create mapping from HuggingFace parameter names to vLLM tensor names. Create mapping from HuggingFace parameter names to vLLM tensor names.
vLLM uses slightly different naming conventions than HuggingFace. vLLM uses fused layers that need special handling:
This function creates the bidirectional mapping. - qkv_proj -> q_proj, k_proj, v_proj (need slicing)
- gate_up_proj -> gate_proj, up_proj (need slicing)
""" """
hf_params = set(model.state_dict().keys()) hf_params = set(model.state_dict().keys())
vllm_params = set(ipc_handles.keys()) vllm_params = set(ipc_handles.keys())
print(f"[Setup] HF model has {len(hf_params)} params, vLLM exported {len(vllm_params)} tensors") print(f"[Setup] HF model has {len(hf_params)} params, vLLM exported {len(vllm_params)} tensors")
# Debug: show sample names
hf_sample = sorted(list(hf_params))[:3]
vllm_sample = sorted(list(vllm_params))[:3]
print(f"[Setup] Sample HF names: {hf_sample}")
print(f"[Setup] Sample vLLM names: {vllm_sample}")
mapping = {} mapping = {}
# Track fused tensors that need slicing
fused_mappings = {} # hf_name -> (vllm_name, slice_type, slice_index)
for hf_name in hf_params: for hf_name in hf_params:
# Try direct match first # Try direct match first
@ -678,18 +707,35 @@ def _create_vllm_to_hf_mapping(model: torch.nn.Module, ipc_handles: dict) -> dic
mapping[hf_name] = hf_name mapping[hf_name] = hf_name
continue continue
# Try common transformations # Check for fused QKV projection
# vLLM often uses 'model.' prefix # HF: model.layers.X.self_attn.q_proj.weight -> vLLM: model.layers.X.self_attn.qkv_proj.weight
vllm_name = f"model.{hf_name}" if not hf_name.startswith("model.") else hf_name if '.self_attn.q_proj.' in hf_name:
if vllm_name in vllm_params: vllm_name = hf_name.replace('.q_proj.', '.qkv_proj.')
mapping[hf_name] = vllm_name
continue
# Remove 'model.' prefix if present
if hf_name.startswith("model."):
vllm_name = hf_name[6:]
if vllm_name in vllm_params: if vllm_name in vllm_params:
mapping[hf_name] = vllm_name fused_mappings[hf_name] = (vllm_name, 'qkv', 0) # Q is first
continue
if '.self_attn.k_proj.' in hf_name:
vllm_name = hf_name.replace('.k_proj.', '.qkv_proj.')
if vllm_name in vllm_params:
fused_mappings[hf_name] = (vllm_name, 'qkv', 1) # K is second
continue
if '.self_attn.v_proj.' in hf_name:
vllm_name = hf_name.replace('.v_proj.', '.qkv_proj.')
if vllm_name in vllm_params:
fused_mappings[hf_name] = (vllm_name, 'qkv', 2) # V is third
continue
# Check for fused gate_up projection
# HF: model.layers.X.mlp.gate_proj.weight -> vLLM: model.layers.X.mlp.gate_up_proj.weight
if '.mlp.gate_proj.' in hf_name:
vllm_name = hf_name.replace('.gate_proj.', '.gate_up_proj.')
if vllm_name in vllm_params:
fused_mappings[hf_name] = (vllm_name, 'gate_up', 0) # gate is first
continue
if '.mlp.up_proj.' in hf_name:
vllm_name = hf_name.replace('.up_proj.', '.gate_up_proj.')
if vllm_name in vllm_params:
fused_mappings[hf_name] = (vllm_name, 'gate_up', 1) # up is second
continue continue
# For lm_head, check if it's tied to embed_tokens # For lm_head, check if it's tied to embed_tokens
@ -697,13 +743,12 @@ def _create_vllm_to_hf_mapping(model: torch.nn.Module, ipc_handles: dict) -> dic
mapping[hf_name] = "model.embed_tokens.weight" mapping[hf_name] = "model.embed_tokens.weight"
continue continue
print(f"[Setup] Mapped {len(mapping)} HF params to vLLM tensors") total_mapped = len(mapping) + len(fused_mappings)
print(f"[Setup] Direct mappings: {len(mapping)}, Fused mappings: {len(fused_mappings)}")
print(f"[Setup] Total mapped: {total_mapped} / {len(hf_params)}")
# Show what's NOT mapped # Store fused mappings for later slicing
unmapped = hf_params - set(mapping.keys()) mapping['_fused_'] = fused_mappings
if unmapped:
unmapped_sample = sorted(list(unmapped))[:5]
print(f"[Setup] Unmapped HF params ({len(unmapped)} total): {unmapped_sample}...")
return mapping return mapping