mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-28 17:29:30 +00:00
serialization errors
This commit is contained in:
parent
f46e5c562d
commit
19b3116b84
1 changed files with 40 additions and 6 deletions
|
|
@ -454,21 +454,46 @@ def _attach_to_vllm_shared_tensors(
|
|||
print("[Setup] Single-copy mode not available (no IPC handles exported)")
|
||||
return None
|
||||
|
||||
ipc_handles = bridge_config.get("ipc_handles", {})
|
||||
if not ipc_handles:
|
||||
ipc_handles_raw = bridge_config.get("ipc_handles", {})
|
||||
if not ipc_handles_raw:
|
||||
print("[Setup] No IPC handles found in bridge config")
|
||||
return None
|
||||
|
||||
# Deserialize base64-encoded bytes back to bytes
|
||||
import base64
|
||||
|
||||
def deserialize_ipc_handles(handles):
|
||||
result = {}
|
||||
for k, v in handles.items():
|
||||
if isinstance(v, dict):
|
||||
if "_bytes_b64_" in v:
|
||||
result[k] = base64.b64decode(v["_bytes_b64_"])
|
||||
else:
|
||||
result[k] = deserialize_ipc_handles(v)
|
||||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
ipc_handles = deserialize_ipc_handles(ipc_handles_raw)
|
||||
|
||||
print(f"[Setup] Attaching to vLLM's shared tensors ({len(ipc_handles)} tensors)...")
|
||||
print("[Setup] TRUE SINGLE-COPY MODE - No additional model memory!")
|
||||
|
||||
# Create model architecture (meta device - no memory allocation)
|
||||
# Load model config (not weights) to get architecture
|
||||
from transformers import AutoConfig
|
||||
model_config = AutoConfig.from_pretrained(config.model_name)
|
||||
|
||||
# Create empty model on meta device (no memory allocation)
|
||||
with torch.device('meta'):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
config.model_name,
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
model_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# Get parameter names from the empty model
|
||||
param_names = list(model.state_dict().keys())
|
||||
print(f"[Setup] Model architecture has {len(param_names)} parameters")
|
||||
|
||||
# Map vLLM tensor names to HuggingFace model parameter names
|
||||
hf_state_dict = {}
|
||||
vllm_to_hf_mapping = _create_vllm_to_hf_mapping(model, ipc_handles)
|
||||
|
|
@ -482,7 +507,16 @@ def _attach_to_vllm_shared_tensors(
|
|||
|
||||
try:
|
||||
# Reconstruct tensor from IPC handle
|
||||
handle_bytes = bytes.fromhex(ipc_info["handle"])
|
||||
# Handle can be bytes (deserialized from base64) or hex string
|
||||
handle = ipc_info["handle"]
|
||||
if isinstance(handle, bytes):
|
||||
handle_bytes = handle
|
||||
elif isinstance(handle, str):
|
||||
handle_bytes = bytes.fromhex(handle)
|
||||
else:
|
||||
print(f"[Setup] Unknown handle type for {hf_name}: {type(handle)}")
|
||||
continue
|
||||
|
||||
storage_size = ipc_info["storage_size"]
|
||||
device_index = ipc_info["device_index"]
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue