mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-28 17:29:30 +00:00
cleanup
This commit is contained in:
parent
24b8ab8574
commit
c8884348c7
8 changed files with 360 additions and 820 deletions
|
|
@ -25,6 +25,67 @@ except ImportError:
|
|||
PEFT_AVAILABLE = False
|
||||
|
||||
|
||||
def _get_attention_implementation() -> str:
|
||||
"""
|
||||
Determine the best attention implementation to use.
|
||||
Priority:
|
||||
1. Flash Attention 2 (if flash_attn library is available and works)
|
||||
2. SDPA (PyTorch's scaled dot-product attention)
|
||||
|
||||
Returns:
|
||||
Tuple of (attn_implementation string, human-readable name)
|
||||
"""
|
||||
try:
|
||||
import flash_attn
|
||||
return "flash_attention_2"
|
||||
except ImportError:
|
||||
return "sdpa"
|
||||
|
||||
|
||||
def _load_model_with_attention(
|
||||
model_name_or_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
from_config: bool = False,
|
||||
) -> torch.nn.Module:
|
||||
"""
|
||||
Load a model with the best available attention implementation.
|
||||
|
||||
Args:
|
||||
model_name_or_config: Either a model name (str) or a model config object
|
||||
torch_dtype: Data type for model weights
|
||||
from_config: If True, use from_config (for meta device loading - no weights)
|
||||
If False, use from_pretrained (downloads and loads weights)
|
||||
|
||||
Returns:
|
||||
Loaded model with appropriate attention implementation
|
||||
"""
|
||||
# Select the loader function based on mode
|
||||
# from_config: creates empty shell (meta device), from_pretrained: loads weights
|
||||
loader = AutoModelForCausalLM.from_config if from_config else AutoModelForCausalLM.from_pretrained
|
||||
|
||||
# Try attention implementations in order of preference
|
||||
for attn_impl in ["flash_attention_2", "sdpa"]:
|
||||
# Skip flash_attention_2 if not available
|
||||
if attn_impl == "flash_attention_2" and _get_attention_implementation() != "flash_attention_2":
|
||||
continue
|
||||
|
||||
try:
|
||||
model = loader(
|
||||
model_name_or_config,
|
||||
torch_dtype=torch_dtype,
|
||||
attn_implementation=attn_impl,
|
||||
)
|
||||
print(f"[Setup] Using {attn_impl.replace('_', ' ').title()}")
|
||||
return model
|
||||
except Exception as e:
|
||||
if attn_impl == "flash_attention_2":
|
||||
print(f"[Setup] Flash Attention 2 failed ({e}), trying SDPA...")
|
||||
continue
|
||||
raise
|
||||
|
||||
# Should never reach here, but just in case
|
||||
raise RuntimeError("Failed to load model with any attention implementation")
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
config: TrainingConfig,
|
||||
single_copy: bool = False,
|
||||
|
|
@ -67,27 +128,7 @@ def load_model_and_tokenizer(
|
|||
else:
|
||||
# Legacy mode: load full model
|
||||
print("[Setup] Loading model for legacy mode...")
|
||||
flash_available = False
|
||||
try:
|
||||
import flash_attn
|
||||
flash_available = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if flash_available:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
config.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
print("[Setup] Using Flash Attention 2")
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
config.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="sdpa",
|
||||
)
|
||||
print("[Setup] Using SDPA attention")
|
||||
model = _load_model_with_attention(config.model_name)
|
||||
model.to(config.device)
|
||||
|
||||
# Enable gradient checkpointing
|
||||
|
|
@ -111,7 +152,7 @@ def _find_vllm_config(config: TrainingConfig) -> str:
|
|||
"/tmp/atropos_bridge",
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
]
|
||||
|
||||
# Look through possible
|
||||
for log_dir in possible_paths:
|
||||
candidate = os.path.join(log_dir, "vllm_bridge_config.json")
|
||||
if os.path.exists(candidate):
|
||||
|
|
@ -137,31 +178,11 @@ def _load_model_with_lora(config: TrainingConfig) -> torch.nn.Module:
|
|||
Returns:
|
||||
PEFT model with LoRA adapters applied
|
||||
"""
|
||||
if not PEFT_AVAILABLE:
|
||||
if not PEFT_AVAILABLE: # Yeah no PEFT is needed no matter what bless huggingface
|
||||
raise RuntimeError("PEFT library not available. Install with: pip install peft")
|
||||
|
||||
print("[Setup] Loading base model for LoRA mode...")
|
||||
flash_available = False
|
||||
try:
|
||||
import flash_attn
|
||||
flash_available = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if flash_available:
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
config.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
print("[Setup] Using Flash Attention 2")
|
||||
else:
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
config.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="sdpa",
|
||||
)
|
||||
print("[Setup] Using SDPA attention")
|
||||
base_model = _load_model_with_attention(config.model_name)
|
||||
base_model.to(config.device)
|
||||
|
||||
# Determine target modules
|
||||
|
|
@ -221,6 +242,7 @@ def _attach_to_vllm_shared_tensors(
|
|||
Model with parameters pointing to vLLM's tensors, or None if not possible
|
||||
"""
|
||||
print(f"[Setup] Reading bridge config from: {bridge_config_path}")
|
||||
# Load the bridge that we just searched for
|
||||
try:
|
||||
with open(bridge_config_path, "r") as f:
|
||||
bridge_config = json.load(f)
|
||||
|
|
@ -231,12 +253,12 @@ def _attach_to_vllm_shared_tensors(
|
|||
|
||||
single_copy_enabled = bridge_config.get("single_copy_enabled", False)
|
||||
print(f"[Setup] single_copy_enabled in config: {single_copy_enabled}")
|
||||
|
||||
# If single copy is not enable here then we exist because VLLM is likely botched
|
||||
if not single_copy_enabled:
|
||||
print("[Setup] Single-copy mode not available (single_copy_enabled=False)")
|
||||
print("[Setup] Make sure vLLM was started with VLLM_ENABLE_SHARED_WEIGHTS=1")
|
||||
return None
|
||||
|
||||
# Get the IPC handles. from the bridge config. these are memory pointers to the space in memory that shared weights exist in
|
||||
ipc_handles_raw = bridge_config.get("ipc_handles", {})
|
||||
print(f"[Setup] IPC handles count: {len(ipc_handles_raw)}")
|
||||
if not ipc_handles_raw:
|
||||
|
|
@ -250,42 +272,14 @@ def _attach_to_vllm_shared_tensors(
|
|||
print("[Setup] TRUE SINGLE-COPY MODE - No additional model memory!")
|
||||
|
||||
# Load model config (not weights) to get architecture
|
||||
# doesn't store the buffers just basically the schematics. This is the
|
||||
# the blueprint for the house not the actual house
|
||||
model_config = AutoConfig.from_pretrained(config.model_name)
|
||||
|
||||
# Create empty model on meta device (no memory allocation)
|
||||
# Try Flash Attention 2 first (matches vLLM better), fall back to SDPA
|
||||
with torch.device("meta"):
|
||||
# Check if flash_attn is available before trying to use it
|
||||
flash_available = False
|
||||
try:
|
||||
import flash_attn
|
||||
flash_available = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if flash_available:
|
||||
try:
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
model_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
print("[Setup] Using Flash Attention 2 (best vLLM alignment)")
|
||||
except Exception as e:
|
||||
print(f"[Setup] Flash Attention 2 failed ({e}), using SDPA")
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
model_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="sdpa",
|
||||
)
|
||||
else:
|
||||
print("[Setup] flash_attn not installed, using SDPA attention")
|
||||
print("[Setup] NOTE: ~10-15% logprob diff with vLLM expected (different attention impl)")
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
model_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="sdpa",
|
||||
)
|
||||
model = _load_model_with_attention(model_config, from_config=True)
|
||||
|
||||
param_names = list(model.state_dict().keys())
|
||||
print(f"[Setup] Model architecture has {len(param_names)} parameters", flush=True)
|
||||
|
|
@ -586,6 +580,7 @@ def _create_vllm_to_hf_mapping(
|
|||
rather than calculating from config (which can be wrong for some models).
|
||||
"""
|
||||
hf_state_dict = model.state_dict()
|
||||
print("Here is the HF state dict so that we can get a better view ")
|
||||
hf_params = set(hf_state_dict.keys())
|
||||
vllm_params = set(ipc_handles.keys())
|
||||
|
||||
|
|
@ -601,7 +596,7 @@ def _create_vllm_to_hf_mapping(
|
|||
if head_dim is None:
|
||||
head_dim = hidden_size // num_attention_heads
|
||||
|
||||
# Determine QKV sizes from ACTUAL HF model tensor shapes (more reliable)
|
||||
# Determine QKV sizes from ACTUAL HF model tensor shapes
|
||||
# Look for a q_proj weight in the model to get the actual size
|
||||
q_size = None
|
||||
k_size = None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue