mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-30 17:40:36 +00:00
major refactor
This commit is contained in:
parent
119721ef3d
commit
6833d4d820
13 changed files with 3268 additions and 3423 deletions
607
example_trainer/model.py
Normal file
607
example_trainer/model.py
Normal file
|
|
@ -0,0 +1,607 @@
|
|||
"""
|
||||
Model loading and shared memory utilities for GRPO trainer.
|
||||
|
||||
Handles:
|
||||
- Standard model loading (legacy mode)
|
||||
- LoRA model loading and wrapping
|
||||
- Single-copy mode: Attaching to vLLM's shared tensors via CUDA IPC
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from .config import TrainingConfig
|
||||
|
||||
# Import PEFT for LoRA training
|
||||
try:
|
||||
from peft import LoraConfig, TaskType, get_peft_model
|
||||
PEFT_AVAILABLE = True
|
||||
except ImportError:
|
||||
PEFT_AVAILABLE = False
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
config: TrainingConfig,
|
||||
single_copy: bool = False,
|
||||
) -> Tuple[torch.nn.Module, AutoTokenizer]:
|
||||
"""
|
||||
Load or attach to model based on weight_bridge_mode.
|
||||
|
||||
Args:
|
||||
config: Training configuration
|
||||
single_copy: If True, attach to vLLM's shared tensors via CUDA IPC
|
||||
|
||||
Returns:
|
||||
Tuple of (model, tokenizer)
|
||||
"""
|
||||
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
||||
|
||||
# Single-copy mode: attach to vLLM's shared tensors via CUDA IPC
|
||||
if single_copy or config.weight_bridge_mode == "shared_vllm":
|
||||
config_path = _find_vllm_config(config)
|
||||
model = _attach_to_vllm_shared_tensors(config, config_path)
|
||||
|
||||
if model is not None:
|
||||
print("[Setup] ✓ Single-copy mode active - using vLLM's tensors directly!")
|
||||
model.train()
|
||||
return model, tokenizer
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"[Setup] Single-copy mode FAILED to attach to vLLM's tensors.\n"
|
||||
"Check:\n"
|
||||
" 1. vLLM running with VLLM_ENABLE_SHARED_WEIGHTS=1\n"
|
||||
" 2. vllm_bridge_config.json exists with ipc_handles\n"
|
||||
" 3. Trainer is on SAME GPUs as vLLM"
|
||||
)
|
||||
|
||||
elif config.weight_bridge_mode == "lora_only":
|
||||
model = _load_model_with_lora(config)
|
||||
|
||||
else:
|
||||
# Legacy mode: load full model
|
||||
print("[Setup] Loading model for legacy mode...")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
config.model_name, torch_dtype=torch.bfloat16
|
||||
)
|
||||
model.to(config.device)
|
||||
|
||||
# Enable gradient checkpointing
|
||||
_setup_gradient_checkpointing(model, config)
|
||||
model.train()
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def _find_vllm_config(config: TrainingConfig) -> str:
|
||||
"""Find the vllm_bridge_config.json file."""
|
||||
# Check explicit path first
|
||||
if config.vllm_config_path and os.path.exists(config.vllm_config_path):
|
||||
print(f"[Setup] Using explicit vLLM config path: {config.vllm_config_path}")
|
||||
return config.vllm_config_path
|
||||
|
||||
# Auto-detect from common locations
|
||||
possible_paths = [
|
||||
os.environ.get("LOGDIR", "."),
|
||||
".",
|
||||
"/tmp/atropos_bridge",
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
]
|
||||
|
||||
for log_dir in possible_paths:
|
||||
candidate = os.path.join(log_dir, "vllm_bridge_config.json")
|
||||
if os.path.exists(candidate):
|
||||
print(f"[Setup] Found vLLM config at: {candidate}")
|
||||
return candidate
|
||||
|
||||
checked = [os.path.join(p, "vllm_bridge_config.json") for p in possible_paths]
|
||||
raise RuntimeError(
|
||||
f"[Setup] Could not find vllm_bridge_config.json\n"
|
||||
f"Checked: {checked}\n"
|
||||
f"Tip: Use --vllm-config-path to specify the path explicitly\n"
|
||||
f"Make sure vLLM is running with VLLM_ENABLE_SHARED_WEIGHTS=1 and LOGDIR set"
|
||||
)
|
||||
|
||||
|
||||
def _load_model_with_lora(config: TrainingConfig) -> torch.nn.Module:
|
||||
"""
|
||||
Load base model and wrap with LoRA adapters.
|
||||
|
||||
Args:
|
||||
config: Training configuration with LoRA settings
|
||||
|
||||
Returns:
|
||||
PEFT model with LoRA adapters applied
|
||||
"""
|
||||
if not PEFT_AVAILABLE:
|
||||
raise RuntimeError("PEFT library not available. Install with: pip install peft")
|
||||
|
||||
print("[Setup] Loading base model for LoRA mode...")
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
config.model_name, torch_dtype=torch.bfloat16
|
||||
)
|
||||
base_model.to(config.device)
|
||||
|
||||
# Determine target modules
|
||||
target_modules = config.lora_target_modules
|
||||
if target_modules is None:
|
||||
target_modules = ["q_proj", "v_proj"]
|
||||
|
||||
print(f"Applying LoRA: r={config.lora_r}, alpha={config.lora_alpha}")
|
||||
print(f"Target modules: {target_modules}")
|
||||
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
r=config.lora_r,
|
||||
lora_alpha=config.lora_alpha,
|
||||
lora_dropout=config.lora_dropout,
|
||||
target_modules=target_modules,
|
||||
bias="none",
|
||||
)
|
||||
|
||||
model = get_peft_model(base_model, lora_config)
|
||||
model.print_trainable_parameters()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def _setup_gradient_checkpointing(model: torch.nn.Module, config: TrainingConfig) -> None:
|
||||
"""Configure gradient checkpointing for the model."""
|
||||
# Disable KV cache - incompatible with gradient checkpointing
|
||||
model.config.use_cache = False
|
||||
|
||||
if config.weight_bridge_mode == "lora_only":
|
||||
# PEFT models need special handling
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
model.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||
)
|
||||
else:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
|
||||
def _attach_to_vllm_shared_tensors(
|
||||
config: TrainingConfig,
|
||||
bridge_config_path: str,
|
||||
) -> Optional[torch.nn.Module]:
|
||||
"""
|
||||
Attach to vLLM's shared tensors via CUDA IPC (true single-copy mode).
|
||||
|
||||
This creates a model whose parameters point to the SAME GPU memory as vLLM,
|
||||
meaning only ONE copy of the model exists in GPU memory.
|
||||
|
||||
Args:
|
||||
config: Training configuration
|
||||
bridge_config_path: Path to vllm_bridge_config.json
|
||||
|
||||
Returns:
|
||||
Model with parameters pointing to vLLM's tensors, or None if not possible
|
||||
"""
|
||||
print(f"[Setup] Reading bridge config from: {bridge_config_path}")
|
||||
try:
|
||||
with open(bridge_config_path, "r") as f:
|
||||
bridge_config = json.load(f)
|
||||
print(f"[Setup] Bridge config keys: {list(bridge_config.keys())}")
|
||||
except Exception as e:
|
||||
print(f"[Setup] Could not read bridge config: {e}")
|
||||
return None
|
||||
|
||||
single_copy_enabled = bridge_config.get("single_copy_enabled", False)
|
||||
print(f"[Setup] single_copy_enabled in config: {single_copy_enabled}")
|
||||
|
||||
if not single_copy_enabled:
|
||||
print("[Setup] Single-copy mode not available (single_copy_enabled=False)")
|
||||
print("[Setup] Make sure vLLM was started with VLLM_ENABLE_SHARED_WEIGHTS=1")
|
||||
return None
|
||||
|
||||
ipc_handles_raw = bridge_config.get("ipc_handles", {})
|
||||
print(f"[Setup] IPC handles count: {len(ipc_handles_raw)}")
|
||||
if not ipc_handles_raw:
|
||||
print("[Setup] No IPC handles found in bridge config")
|
||||
return None
|
||||
|
||||
# Deserialize base64-encoded bytes
|
||||
ipc_handles = _deserialize_ipc_handles(ipc_handles_raw)
|
||||
|
||||
print(f"[Setup] Attaching to vLLM's shared tensors ({len(ipc_handles)} tensors)...")
|
||||
print("[Setup] TRUE SINGLE-COPY MODE - No additional model memory!")
|
||||
|
||||
# Load model config (not weights) to get architecture
|
||||
model_config = AutoConfig.from_pretrained(config.model_name)
|
||||
|
||||
# Create empty model on meta device (no memory allocation)
|
||||
with torch.device("meta"):
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
model_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
param_names = list(model.state_dict().keys())
|
||||
print(f"[Setup] Model architecture has {len(param_names)} parameters", flush=True)
|
||||
|
||||
# Initialize CUDA on devices used by vLLM
|
||||
device_indices = _initialize_cuda_devices(ipc_handles)
|
||||
|
||||
# Create mapping from HF names to vLLM tensors
|
||||
vllm_to_hf_mapping = _create_vllm_to_hf_mapping(
|
||||
model, ipc_handles, debug=config.debug_loading
|
||||
)
|
||||
|
||||
# Reconstruct tensors and build state dict
|
||||
hf_state_dict, attached_count, fused_count = _reconstruct_shared_tensors(
|
||||
ipc_handles, vllm_to_hf_mapping, config
|
||||
)
|
||||
|
||||
print(f"[Setup] Attached {attached_count} tensors ({fused_count} from fused layers)")
|
||||
|
||||
if attached_count == 0:
|
||||
print("[Setup] Could not attach any tensors, falling back to regular loading")
|
||||
return None
|
||||
|
||||
# Validate mapping coverage
|
||||
_validate_mapping_coverage(model, hf_state_dict, attached_count)
|
||||
|
||||
# Load state dict into model
|
||||
model.load_state_dict(hf_state_dict, strict=False, assign=True)
|
||||
|
||||
# Initialize remaining meta tensors
|
||||
device = f"cuda:{list(device_indices)[0]}" if device_indices else "cuda:0"
|
||||
_initialize_meta_tensors(model, device, config)
|
||||
|
||||
# Final validation - ensure nothing is left on meta device
|
||||
_validate_no_meta_tensors(model)
|
||||
|
||||
print("[Setup] ✓ All tensors successfully initialized on CUDA")
|
||||
return model
|
||||
|
||||
|
||||
def _deserialize_ipc_handles(handles_raw: dict) -> dict:
|
||||
"""Deserialize base64-encoded bytes in IPC handles."""
|
||||
def deserialize(handles):
|
||||
result = {}
|
||||
for k, v in handles.items():
|
||||
if isinstance(v, dict):
|
||||
if "_bytes_b64_" in v:
|
||||
result[k] = base64.b64decode(v["_bytes_b64_"])
|
||||
else:
|
||||
result[k] = deserialize(v)
|
||||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
return deserialize(handles_raw)
|
||||
|
||||
|
||||
def _initialize_cuda_devices(ipc_handles: dict) -> set:
|
||||
"""Initialize CUDA context on devices used by IPC handles."""
|
||||
device_indices = set()
|
||||
for name, info in ipc_handles.items():
|
||||
if "device_index" in info:
|
||||
device_indices.add(info["device_index"])
|
||||
|
||||
print(f"[Setup] IPC handles span devices: {sorted(device_indices)}", flush=True)
|
||||
|
||||
for dev_idx in sorted(device_indices):
|
||||
print(f"[Setup] Initializing CUDA on device {dev_idx}...", flush=True)
|
||||
torch.cuda.set_device(dev_idx)
|
||||
torch.cuda.synchronize(dev_idx)
|
||||
print(f"[Setup] ✓ Device {dev_idx} ready", flush=True)
|
||||
|
||||
return device_indices
|
||||
|
||||
|
||||
def _reconstruct_shared_tensors(
|
||||
ipc_handles: dict,
|
||||
vllm_to_hf_mapping: dict,
|
||||
config: TrainingConfig,
|
||||
) -> Tuple[dict, int, int]:
|
||||
"""Reconstruct tensors from IPC handles and build state dict."""
|
||||
hf_state_dict = {}
|
||||
vllm_tensor_cache: Dict[str, torch.Tensor] = {}
|
||||
attached_count = 0
|
||||
fused_count = 0
|
||||
|
||||
def reconstruct_vllm_tensor(vllm_name: str) -> Optional[torch.Tensor]:
|
||||
if vllm_name in vllm_tensor_cache:
|
||||
return vllm_tensor_cache[vllm_name]
|
||||
|
||||
if vllm_name not in ipc_handles:
|
||||
return None
|
||||
|
||||
ipc_info = ipc_handles[vllm_name]
|
||||
if "ipc_handle_b64" not in ipc_info:
|
||||
return None
|
||||
|
||||
try:
|
||||
device_index = ipc_info["device_index"]
|
||||
ipc_handle = base64.b64decode(ipc_info["ipc_handle_b64"])
|
||||
storage_size = ipc_info["storage_size"]
|
||||
storage_offset_orig = ipc_info["storage_offset_orig"]
|
||||
ref_counter_handle = base64.b64decode(ipc_info["ref_counter_handle_b64"])
|
||||
ref_counter_offset = ipc_info["ref_counter_offset"]
|
||||
event_handle = base64.b64decode(ipc_info["event_handle_b64"])
|
||||
event_sync_required = ipc_info["event_sync_required"]
|
||||
|
||||
share_tuple = (
|
||||
device_index, ipc_handle, storage_size, storage_offset_orig,
|
||||
ref_counter_handle, ref_counter_offset, event_handle, event_sync_required,
|
||||
)
|
||||
|
||||
storage = torch.UntypedStorage._new_shared_cuda(*share_tuple)
|
||||
dtype = getattr(torch, ipc_info["dtype"].replace("torch.", ""))
|
||||
tensor = torch.tensor([], dtype=dtype, device=f"cuda:{device_index}")
|
||||
tensor.set_(
|
||||
storage,
|
||||
storage_offset=ipc_info["tensor_storage_offset"],
|
||||
size=ipc_info["shape"],
|
||||
stride=ipc_info["stride"],
|
||||
)
|
||||
|
||||
vllm_tensor_cache[vllm_name] = tensor
|
||||
return tensor
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Setup] Failed to reconstruct {vllm_name}: {e}", flush=True)
|
||||
return None
|
||||
|
||||
for hf_name, mapping_info in vllm_to_hf_mapping.items():
|
||||
try:
|
||||
if isinstance(mapping_info, dict):
|
||||
# Fused mapping - slice the source tensor
|
||||
vllm_name = mapping_info["source"]
|
||||
slice_start, slice_end = mapping_info["slice"]
|
||||
slice_dim = mapping_info["dim"]
|
||||
|
||||
full_tensor = reconstruct_vllm_tensor(vllm_name)
|
||||
if full_tensor is None:
|
||||
continue
|
||||
|
||||
# Create VIEW (not copy) into the fused tensor
|
||||
if slice_dim == 0:
|
||||
tensor = full_tensor[slice_start:slice_end]
|
||||
else:
|
||||
tensor = full_tensor.narrow(slice_dim, slice_start, slice_end - slice_start)
|
||||
|
||||
tensor.requires_grad_(True)
|
||||
hf_state_dict[hf_name] = tensor
|
||||
fused_count += 1
|
||||
attached_count += 1
|
||||
|
||||
else:
|
||||
# Direct mapping
|
||||
vllm_name = mapping_info
|
||||
tensor = reconstruct_vllm_tensor(vllm_name)
|
||||
if tensor is None:
|
||||
continue
|
||||
|
||||
tensor.requires_grad_(True)
|
||||
hf_state_dict[hf_name] = tensor
|
||||
attached_count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Setup] Failed to attach {hf_name}: {e}", flush=True)
|
||||
|
||||
return hf_state_dict, attached_count, fused_count
|
||||
|
||||
|
||||
def _validate_mapping_coverage(
|
||||
model: torch.nn.Module,
|
||||
hf_state_dict: dict,
|
||||
attached_count: int,
|
||||
) -> None:
|
||||
"""Validate that enough parameters were mapped."""
|
||||
hf_param_count = len(list(model.named_parameters()))
|
||||
mapping_coverage = attached_count / hf_param_count if hf_param_count > 0 else 0
|
||||
|
||||
print(f"[Setup] Mapping coverage: {attached_count} tensors for {hf_param_count} parameters")
|
||||
|
||||
if mapping_coverage < 0.90:
|
||||
unmapped_params = set(model.state_dict().keys()) - set(hf_state_dict.keys())
|
||||
warning_msg = f"[Setup] WARNING: Low mapping coverage ({mapping_coverage:.1%})\n"
|
||||
warning_msg += f"Unmapped parameters ({len(unmapped_params)}):\n"
|
||||
for name in list(unmapped_params)[:20]:
|
||||
warning_msg += f" - {name}\n"
|
||||
print(warning_msg)
|
||||
|
||||
if mapping_coverage < 0.50:
|
||||
raise RuntimeError(
|
||||
f"[Setup] CRITICAL: Only {mapping_coverage:.1%} of parameters mapped!"
|
||||
)
|
||||
else:
|
||||
print(f"[Setup] ✓ Good mapping coverage ({mapping_coverage:.1%})")
|
||||
|
||||
|
||||
def _initialize_meta_tensors(
|
||||
model: torch.nn.Module,
|
||||
device: str,
|
||||
config: TrainingConfig,
|
||||
) -> None:
|
||||
"""Initialize any remaining meta tensors after loading."""
|
||||
meta_params = [name for name, p in model.named_parameters() if p.device.type == "meta"]
|
||||
meta_buffers = [name for name, b in model.named_buffers() if b.device.type == "meta"]
|
||||
|
||||
if config.debug_loading:
|
||||
print(f"\n[DIAGNOSTIC] Meta params: {len(meta_params)}, Meta buffers: {len(meta_buffers)}")
|
||||
|
||||
def get_parent_and_name(model, full_name):
|
||||
parts = full_name.split(".")
|
||||
parent = model
|
||||
for part in parts[:-1]:
|
||||
parent = getattr(parent, part)
|
||||
return parent, parts[-1]
|
||||
|
||||
meta_count = 0
|
||||
|
||||
# Initialize meta parameters
|
||||
for name in meta_params:
|
||||
param = dict(model.named_parameters()).get(name)
|
||||
if param is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
new_data = torch.zeros(param.shape, dtype=param.dtype, device=device)
|
||||
new_param = torch.nn.Parameter(new_data, requires_grad=param.requires_grad)
|
||||
parent, attr_name = get_parent_and_name(model, name)
|
||||
setattr(parent, attr_name, new_param)
|
||||
meta_count += 1
|
||||
except Exception as e:
|
||||
if config.debug_loading:
|
||||
print(f"[DIAGNOSTIC] FAILED to initialize {name}: {e}")
|
||||
|
||||
# Initialize meta buffers
|
||||
for name in meta_buffers:
|
||||
buffer = dict(model.named_buffers()).get(name)
|
||||
if buffer is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
if "inv_freq" in name:
|
||||
dim = buffer.shape[0] * 2
|
||||
base = 10000.0
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
||||
new_buffer = inv_freq.to(dtype=buffer.dtype, device=device)
|
||||
else:
|
||||
new_buffer = torch.zeros(buffer.shape, dtype=buffer.dtype, device=device)
|
||||
|
||||
parent, attr_name = get_parent_and_name(model, name)
|
||||
parent.register_buffer(attr_name, new_buffer)
|
||||
meta_count += 1
|
||||
except Exception as e:
|
||||
if config.debug_loading:
|
||||
print(f"[DIAGNOSTIC] FAILED to initialize buffer {name}: {e}")
|
||||
|
||||
print(f"\n[Setup] Initialized {meta_count} remaining meta tensors")
|
||||
|
||||
|
||||
def _validate_no_meta_tensors(model: torch.nn.Module) -> None:
|
||||
"""Ensure no parameters or buffers are still on meta device."""
|
||||
final_meta_params = [name for name, p in model.named_parameters() if p.device.type == "meta"]
|
||||
final_meta_buffers = [name for name, b in model.named_buffers() if b.device.type == "meta"]
|
||||
|
||||
if final_meta_params or final_meta_buffers:
|
||||
error_msg = "[Setup] CRITICAL ERROR: Some tensors are still on meta device!\n"
|
||||
error_msg += "The model would produce GARBAGE output.\n\n"
|
||||
|
||||
if final_meta_params:
|
||||
error_msg += f"Meta parameters ({len(final_meta_params)}):\n"
|
||||
for name in final_meta_params[:20]:
|
||||
error_msg += f" - {name}\n"
|
||||
|
||||
if final_meta_buffers:
|
||||
error_msg += f"\nMeta buffers ({len(final_meta_buffers)}):\n"
|
||||
for name in final_meta_buffers[:20]:
|
||||
error_msg += f" - {name}\n"
|
||||
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
|
||||
def _create_vllm_to_hf_mapping(
|
||||
model: torch.nn.Module,
|
||||
ipc_handles: dict,
|
||||
debug: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Create mapping from HuggingFace parameter names to vLLM tensor names.
|
||||
|
||||
Handles fused layers:
|
||||
- qkv_proj (vLLM) = q_proj + k_proj + v_proj (HF)
|
||||
- gate_up_proj (vLLM) = gate_proj + up_proj (HF)
|
||||
"""
|
||||
hf_params = set(model.state_dict().keys())
|
||||
vllm_params = set(ipc_handles.keys())
|
||||
|
||||
# Get model config for dimension calculations
|
||||
model_config = model.config
|
||||
hidden_size = getattr(model_config, "hidden_size", 4096)
|
||||
num_attention_heads = getattr(model_config, "num_attention_heads", 32)
|
||||
num_key_value_heads = getattr(model_config, "num_key_value_heads", num_attention_heads)
|
||||
intermediate_size = getattr(model_config, "intermediate_size", hidden_size * 4)
|
||||
head_dim = hidden_size // num_attention_heads
|
||||
|
||||
# QKV sizes
|
||||
q_size = hidden_size
|
||||
k_size = num_key_value_heads * head_dim
|
||||
v_size = num_key_value_heads * head_dim
|
||||
|
||||
if debug:
|
||||
print(f"[Mapping] Model config: hidden={hidden_size}, heads={num_attention_heads}, "
|
||||
f"kv_heads={num_key_value_heads}, intermediate={intermediate_size}")
|
||||
|
||||
mapping = {}
|
||||
|
||||
def find_vllm_name(hf_name: str) -> Optional[str]:
|
||||
if hf_name in vllm_params:
|
||||
return hf_name
|
||||
if not hf_name.startswith("model."):
|
||||
candidate = f"model.{hf_name}"
|
||||
if candidate in vllm_params:
|
||||
return candidate
|
||||
if hf_name.startswith("model."):
|
||||
candidate = hf_name[6:]
|
||||
if candidate in vllm_params:
|
||||
return candidate
|
||||
return None
|
||||
|
||||
def find_fused_source(hf_name: str, fused_suffix: str) -> Optional[str]:
|
||||
for unfused in ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"]:
|
||||
if unfused in hf_name:
|
||||
fused_name = hf_name.replace(unfused, fused_suffix)
|
||||
found = find_vllm_name(fused_name)
|
||||
if found:
|
||||
return found
|
||||
return None
|
||||
|
||||
for hf_name in hf_params:
|
||||
# Try direct match first
|
||||
vllm_name = find_vllm_name(hf_name)
|
||||
if vllm_name:
|
||||
mapping[hf_name] = vllm_name
|
||||
continue
|
||||
|
||||
# Check for QKV fusion
|
||||
if any(x in hf_name for x in ["q_proj", "k_proj", "v_proj"]):
|
||||
fused_name = find_fused_source(hf_name, "qkv_proj")
|
||||
if fused_name:
|
||||
if "q_proj" in hf_name:
|
||||
start, end = 0, q_size
|
||||
elif "k_proj" in hf_name:
|
||||
start, end = q_size, q_size + k_size
|
||||
else:
|
||||
start, end = q_size + k_size, q_size + k_size + v_size
|
||||
|
||||
mapping[hf_name] = {
|
||||
"source": fused_name,
|
||||
"slice": (start, end),
|
||||
"dim": 0,
|
||||
"type": "qkv_fusion",
|
||||
}
|
||||
continue
|
||||
|
||||
# Check for Gate/Up fusion
|
||||
if any(x in hf_name for x in ["gate_proj", "up_proj"]):
|
||||
fused_name = find_fused_source(hf_name, "gate_up_proj")
|
||||
if fused_name:
|
||||
if "gate_proj" in hf_name:
|
||||
start, end = 0, intermediate_size
|
||||
else:
|
||||
start, end = intermediate_size, intermediate_size * 2
|
||||
|
||||
mapping[hf_name] = {
|
||||
"source": fused_name,
|
||||
"slice": (start, end),
|
||||
"dim": 0,
|
||||
"type": "gate_up_fusion",
|
||||
}
|
||||
continue
|
||||
|
||||
if debug:
|
||||
direct = sum(1 for v in mapping.values() if isinstance(v, str))
|
||||
fused = sum(1 for v in mapping.values() if isinstance(v, dict))
|
||||
print(f"[Mapping] Total: {len(mapping)} mapped ({direct} direct, {fused} fused)")
|
||||
|
||||
return mapping
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue