mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-28 17:29:30 +00:00
main changes
This commit is contained in:
parent
2dc1c2a981
commit
2225b4623f
2 changed files with 35 additions and 11 deletions
|
|
@ -514,14 +514,21 @@ def _attach_to_vllm_shared_tensors(
|
|||
|
||||
try:
|
||||
# Reconstruct tensor from IPC handle
|
||||
# Handle can be bytes (deserialized from base64) or hex string
|
||||
handle = ipc_info["handle"]
|
||||
if isinstance(handle, bytes):
|
||||
handle_bytes = handle
|
||||
elif isinstance(handle, str):
|
||||
handle_bytes = bytes.fromhex(handle)
|
||||
# Handle is base64-encoded in the JSON
|
||||
if "handle_b64" in ipc_info:
|
||||
handle_bytes = base64.b64decode(ipc_info["handle_b64"])
|
||||
elif "handle" in ipc_info:
|
||||
# Legacy format - hex string or bytes
|
||||
handle = ipc_info["handle"]
|
||||
if isinstance(handle, bytes):
|
||||
handle_bytes = handle
|
||||
elif isinstance(handle, str):
|
||||
handle_bytes = bytes.fromhex(handle)
|
||||
else:
|
||||
print(f"[Setup] Unknown handle type for {hf_name}: {type(handle)}")
|
||||
continue
|
||||
else:
|
||||
print(f"[Setup] Unknown handle type for {hf_name}: {type(handle)}")
|
||||
print(f"[Setup] No handle found for {hf_name}")
|
||||
continue
|
||||
|
||||
storage_size = ipc_info["storage_size"]
|
||||
|
|
|
|||
|
|
@ -244,12 +244,27 @@ def _create_patched_runner(BaseRunner: type) -> type:
|
|||
# Export CUDA IPC handles for true single-copy mode
|
||||
if tensor.is_cuda:
|
||||
try:
|
||||
# Get the storage's IPC handle
|
||||
import base64
|
||||
# Get the storage's IPC handle tuple
|
||||
storage = tensor.untyped_storage()
|
||||
ipc_handle = storage._share_cuda_()
|
||||
# _share_cuda_() returns: (handle, storage_size, storage_offset, ...)
|
||||
share_data = storage._share_cuda_()
|
||||
|
||||
# Convert handle to bytes - it's a cudaIpcMemHandle_t (64 bytes)
|
||||
handle = share_data[0]
|
||||
if isinstance(handle, bytes):
|
||||
handle_bytes = handle
|
||||
elif hasattr(handle, '__bytes__'):
|
||||
handle_bytes = bytes(handle)
|
||||
else:
|
||||
# For cudaIpcMemHandle_t object, get raw bytes via memoryview
|
||||
import ctypes
|
||||
# cudaIpcMemHandle_t is 64 bytes
|
||||
handle_bytes = bytes(memoryview(handle).cast('B')[:64])
|
||||
|
||||
ipc_handles[name] = {
|
||||
"handle": ipc_handle[0].hex() if isinstance(ipc_handle[0], bytes) else str(ipc_handle[0]),
|
||||
"storage_size": ipc_handle[1],
|
||||
"handle_b64": base64.b64encode(handle_bytes).decode('ascii'),
|
||||
"storage_size": share_data[1],
|
||||
"storage_offset": tensor.storage_offset(),
|
||||
"shape": list(tensor.shape),
|
||||
"stride": list(tensor.stride()),
|
||||
|
|
@ -258,6 +273,8 @@ def _create_patched_runner(BaseRunner: type) -> type:
|
|||
}
|
||||
except Exception as e:
|
||||
print(f"[vLLM Patch] Could not get IPC handle for {name}: {e}", flush=True)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
print(f"[vLLM Patch] Exported {len(ipc_handles)} IPC handles for single-copy mode", flush=True)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue