mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-28 17:29:30 +00:00
fused memory
This commit is contained in:
parent
9a95ec5aa1
commit
906802299c
1 changed files with 132 additions and 87 deletions
|
|
@ -518,87 +518,119 @@ def _attach_to_vllm_shared_tensors(
|
||||||
hf_state_dict = {}
|
hf_state_dict = {}
|
||||||
vllm_to_hf_mapping = _create_vllm_to_hf_mapping(model, ipc_handles)
|
vllm_to_hf_mapping = _create_vllm_to_hf_mapping(model, ipc_handles)
|
||||||
|
|
||||||
|
# Helper to create tensor from IPC handle
|
||||||
|
def _create_ipc_tensor(ipc_info):
|
||||||
|
device_index = ipc_info["device_index"]
|
||||||
|
ipc_handle = base64.b64decode(ipc_info["ipc_handle_b64"])
|
||||||
|
storage_size = ipc_info["storage_size"]
|
||||||
|
storage_offset_orig = ipc_info["storage_offset_orig"]
|
||||||
|
ref_counter_handle = base64.b64decode(ipc_info["ref_counter_handle_b64"])
|
||||||
|
ref_counter_offset = ipc_info["ref_counter_offset"]
|
||||||
|
event_handle = base64.b64decode(ipc_info["event_handle_b64"])
|
||||||
|
event_sync_required = ipc_info["event_sync_required"]
|
||||||
|
|
||||||
|
share_tuple = (device_index, ipc_handle, storage_size, storage_offset_orig,
|
||||||
|
ref_counter_handle, ref_counter_offset, event_handle, event_sync_required)
|
||||||
|
|
||||||
|
storage = torch.UntypedStorage._new_shared_cuda(*share_tuple)
|
||||||
|
dtype = getattr(torch, ipc_info["dtype"].replace("torch.", ""))
|
||||||
|
tensor = torch.tensor([], dtype=dtype, device=f"cuda:{device_index}")
|
||||||
|
tensor.set_(storage, storage_offset=ipc_info["tensor_storage_offset"],
|
||||||
|
size=ipc_info["shape"], stride=ipc_info["stride"])
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
# Cache for fused tensors (avoid recreating from IPC multiple times)
|
||||||
|
fused_tensor_cache = {}
|
||||||
|
|
||||||
|
# Extract fused mappings
|
||||||
|
fused_mappings = vllm_to_hf_mapping.pop('_fused_', {})
|
||||||
|
|
||||||
attached_count = 0
|
attached_count = 0
|
||||||
|
|
||||||
|
# First, handle direct mappings
|
||||||
for hf_name, vllm_name in vllm_to_hf_mapping.items():
|
for hf_name, vllm_name in vllm_to_hf_mapping.items():
|
||||||
if vllm_name not in ipc_handles:
|
if vllm_name not in ipc_handles:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
ipc_info = ipc_handles[vllm_name]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Reconstruct tensor from IPC handle
|
ipc_info = ipc_handles[vllm_name]
|
||||||
# We need all 8 items from the original _share_cuda_() call
|
|
||||||
if "ipc_handle_b64" not in ipc_info:
|
if "ipc_handle_b64" not in ipc_info:
|
||||||
print(f"[Setup] Missing ipc_handle_b64 for {hf_name}")
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# DEBUG: Only try first tensor to see if IPC works at all
|
tensor = _create_ipc_tensor(ipc_info)
|
||||||
if attached_count == 0:
|
|
||||||
print(f"[Setup DEBUG] Attempting first tensor: {hf_name}", flush=True)
|
|
||||||
print(f"[Setup DEBUG] device_index: {ipc_info['device_index']}", flush=True)
|
|
||||||
print(f"[Setup DEBUG] storage_size: {ipc_info['storage_size']}", flush=True)
|
|
||||||
print(f"[Setup DEBUG] shape: {ipc_info['shape']}", flush=True)
|
|
||||||
|
|
||||||
# Decode all the bytes fields from base64
|
|
||||||
device_index = ipc_info["device_index"]
|
|
||||||
ipc_handle = base64.b64decode(ipc_info["ipc_handle_b64"])
|
|
||||||
storage_size = ipc_info["storage_size"]
|
|
||||||
storage_offset_orig = ipc_info["storage_offset_orig"]
|
|
||||||
ref_counter_handle = base64.b64decode(ipc_info["ref_counter_handle_b64"])
|
|
||||||
ref_counter_offset = ipc_info["ref_counter_offset"]
|
|
||||||
event_handle = base64.b64decode(ipc_info["event_handle_b64"])
|
|
||||||
event_sync_required = ipc_info["event_sync_required"]
|
|
||||||
|
|
||||||
if attached_count == 0:
|
|
||||||
print(f"[Setup DEBUG] Decoded IPC handle, len={len(ipc_handle)}", flush=True)
|
|
||||||
print(f"[Setup DEBUG] About to call _new_shared_cuda...", flush=True)
|
|
||||||
|
|
||||||
# Reconstruct the 8-tuple that _new_shared_cuda expects
|
|
||||||
share_tuple = (
|
|
||||||
device_index,
|
|
||||||
ipc_handle,
|
|
||||||
storage_size,
|
|
||||||
storage_offset_orig,
|
|
||||||
ref_counter_handle,
|
|
||||||
ref_counter_offset,
|
|
||||||
event_handle,
|
|
||||||
event_sync_required,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create storage from IPC handle (needs all 8 items)
|
|
||||||
storage = torch.UntypedStorage._new_shared_cuda(*share_tuple)
|
|
||||||
|
|
||||||
if attached_count == 0:
|
|
||||||
print(f"[Setup DEBUG] Storage created! size={storage.size()}", flush=True)
|
|
||||||
|
|
||||||
# Reconstruct tensor
|
|
||||||
dtype = getattr(torch, ipc_info["dtype"].replace("torch.", ""))
|
|
||||||
tensor = torch.tensor([], dtype=dtype, device=f"cuda:{device_index}")
|
|
||||||
tensor.set_(
|
|
||||||
storage,
|
|
||||||
storage_offset=ipc_info["tensor_storage_offset"],
|
|
||||||
size=ipc_info["shape"],
|
|
||||||
stride=ipc_info["stride"],
|
|
||||||
)
|
|
||||||
|
|
||||||
if attached_count == 0:
|
|
||||||
print(f"[Setup DEBUG] Tensor set! shape={tensor.shape}", flush=True)
|
|
||||||
|
|
||||||
# Make tensor require gradients for training
|
|
||||||
tensor.requires_grad_(True)
|
tensor.requires_grad_(True)
|
||||||
|
|
||||||
hf_state_dict[hf_name] = tensor
|
hf_state_dict[hf_name] = tensor
|
||||||
attached_count += 1
|
attached_count += 1
|
||||||
|
|
||||||
if attached_count == 1:
|
if attached_count == 1:
|
||||||
print(f"[Setup DEBUG] ✓ First tensor attached successfully!", flush=True)
|
print(f"[Setup DEBUG] ✓ First tensor attached: {hf_name}", flush=True)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[Setup] Failed to attach {hf_name}: {e}", flush=True)
|
print(f"[Setup] Failed to attach {hf_name}: {e}", flush=True)
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
print(f"[Setup] Direct attachments: {attached_count}")
|
||||||
|
|
||||||
|
# Now handle fused mappings (QKV, gate_up)
|
||||||
|
fused_count = 0
|
||||||
|
for hf_name, (vllm_name, fuse_type, slice_idx) in fused_mappings.items():
|
||||||
|
if vllm_name not in ipc_handles:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get or create the fused tensor
|
||||||
|
if vllm_name not in fused_tensor_cache:
|
||||||
|
ipc_info = ipc_handles[vllm_name]
|
||||||
|
if "ipc_handle_b64" not in ipc_info:
|
||||||
|
continue
|
||||||
|
fused_tensor_cache[vllm_name] = _create_ipc_tensor(ipc_info)
|
||||||
|
|
||||||
|
fused_tensor = fused_tensor_cache[vllm_name]
|
||||||
|
|
||||||
|
# Slice the fused tensor
|
||||||
|
# vLLM fuses along the first dimension (output features)
|
||||||
|
if fuse_type == 'qkv':
|
||||||
|
# QKV is fused: [hidden, (q_size + k_size + v_size)]
|
||||||
|
# For Qwen: q_size = num_heads * head_dim, k_size = v_size = num_kv_heads * head_dim
|
||||||
|
total_size = fused_tensor.shape[0]
|
||||||
|
# Assume equal splits for Q, K, V (common case)
|
||||||
|
# Actually for GQA: Q is larger, K=V are smaller
|
||||||
|
# We'll get sizes from the HF model's expected shapes
|
||||||
|
hf_param = dict(model.named_parameters())[hf_name]
|
||||||
|
expected_size = hf_param.shape[0]
|
||||||
|
|
||||||
|
if slice_idx == 0: # Q
|
||||||
|
sliced = fused_tensor[:expected_size]
|
||||||
|
elif slice_idx == 1: # K
|
||||||
|
# K starts after Q
|
||||||
|
q_size = dict(model.named_parameters())[hf_name.replace('.k_proj.', '.q_proj.')].shape[0]
|
||||||
|
sliced = fused_tensor[q_size:q_size + expected_size]
|
||||||
|
else: # V
|
||||||
|
q_size = dict(model.named_parameters())[hf_name.replace('.v_proj.', '.q_proj.')].shape[0]
|
||||||
|
k_size = dict(model.named_parameters())[hf_name.replace('.v_proj.', '.k_proj.')].shape[0]
|
||||||
|
sliced = fused_tensor[q_size + k_size:q_size + k_size + expected_size]
|
||||||
|
|
||||||
|
elif fuse_type == 'gate_up':
|
||||||
|
# gate_up is fused: [hidden, gate_size + up_size]
|
||||||
|
total_size = fused_tensor.shape[0]
|
||||||
|
half_size = total_size // 2
|
||||||
|
if slice_idx == 0: # gate
|
||||||
|
sliced = fused_tensor[:half_size]
|
||||||
|
else: # up
|
||||||
|
sliced = fused_tensor[half_size:]
|
||||||
|
|
||||||
|
# The slice shares memory with the original!
|
||||||
|
sliced.requires_grad_(True)
|
||||||
|
hf_state_dict[hf_name] = sliced
|
||||||
|
fused_count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[Setup] Failed to slice {hf_name} from {vllm_name}: {e}", flush=True)
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"[Setup] Fused slices: {fused_count}")
|
||||||
|
attached_count += fused_count
|
||||||
|
|
||||||
if attached_count == 0:
|
if attached_count == 0:
|
||||||
print("[Setup] Could not attach any tensors - IPC failed")
|
print("[Setup] Could not attach any tensors - IPC failed")
|
||||||
print("[Setup] Model loaded normally (not sharing memory with vLLM)")
|
print("[Setup] Model loaded normally (not sharing memory with vLLM)")
|
||||||
|
|
@ -656,21 +688,18 @@ def _create_vllm_to_hf_mapping(model: torch.nn.Module, ipc_handles: dict) -> dic
|
||||||
"""
|
"""
|
||||||
Create mapping from HuggingFace parameter names to vLLM tensor names.
|
Create mapping from HuggingFace parameter names to vLLM tensor names.
|
||||||
|
|
||||||
vLLM uses slightly different naming conventions than HuggingFace.
|
vLLM uses fused layers that need special handling:
|
||||||
This function creates the bidirectional mapping.
|
- qkv_proj -> q_proj, k_proj, v_proj (need slicing)
|
||||||
|
- gate_up_proj -> gate_proj, up_proj (need slicing)
|
||||||
"""
|
"""
|
||||||
hf_params = set(model.state_dict().keys())
|
hf_params = set(model.state_dict().keys())
|
||||||
vllm_params = set(ipc_handles.keys())
|
vllm_params = set(ipc_handles.keys())
|
||||||
|
|
||||||
print(f"[Setup] HF model has {len(hf_params)} params, vLLM exported {len(vllm_params)} tensors")
|
print(f"[Setup] HF model has {len(hf_params)} params, vLLM exported {len(vllm_params)} tensors")
|
||||||
|
|
||||||
# Debug: show sample names
|
|
||||||
hf_sample = sorted(list(hf_params))[:3]
|
|
||||||
vllm_sample = sorted(list(vllm_params))[:3]
|
|
||||||
print(f"[Setup] Sample HF names: {hf_sample}")
|
|
||||||
print(f"[Setup] Sample vLLM names: {vllm_sample}")
|
|
||||||
|
|
||||||
mapping = {}
|
mapping = {}
|
||||||
|
# Track fused tensors that need slicing
|
||||||
|
fused_mappings = {} # hf_name -> (vllm_name, slice_type, slice_index)
|
||||||
|
|
||||||
for hf_name in hf_params:
|
for hf_name in hf_params:
|
||||||
# Try direct match first
|
# Try direct match first
|
||||||
|
|
@ -678,18 +707,35 @@ def _create_vllm_to_hf_mapping(model: torch.nn.Module, ipc_handles: dict) -> dic
|
||||||
mapping[hf_name] = hf_name
|
mapping[hf_name] = hf_name
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Try common transformations
|
# Check for fused QKV projection
|
||||||
# vLLM often uses 'model.' prefix
|
# HF: model.layers.X.self_attn.q_proj.weight -> vLLM: model.layers.X.self_attn.qkv_proj.weight
|
||||||
vllm_name = f"model.{hf_name}" if not hf_name.startswith("model.") else hf_name
|
if '.self_attn.q_proj.' in hf_name:
|
||||||
if vllm_name in vllm_params:
|
vllm_name = hf_name.replace('.q_proj.', '.qkv_proj.')
|
||||||
mapping[hf_name] = vllm_name
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Remove 'model.' prefix if present
|
|
||||||
if hf_name.startswith("model."):
|
|
||||||
vllm_name = hf_name[6:]
|
|
||||||
if vllm_name in vllm_params:
|
if vllm_name in vllm_params:
|
||||||
mapping[hf_name] = vllm_name
|
fused_mappings[hf_name] = (vllm_name, 'qkv', 0) # Q is first
|
||||||
|
continue
|
||||||
|
if '.self_attn.k_proj.' in hf_name:
|
||||||
|
vllm_name = hf_name.replace('.k_proj.', '.qkv_proj.')
|
||||||
|
if vllm_name in vllm_params:
|
||||||
|
fused_mappings[hf_name] = (vllm_name, 'qkv', 1) # K is second
|
||||||
|
continue
|
||||||
|
if '.self_attn.v_proj.' in hf_name:
|
||||||
|
vllm_name = hf_name.replace('.v_proj.', '.qkv_proj.')
|
||||||
|
if vllm_name in vllm_params:
|
||||||
|
fused_mappings[hf_name] = (vllm_name, 'qkv', 2) # V is third
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check for fused gate_up projection
|
||||||
|
# HF: model.layers.X.mlp.gate_proj.weight -> vLLM: model.layers.X.mlp.gate_up_proj.weight
|
||||||
|
if '.mlp.gate_proj.' in hf_name:
|
||||||
|
vllm_name = hf_name.replace('.gate_proj.', '.gate_up_proj.')
|
||||||
|
if vllm_name in vllm_params:
|
||||||
|
fused_mappings[hf_name] = (vllm_name, 'gate_up', 0) # gate is first
|
||||||
|
continue
|
||||||
|
if '.mlp.up_proj.' in hf_name:
|
||||||
|
vllm_name = hf_name.replace('.up_proj.', '.gate_up_proj.')
|
||||||
|
if vllm_name in vllm_params:
|
||||||
|
fused_mappings[hf_name] = (vllm_name, 'gate_up', 1) # up is second
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# For lm_head, check if it's tied to embed_tokens
|
# For lm_head, check if it's tied to embed_tokens
|
||||||
|
|
@ -697,13 +743,12 @@ def _create_vllm_to_hf_mapping(model: torch.nn.Module, ipc_handles: dict) -> dic
|
||||||
mapping[hf_name] = "model.embed_tokens.weight"
|
mapping[hf_name] = "model.embed_tokens.weight"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print(f"[Setup] Mapped {len(mapping)} HF params to vLLM tensors")
|
total_mapped = len(mapping) + len(fused_mappings)
|
||||||
|
print(f"[Setup] Direct mappings: {len(mapping)}, Fused mappings: {len(fused_mappings)}")
|
||||||
|
print(f"[Setup] Total mapped: {total_mapped} / {len(hf_params)}")
|
||||||
|
|
||||||
# Show what's NOT mapped
|
# Store fused mappings for later slicing
|
||||||
unmapped = hf_params - set(mapping.keys())
|
mapping['_fused_'] = fused_mappings
|
||||||
if unmapped:
|
|
||||||
unmapped_sample = sorted(list(unmapped))[:5]
|
|
||||||
print(f"[Setup] Unmapped HF params ({len(unmapped)} total): {unmapped_sample}...")
|
|
||||||
|
|
||||||
return mapping
|
return mapping
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue