gradient checkpointing issue for LoRAs

This commit is contained in:
Jai Suphavadeeprasit 2025-12-09 10:04:36 -05:00
parent a7bdc0270d
commit e202e2c288

View file

@ -389,7 +389,17 @@ def load_model_and_tokenizer(
)
model.to(config.device)
model.gradient_checkpointing_enable()
# Enable gradient checkpointing (saves memory)
# For LoRA, use PEFT's method; for others, use standard method
if config.weight_bridge_mode == "lora_only":
# PEFT models need gradient_checkpointing enabled on base model
# and require use_reentrant=False for proper gradient flow
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
else:
model.gradient_checkpointing_enable()
model.train()
return model, tokenizer