mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
readme updates
This commit is contained in:
parent
998c062709
commit
41cab9d52a
2 changed files with 47 additions and 15 deletions
|
|
@ -65,9 +65,18 @@ def load_model_and_tokenizer(
|
|||
else:
|
||||
# Legacy mode: load full model
|
||||
print("[Setup] Loading model for legacy mode...")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
config.model_name, torch_dtype=torch.bfloat16
|
||||
)
|
||||
try:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
config.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
print("[Setup] Using Flash Attention 2")
|
||||
except Exception as e:
|
||||
print(f"[Setup] Flash Attention 2 not available: {e}")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
config.model_name, torch_dtype=torch.bfloat16
|
||||
)
|
||||
model.to(config.device)
|
||||
|
||||
# Enable gradient checkpointing
|
||||
|
|
@ -121,9 +130,18 @@ def _load_model_with_lora(config: TrainingConfig) -> torch.nn.Module:
|
|||
raise RuntimeError("PEFT library not available. Install with: pip install peft")
|
||||
|
||||
print("[Setup] Loading base model for LoRA mode...")
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
config.model_name, torch_dtype=torch.bfloat16
|
||||
)
|
||||
try:
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
config.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
print("[Setup] Using Flash Attention 2")
|
||||
except Exception as e:
|
||||
print(f"[Setup] Flash Attention 2 not available: {e}")
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
config.model_name, torch_dtype=torch.bfloat16
|
||||
)
|
||||
base_model.to(config.device)
|
||||
|
||||
# Determine target modules
|
||||
|
|
@ -215,11 +233,23 @@ def _attach_to_vllm_shared_tensors(
|
|||
model_config = AutoConfig.from_pretrained(config.model_name)
|
||||
|
||||
# Create empty model on meta device (no memory allocation)
|
||||
# Use Flash Attention 2 to match vLLM's attention implementation more closely
|
||||
# This reduces logprob differences from ~10-15% to ~1-2%
|
||||
with torch.device("meta"):
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
model_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
try:
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
model_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
print("[Setup] Using Flash Attention 2 (matches vLLM)")
|
||||
except Exception as e:
|
||||
print(f"[Setup] Flash Attention 2 not available ({e}), using default attention")
|
||||
print("[Setup] WARNING: This may cause ~10-15% logprob difference with vLLM")
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
model_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
param_names = list(model.state_dict().keys())
|
||||
print(f"[Setup] Model architecture has {len(param_names)} parameters", flush=True)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue