mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-25 17:10:42 +00:00
fused memory
This commit is contained in:
parent
9a95ec5aa1
commit
906802299c
1 changed files with 132 additions and 87 deletions
|
|
@ -518,87 +518,119 @@ def _attach_to_vllm_shared_tensors(
|
|||
hf_state_dict = {}
|
||||
vllm_to_hf_mapping = _create_vllm_to_hf_mapping(model, ipc_handles)
|
||||
|
||||
# Helper to create tensor from IPC handle
|
||||
def _create_ipc_tensor(ipc_info):
|
||||
device_index = ipc_info["device_index"]
|
||||
ipc_handle = base64.b64decode(ipc_info["ipc_handle_b64"])
|
||||
storage_size = ipc_info["storage_size"]
|
||||
storage_offset_orig = ipc_info["storage_offset_orig"]
|
||||
ref_counter_handle = base64.b64decode(ipc_info["ref_counter_handle_b64"])
|
||||
ref_counter_offset = ipc_info["ref_counter_offset"]
|
||||
event_handle = base64.b64decode(ipc_info["event_handle_b64"])
|
||||
event_sync_required = ipc_info["event_sync_required"]
|
||||
|
||||
share_tuple = (device_index, ipc_handle, storage_size, storage_offset_orig,
|
||||
ref_counter_handle, ref_counter_offset, event_handle, event_sync_required)
|
||||
|
||||
storage = torch.UntypedStorage._new_shared_cuda(*share_tuple)
|
||||
dtype = getattr(torch, ipc_info["dtype"].replace("torch.", ""))
|
||||
tensor = torch.tensor([], dtype=dtype, device=f"cuda:{device_index}")
|
||||
tensor.set_(storage, storage_offset=ipc_info["tensor_storage_offset"],
|
||||
size=ipc_info["shape"], stride=ipc_info["stride"])
|
||||
return tensor
|
||||
|
||||
# Cache for fused tensors (avoid recreating from IPC multiple times)
|
||||
fused_tensor_cache = {}
|
||||
|
||||
# Extract fused mappings
|
||||
fused_mappings = vllm_to_hf_mapping.pop('_fused_', {})
|
||||
|
||||
attached_count = 0
|
||||
|
||||
# First, handle direct mappings
|
||||
for hf_name, vllm_name in vllm_to_hf_mapping.items():
|
||||
if vllm_name not in ipc_handles:
|
||||
continue
|
||||
|
||||
ipc_info = ipc_handles[vllm_name]
|
||||
|
||||
try:
|
||||
# Reconstruct tensor from IPC handle
|
||||
# We need all 8 items from the original _share_cuda_() call
|
||||
ipc_info = ipc_handles[vllm_name]
|
||||
if "ipc_handle_b64" not in ipc_info:
|
||||
print(f"[Setup] Missing ipc_handle_b64 for {hf_name}")
|
||||
continue
|
||||
|
||||
# DEBUG: Only try first tensor to see if IPC works at all
|
||||
if attached_count == 0:
|
||||
print(f"[Setup DEBUG] Attempting first tensor: {hf_name}", flush=True)
|
||||
print(f"[Setup DEBUG] device_index: {ipc_info['device_index']}", flush=True)
|
||||
print(f"[Setup DEBUG] storage_size: {ipc_info['storage_size']}", flush=True)
|
||||
print(f"[Setup DEBUG] shape: {ipc_info['shape']}", flush=True)
|
||||
|
||||
# Decode all the bytes fields from base64
|
||||
device_index = ipc_info["device_index"]
|
||||
ipc_handle = base64.b64decode(ipc_info["ipc_handle_b64"])
|
||||
storage_size = ipc_info["storage_size"]
|
||||
storage_offset_orig = ipc_info["storage_offset_orig"]
|
||||
ref_counter_handle = base64.b64decode(ipc_info["ref_counter_handle_b64"])
|
||||
ref_counter_offset = ipc_info["ref_counter_offset"]
|
||||
event_handle = base64.b64decode(ipc_info["event_handle_b64"])
|
||||
event_sync_required = ipc_info["event_sync_required"]
|
||||
|
||||
if attached_count == 0:
|
||||
print(f"[Setup DEBUG] Decoded IPC handle, len={len(ipc_handle)}", flush=True)
|
||||
print(f"[Setup DEBUG] About to call _new_shared_cuda...", flush=True)
|
||||
|
||||
# Reconstruct the 8-tuple that _new_shared_cuda expects
|
||||
share_tuple = (
|
||||
device_index,
|
||||
ipc_handle,
|
||||
storage_size,
|
||||
storage_offset_orig,
|
||||
ref_counter_handle,
|
||||
ref_counter_offset,
|
||||
event_handle,
|
||||
event_sync_required,
|
||||
)
|
||||
|
||||
# Create storage from IPC handle (needs all 8 items)
|
||||
storage = torch.UntypedStorage._new_shared_cuda(*share_tuple)
|
||||
|
||||
if attached_count == 0:
|
||||
print(f"[Setup DEBUG] Storage created! size={storage.size()}", flush=True)
|
||||
|
||||
# Reconstruct tensor
|
||||
dtype = getattr(torch, ipc_info["dtype"].replace("torch.", ""))
|
||||
tensor = torch.tensor([], dtype=dtype, device=f"cuda:{device_index}")
|
||||
tensor.set_(
|
||||
storage,
|
||||
storage_offset=ipc_info["tensor_storage_offset"],
|
||||
size=ipc_info["shape"],
|
||||
stride=ipc_info["stride"],
|
||||
)
|
||||
|
||||
if attached_count == 0:
|
||||
print(f"[Setup DEBUG] Tensor set! shape={tensor.shape}", flush=True)
|
||||
|
||||
# Make tensor require gradients for training
|
||||
tensor = _create_ipc_tensor(ipc_info)
|
||||
tensor.requires_grad_(True)
|
||||
|
||||
hf_state_dict[hf_name] = tensor
|
||||
attached_count += 1
|
||||
|
||||
if attached_count == 1:
|
||||
print(f"[Setup DEBUG] ✓ First tensor attached successfully!", flush=True)
|
||||
print(f"[Setup DEBUG] ✓ First tensor attached: {hf_name}", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Setup] Failed to attach {hf_name}: {e}", flush=True)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
print(f"[Setup] Direct attachments: {attached_count}")
|
||||
|
||||
# Now handle fused mappings (QKV, gate_up)
|
||||
fused_count = 0
|
||||
for hf_name, (vllm_name, fuse_type, slice_idx) in fused_mappings.items():
|
||||
if vllm_name not in ipc_handles:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Get or create the fused tensor
|
||||
if vllm_name not in fused_tensor_cache:
|
||||
ipc_info = ipc_handles[vllm_name]
|
||||
if "ipc_handle_b64" not in ipc_info:
|
||||
continue
|
||||
fused_tensor_cache[vllm_name] = _create_ipc_tensor(ipc_info)
|
||||
|
||||
fused_tensor = fused_tensor_cache[vllm_name]
|
||||
|
||||
# Slice the fused tensor
|
||||
# vLLM fuses along the first dimension (output features)
|
||||
if fuse_type == 'qkv':
|
||||
# QKV is fused: [hidden, (q_size + k_size + v_size)]
|
||||
# For Qwen: q_size = num_heads * head_dim, k_size = v_size = num_kv_heads * head_dim
|
||||
total_size = fused_tensor.shape[0]
|
||||
# Assume equal splits for Q, K, V (common case)
|
||||
# Actually for GQA: Q is larger, K=V are smaller
|
||||
# We'll get sizes from the HF model's expected shapes
|
||||
hf_param = dict(model.named_parameters())[hf_name]
|
||||
expected_size = hf_param.shape[0]
|
||||
|
||||
if slice_idx == 0: # Q
|
||||
sliced = fused_tensor[:expected_size]
|
||||
elif slice_idx == 1: # K
|
||||
# K starts after Q
|
||||
q_size = dict(model.named_parameters())[hf_name.replace('.k_proj.', '.q_proj.')].shape[0]
|
||||
sliced = fused_tensor[q_size:q_size + expected_size]
|
||||
else: # V
|
||||
q_size = dict(model.named_parameters())[hf_name.replace('.v_proj.', '.q_proj.')].shape[0]
|
||||
k_size = dict(model.named_parameters())[hf_name.replace('.v_proj.', '.k_proj.')].shape[0]
|
||||
sliced = fused_tensor[q_size + k_size:q_size + k_size + expected_size]
|
||||
|
||||
elif fuse_type == 'gate_up':
|
||||
# gate_up is fused: [hidden, gate_size + up_size]
|
||||
total_size = fused_tensor.shape[0]
|
||||
half_size = total_size // 2
|
||||
if slice_idx == 0: # gate
|
||||
sliced = fused_tensor[:half_size]
|
||||
else: # up
|
||||
sliced = fused_tensor[half_size:]
|
||||
|
||||
# The slice shares memory with the original!
|
||||
sliced.requires_grad_(True)
|
||||
hf_state_dict[hf_name] = sliced
|
||||
fused_count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Setup] Failed to slice {hf_name} from {vllm_name}: {e}", flush=True)
|
||||
continue
|
||||
|
||||
print(f"[Setup] Fused slices: {fused_count}")
|
||||
attached_count += fused_count
|
||||
|
||||
if attached_count == 0:
|
||||
print("[Setup] Could not attach any tensors - IPC failed")
|
||||
print("[Setup] Model loaded normally (not sharing memory with vLLM)")
|
||||
|
|
@ -656,21 +688,18 @@ def _create_vllm_to_hf_mapping(model: torch.nn.Module, ipc_handles: dict) -> dic
|
|||
"""
|
||||
Create mapping from HuggingFace parameter names to vLLM tensor names.
|
||||
|
||||
vLLM uses slightly different naming conventions than HuggingFace.
|
||||
This function creates the bidirectional mapping.
|
||||
vLLM uses fused layers that need special handling:
|
||||
- qkv_proj -> q_proj, k_proj, v_proj (need slicing)
|
||||
- gate_up_proj -> gate_proj, up_proj (need slicing)
|
||||
"""
|
||||
hf_params = set(model.state_dict().keys())
|
||||
vllm_params = set(ipc_handles.keys())
|
||||
|
||||
print(f"[Setup] HF model has {len(hf_params)} params, vLLM exported {len(vllm_params)} tensors")
|
||||
|
||||
# Debug: show sample names
|
||||
hf_sample = sorted(list(hf_params))[:3]
|
||||
vllm_sample = sorted(list(vllm_params))[:3]
|
||||
print(f"[Setup] Sample HF names: {hf_sample}")
|
||||
print(f"[Setup] Sample vLLM names: {vllm_sample}")
|
||||
|
||||
mapping = {}
|
||||
# Track fused tensors that need slicing
|
||||
fused_mappings = {} # hf_name -> (vllm_name, slice_type, slice_index)
|
||||
|
||||
for hf_name in hf_params:
|
||||
# Try direct match first
|
||||
|
|
@ -678,18 +707,35 @@ def _create_vllm_to_hf_mapping(model: torch.nn.Module, ipc_handles: dict) -> dic
|
|||
mapping[hf_name] = hf_name
|
||||
continue
|
||||
|
||||
# Try common transformations
|
||||
# vLLM often uses 'model.' prefix
|
||||
vllm_name = f"model.{hf_name}" if not hf_name.startswith("model.") else hf_name
|
||||
if vllm_name in vllm_params:
|
||||
mapping[hf_name] = vllm_name
|
||||
continue
|
||||
|
||||
# Remove 'model.' prefix if present
|
||||
if hf_name.startswith("model."):
|
||||
vllm_name = hf_name[6:]
|
||||
# Check for fused QKV projection
|
||||
# HF: model.layers.X.self_attn.q_proj.weight -> vLLM: model.layers.X.self_attn.qkv_proj.weight
|
||||
if '.self_attn.q_proj.' in hf_name:
|
||||
vllm_name = hf_name.replace('.q_proj.', '.qkv_proj.')
|
||||
if vllm_name in vllm_params:
|
||||
mapping[hf_name] = vllm_name
|
||||
fused_mappings[hf_name] = (vllm_name, 'qkv', 0) # Q is first
|
||||
continue
|
||||
if '.self_attn.k_proj.' in hf_name:
|
||||
vllm_name = hf_name.replace('.k_proj.', '.qkv_proj.')
|
||||
if vllm_name in vllm_params:
|
||||
fused_mappings[hf_name] = (vllm_name, 'qkv', 1) # K is second
|
||||
continue
|
||||
if '.self_attn.v_proj.' in hf_name:
|
||||
vllm_name = hf_name.replace('.v_proj.', '.qkv_proj.')
|
||||
if vllm_name in vllm_params:
|
||||
fused_mappings[hf_name] = (vllm_name, 'qkv', 2) # V is third
|
||||
continue
|
||||
|
||||
# Check for fused gate_up projection
|
||||
# HF: model.layers.X.mlp.gate_proj.weight -> vLLM: model.layers.X.mlp.gate_up_proj.weight
|
||||
if '.mlp.gate_proj.' in hf_name:
|
||||
vllm_name = hf_name.replace('.gate_proj.', '.gate_up_proj.')
|
||||
if vllm_name in vllm_params:
|
||||
fused_mappings[hf_name] = (vllm_name, 'gate_up', 0) # gate is first
|
||||
continue
|
||||
if '.mlp.up_proj.' in hf_name:
|
||||
vllm_name = hf_name.replace('.up_proj.', '.gate_up_proj.')
|
||||
if vllm_name in vllm_params:
|
||||
fused_mappings[hf_name] = (vllm_name, 'gate_up', 1) # up is second
|
||||
continue
|
||||
|
||||
# For lm_head, check if it's tied to embed_tokens
|
||||
|
|
@ -697,13 +743,12 @@ def _create_vllm_to_hf_mapping(model: torch.nn.Module, ipc_handles: dict) -> dic
|
|||
mapping[hf_name] = "model.embed_tokens.weight"
|
||||
continue
|
||||
|
||||
print(f"[Setup] Mapped {len(mapping)} HF params to vLLM tensors")
|
||||
total_mapped = len(mapping) + len(fused_mappings)
|
||||
print(f"[Setup] Direct mappings: {len(mapping)}, Fused mappings: {len(fused_mappings)}")
|
||||
print(f"[Setup] Total mapped: {total_mapped} / {len(hf_params)}")
|
||||
|
||||
# Show what's NOT mapped
|
||||
unmapped = hf_params - set(mapping.keys())
|
||||
if unmapped:
|
||||
unmapped_sample = sorted(list(unmapped))[:5]
|
||||
print(f"[Setup] Unmapped HF params ({len(unmapped)} total): {unmapped_sample}...")
|
||||
# Store fused mappings for later slicing
|
||||
mapping['_fused_'] = fused_mappings
|
||||
|
||||
return mapping
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue