This commit is contained in:
Jai Suphavadeeprasit 2026-02-02 22:59:32 -05:00
parent 24b8ab8574
commit c8884348c7
8 changed files with 360 additions and 820 deletions

View file

@ -25,6 +25,67 @@ except ImportError:
PEFT_AVAILABLE = False
def _get_attention_implementation() -> str:
"""
Determine the best attention implementation to use.
Priority:
1. Flash Attention 2 (if flash_attn library is available and works)
2. SDPA (PyTorch's scaled dot-product attention)
Returns:
Tuple of (attn_implementation string, human-readable name)
"""
try:
import flash_attn
return "flash_attention_2"
except ImportError:
return "sdpa"
def _load_model_with_attention(
model_name_or_config,
torch_dtype=torch.bfloat16,
from_config: bool = False,
) -> torch.nn.Module:
"""
Load a model with the best available attention implementation.
Args:
model_name_or_config: Either a model name (str) or a model config object
torch_dtype: Data type for model weights
from_config: If True, use from_config (for meta device loading - no weights)
If False, use from_pretrained (downloads and loads weights)
Returns:
Loaded model with appropriate attention implementation
"""
# Select the loader function based on mode
# from_config: creates empty shell (meta device), from_pretrained: loads weights
loader = AutoModelForCausalLM.from_config if from_config else AutoModelForCausalLM.from_pretrained
# Try attention implementations in order of preference
for attn_impl in ["flash_attention_2", "sdpa"]:
# Skip flash_attention_2 if not available
if attn_impl == "flash_attention_2" and _get_attention_implementation() != "flash_attention_2":
continue
try:
model = loader(
model_name_or_config,
torch_dtype=torch_dtype,
attn_implementation=attn_impl,
)
print(f"[Setup] Using {attn_impl.replace('_', ' ').title()}")
return model
except Exception as e:
if attn_impl == "flash_attention_2":
print(f"[Setup] Flash Attention 2 failed ({e}), trying SDPA...")
continue
raise
# Should never reach here, but just in case
raise RuntimeError("Failed to load model with any attention implementation")
def load_model_and_tokenizer(
config: TrainingConfig,
single_copy: bool = False,
@ -67,27 +128,7 @@ def load_model_and_tokenizer(
else:
# Legacy mode: load full model
print("[Setup] Loading model for legacy mode...")
flash_available = False
try:
import flash_attn
flash_available = True
except ImportError:
pass
if flash_available:
model = AutoModelForCausalLM.from_pretrained(
config.model_name,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
print("[Setup] Using Flash Attention 2")
else:
model = AutoModelForCausalLM.from_pretrained(
config.model_name,
torch_dtype=torch.bfloat16,
attn_implementation="sdpa",
)
print("[Setup] Using SDPA attention")
model = _load_model_with_attention(config.model_name)
model.to(config.device)
# Enable gradient checkpointing
@ -111,7 +152,7 @@ def _find_vllm_config(config: TrainingConfig) -> str:
"/tmp/atropos_bridge",
os.path.dirname(os.path.abspath(__file__)),
]
# Look through possible
for log_dir in possible_paths:
candidate = os.path.join(log_dir, "vllm_bridge_config.json")
if os.path.exists(candidate):
@ -137,31 +178,11 @@ def _load_model_with_lora(config: TrainingConfig) -> torch.nn.Module:
Returns:
PEFT model with LoRA adapters applied
"""
if not PEFT_AVAILABLE:
if not PEFT_AVAILABLE: # Yeah no PEFT is needed no matter what bless huggingface
raise RuntimeError("PEFT library not available. Install with: pip install peft")
print("[Setup] Loading base model for LoRA mode...")
flash_available = False
try:
import flash_attn
flash_available = True
except ImportError:
pass
if flash_available:
base_model = AutoModelForCausalLM.from_pretrained(
config.model_name,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
print("[Setup] Using Flash Attention 2")
else:
base_model = AutoModelForCausalLM.from_pretrained(
config.model_name,
torch_dtype=torch.bfloat16,
attn_implementation="sdpa",
)
print("[Setup] Using SDPA attention")
base_model = _load_model_with_attention(config.model_name)
base_model.to(config.device)
# Determine target modules
@ -221,6 +242,7 @@ def _attach_to_vllm_shared_tensors(
Model with parameters pointing to vLLM's tensors, or None if not possible
"""
print(f"[Setup] Reading bridge config from: {bridge_config_path}")
# Load the bridge that we just searched for
try:
with open(bridge_config_path, "r") as f:
bridge_config = json.load(f)
@ -231,12 +253,12 @@ def _attach_to_vllm_shared_tensors(
single_copy_enabled = bridge_config.get("single_copy_enabled", False)
print(f"[Setup] single_copy_enabled in config: {single_copy_enabled}")
# If single copy is not enable here then we exist because VLLM is likely botched
if not single_copy_enabled:
print("[Setup] Single-copy mode not available (single_copy_enabled=False)")
print("[Setup] Make sure vLLM was started with VLLM_ENABLE_SHARED_WEIGHTS=1")
return None
# Get the IPC handles. from the bridge config. these are memory pointers to the space in memory that shared weights exist in
ipc_handles_raw = bridge_config.get("ipc_handles", {})
print(f"[Setup] IPC handles count: {len(ipc_handles_raw)}")
if not ipc_handles_raw:
@ -250,42 +272,14 @@ def _attach_to_vllm_shared_tensors(
print("[Setup] TRUE SINGLE-COPY MODE - No additional model memory!")
# Load model config (not weights) to get architecture
# doesn't store the buffers just basically the schematics. This is the
# the blueprint for the house not the actual house
model_config = AutoConfig.from_pretrained(config.model_name)
# Create empty model on meta device (no memory allocation)
# Try Flash Attention 2 first (matches vLLM better), fall back to SDPA
with torch.device("meta"):
# Check if flash_attn is available before trying to use it
flash_available = False
try:
import flash_attn
flash_available = True
except ImportError:
pass
if flash_available:
try:
model = AutoModelForCausalLM.from_config(
model_config,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
print("[Setup] Using Flash Attention 2 (best vLLM alignment)")
except Exception as e:
print(f"[Setup] Flash Attention 2 failed ({e}), using SDPA")
model = AutoModelForCausalLM.from_config(
model_config,
torch_dtype=torch.bfloat16,
attn_implementation="sdpa",
)
else:
print("[Setup] flash_attn not installed, using SDPA attention")
print("[Setup] NOTE: ~10-15% logprob diff with vLLM expected (different attention impl)")
model = AutoModelForCausalLM.from_config(
model_config,
torch_dtype=torch.bfloat16,
attn_implementation="sdpa",
)
model = _load_model_with_attention(model_config, from_config=True)
param_names = list(model.state_dict().keys())
print(f"[Setup] Model architecture has {len(param_names)} parameters", flush=True)
@ -586,6 +580,7 @@ def _create_vllm_to_hf_mapping(
rather than calculating from config (which can be wrong for some models).
"""
hf_state_dict = model.state_dict()
print("Here is the HF state dict so that we can get a better view ")
hf_params = set(hf_state_dict.keys())
vllm_params = set(ipc_handles.keys())
@ -601,7 +596,7 @@ def _create_vllm_to_hf_mapping(
if head_dim is None:
head_dim = hidden_size // num_attention_heads
# Determine QKV sizes from ACTUAL HF model tensor shapes (more reliable)
# Determine QKV sizes from ACTUAL HF model tensor shapes
# Look for a q_proj weight in the model to get the actual size
q_size = None
k_size = None