other changes

This commit is contained in:
Jai Suphavadeeprasit 2026-01-13 22:20:43 -05:00
parent 6efec3f1c5
commit 9a95ec5aa1

View file

@ -631,10 +631,23 @@ def _attach_to_vllm_shared_tensors(
if buffer_count > 0:
print(f"[Setup] ✓ Moved {buffer_count} buffers to {device}")
# Also move any remaining CPU parameters (shouldn't be many)
# Check for unmatched parameters still on CPU
# These indicate a name mapping issue - don't try to move large ones
unmatched_cpu = []
for name, param in model.named_parameters():
if param.device.type == 'cpu':
param.data = param.to(device)
unmatched_cpu.append((name, param.numel()))
if unmatched_cpu:
total_unmatched = sum(n for _, n in unmatched_cpu)
print(f"[Setup] WARNING: {len(unmatched_cpu)} parameters ({total_unmatched:,} elements) not matched!")
print(f"[Setup] First few unmatched: {[n for n, _ in unmatched_cpu[:5]]}")
print(f"[Setup] This may cause issues - check name mapping between vLLM and HuggingFace")
# Only move small parameters (< 1M elements) - skip large unmatched ones
for name, param in model.named_parameters():
if param.device.type == 'cpu' and param.numel() < 1_000_000:
param.data = param.to(device)
return model
@ -649,6 +662,14 @@ def _create_vllm_to_hf_mapping(model: torch.nn.Module, ipc_handles: dict) -> dic
hf_params = set(model.state_dict().keys())
vllm_params = set(ipc_handles.keys())
print(f"[Setup] HF model has {len(hf_params)} params, vLLM exported {len(vllm_params)} tensors")
# Debug: show sample names
hf_sample = sorted(list(hf_params))[:3]
vllm_sample = sorted(list(vllm_params))[:3]
print(f"[Setup] Sample HF names: {hf_sample}")
print(f"[Setup] Sample vLLM names: {vllm_sample}")
mapping = {}
for hf_name in hf_params:
@ -669,6 +690,20 @@ def _create_vllm_to_hf_mapping(model: torch.nn.Module, ipc_handles: dict) -> dic
vllm_name = hf_name[6:]
if vllm_name in vllm_params:
mapping[hf_name] = vllm_name
continue
# For lm_head, check if it's tied to embed_tokens
if hf_name == "lm_head.weight" and "model.embed_tokens.weight" in vllm_params:
mapping[hf_name] = "model.embed_tokens.weight"
continue
print(f"[Setup] Mapped {len(mapping)} HF params to vLLM tensors")
# Show what's NOT mapped
unmapped = hf_params - set(mapping.keys())
if unmapped:
unmapped_sample = sorted(list(unmapped))[:5]
print(f"[Setup] Unmapped HF params ({len(unmapped)} total): {unmapped_sample}...")
return mapping