readme updates

This commit is contained in:
Jai Suphavadeeprasit 2026-01-27 17:13:18 -05:00
parent 998c062709
commit 41cab9d52a
2 changed files with 47 additions and 15 deletions

View file

@ -65,9 +65,18 @@ def load_model_and_tokenizer(
else:
# Legacy mode: load full model
print("[Setup] Loading model for legacy mode...")
model = AutoModelForCausalLM.from_pretrained(
config.model_name, torch_dtype=torch.bfloat16
)
try:
model = AutoModelForCausalLM.from_pretrained(
config.model_name,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
print("[Setup] Using Flash Attention 2")
except Exception as e:
print(f"[Setup] Flash Attention 2 not available: {e}")
model = AutoModelForCausalLM.from_pretrained(
config.model_name, torch_dtype=torch.bfloat16
)
model.to(config.device)
# Enable gradient checkpointing
@ -121,9 +130,18 @@ def _load_model_with_lora(config: TrainingConfig) -> torch.nn.Module:
raise RuntimeError("PEFT library not available. Install with: pip install peft")
print("[Setup] Loading base model for LoRA mode...")
base_model = AutoModelForCausalLM.from_pretrained(
config.model_name, torch_dtype=torch.bfloat16
)
try:
base_model = AutoModelForCausalLM.from_pretrained(
config.model_name,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
print("[Setup] Using Flash Attention 2")
except Exception as e:
print(f"[Setup] Flash Attention 2 not available: {e}")
base_model = AutoModelForCausalLM.from_pretrained(
config.model_name, torch_dtype=torch.bfloat16
)
base_model.to(config.device)
# Determine target modules
@ -215,11 +233,23 @@ def _attach_to_vllm_shared_tensors(
model_config = AutoConfig.from_pretrained(config.model_name)
# Create empty model on meta device (no memory allocation)
# Use Flash Attention 2 to match vLLM's attention implementation more closely
# This reduces logprob differences from ~10-15% to ~1-2%
with torch.device("meta"):
model = AutoModelForCausalLM.from_config(
model_config,
torch_dtype=torch.bfloat16,
)
try:
model = AutoModelForCausalLM.from_config(
model_config,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
print("[Setup] Using Flash Attention 2 (matches vLLM)")
except Exception as e:
print(f"[Setup] Flash Attention 2 not available ({e}), using default attention")
print("[Setup] WARNING: This may cause ~10-15% logprob difference with vLLM")
model = AutoModelForCausalLM.from_config(
model_config,
torch_dtype=torch.bfloat16,
)
param_names = list(model.state_dict().keys())
print(f"[Setup] Model architecture has {len(param_names)} parameters", flush=True)