serialization errors

This commit is contained in:
Jai Suphavadeeprasit 2026-01-13 12:17:48 -05:00
parent f46e5c562d
commit 19b3116b84

View file

@ -454,21 +454,46 @@ def _attach_to_vllm_shared_tensors(
print("[Setup] Single-copy mode not available (no IPC handles exported)")
return None
ipc_handles = bridge_config.get("ipc_handles", {})
if not ipc_handles:
ipc_handles_raw = bridge_config.get("ipc_handles", {})
if not ipc_handles_raw:
print("[Setup] No IPC handles found in bridge config")
return None
# Deserialize base64-encoded bytes back to bytes
import base64
def deserialize_ipc_handles(handles):
result = {}
for k, v in handles.items():
if isinstance(v, dict):
if "_bytes_b64_" in v:
result[k] = base64.b64decode(v["_bytes_b64_"])
else:
result[k] = deserialize_ipc_handles(v)
else:
result[k] = v
return result
ipc_handles = deserialize_ipc_handles(ipc_handles_raw)
print(f"[Setup] Attaching to vLLM's shared tensors ({len(ipc_handles)} tensors)...")
print("[Setup] TRUE SINGLE-COPY MODE - No additional model memory!")
# Create model architecture (meta device - no memory allocation)
# Load model config (not weights) to get architecture
from transformers import AutoConfig
model_config = AutoConfig.from_pretrained(config.model_name)
# Create empty model on meta device (no memory allocation)
with torch.device('meta'):
model = AutoModelForCausalLM.from_pretrained(
config.model_name,
model = AutoModelForCausalLM.from_config(
model_config,
torch_dtype=torch.bfloat16,
)
# Get parameter names from the empty model
param_names = list(model.state_dict().keys())
print(f"[Setup] Model architecture has {len(param_names)} parameters")
# Map vLLM tensor names to HuggingFace model parameter names
hf_state_dict = {}
vllm_to_hf_mapping = _create_vllm_to_hf_mapping(model, ipc_handles)
@ -482,7 +507,16 @@ def _attach_to_vllm_shared_tensors(
try:
# Reconstruct tensor from IPC handle
handle_bytes = bytes.fromhex(ipc_info["handle"])
# Handle can be bytes (deserialized from base64) or hex string
handle = ipc_info["handle"]
if isinstance(handle, bytes):
handle_bytes = handle
elif isinstance(handle, str):
handle_bytes = bytes.fromhex(handle)
else:
print(f"[Setup] Unknown handle type for {hf_name}: {type(handle)}")
continue
storage_size = ipc_info["storage_size"]
device_index = ipc_info["device_index"]