mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
buffer efficiency
This commit is contained in:
parent
dff4065982
commit
bf50ed37d9
1 changed files with 39 additions and 30 deletions
|
|
@ -484,38 +484,34 @@ def _attach_to_vllm_shared_tensors(
|
|||
ipc_handles = deserialize_ipc_handles(ipc_handles_raw)
|
||||
|
||||
print(f"[Setup] Attaching to vLLM's shared tensors ({len(ipc_handles)} tensors)...")
|
||||
print("[Setup] TRUE SINGLE-COPY MODE - No additional model memory!")
|
||||
print("[Setup] TRUE SINGLE-COPY MODE - sharing vLLM's GPU memory!")
|
||||
|
||||
# Load model config (not weights) to get architecture
|
||||
from transformers import AutoConfig
|
||||
model_config = AutoConfig.from_pretrained(config.model_name)
|
||||
|
||||
# Create empty model on meta device (no memory allocation)
|
||||
with torch.device('meta'):
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
model_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# Get parameter names from the empty model
|
||||
param_names = list(model.state_dict().keys())
|
||||
print(f"[Setup] Model architecture has {len(param_names)} parameters", flush=True)
|
||||
|
||||
# Initialize CUDA before IPC operations
|
||||
# Get the device indices we'll be using
|
||||
# Get device from IPC handles
|
||||
device_indices = set()
|
||||
for name, info in ipc_handles.items():
|
||||
if "device_index" in info:
|
||||
device_indices.add(info["device_index"])
|
||||
|
||||
print(f"[Setup] IPC handles span devices: {sorted(device_indices)}", flush=True)
|
||||
device = f"cuda:{list(device_indices)[0]}" if device_indices else "cuda:0"
|
||||
print(f"[Setup] Target device: {device}", flush=True)
|
||||
|
||||
# Initialize CUDA context on each device
|
||||
for dev_idx in sorted(device_indices):
|
||||
print(f"[Setup] Initializing CUDA on device {dev_idx}...", flush=True)
|
||||
torch.cuda.set_device(dev_idx)
|
||||
torch.cuda.synchronize(dev_idx)
|
||||
print(f"[Setup] ✓ Device {dev_idx} ready", flush=True)
|
||||
# Initialize CUDA context
|
||||
torch.cuda.set_device(int(device.split(':')[1]))
|
||||
torch.cuda.synchronize()
|
||||
print(f"[Setup] ✓ CUDA initialized", flush=True)
|
||||
|
||||
# APPROACH: Load model skeleton on CPU, then swap in IPC tensors
|
||||
# This ensures buffers (like rotary_emb) are properly initialized
|
||||
print(f"[Setup] Loading model structure...", flush=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
config.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=device, # Load directly to GPU
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
||||
param_names = list(model.state_dict().keys())
|
||||
print(f"[Setup] Model has {len(param_names)} parameters", flush=True)
|
||||
|
||||
# Map vLLM tensor names to HuggingFace model parameter names
|
||||
hf_state_dict = {}
|
||||
|
|
@ -603,13 +599,26 @@ def _attach_to_vllm_shared_tensors(
|
|||
continue
|
||||
|
||||
if attached_count == 0:
|
||||
print("[Setup] Could not attach any tensors, falling back to regular loading")
|
||||
return None
|
||||
print("[Setup] Could not attach any tensors - IPC failed")
|
||||
print("[Setup] Model loaded normally (not sharing memory with vLLM)")
|
||||
return model # Return the normally-loaded model
|
||||
|
||||
print(f"[Setup] ✓ Attached {attached_count} tensors to vLLM's shared memory")
|
||||
print(f"[Setup] ✓ Swapped {attached_count} tensors to share vLLM's memory")
|
||||
|
||||
# Load state dict into model
|
||||
model.load_state_dict(hf_state_dict, strict=False, assign=True)
|
||||
# Now swap the model's parameters with the IPC tensors
|
||||
# This makes the model use vLLM's memory directly
|
||||
swap_count = 0
|
||||
for name, param in model.named_parameters():
|
||||
if name in hf_state_dict:
|
||||
ipc_tensor = hf_state_dict[name]
|
||||
# Verify shapes match
|
||||
if param.shape == ipc_tensor.shape:
|
||||
param.data = ipc_tensor
|
||||
swap_count += 1
|
||||
else:
|
||||
print(f"[Setup] Shape mismatch for {name}: model={param.shape}, ipc={ipc_tensor.shape}")
|
||||
|
||||
print(f"[Setup] ✓ {swap_count} parameters now share vLLM's GPU memory!")
|
||||
|
||||
return model
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue