gradient flow fix

This commit is contained in:
Jai Suphavadeeprasit 2026-02-12 13:10:19 -05:00
parent 1c8bb34bc1
commit 33844c374b

View file

@ -226,8 +226,9 @@ def _setup_gradient_checkpointing(
# Disable KV cache - incompatible with gradient checkpointing
model.config.use_cache = False
if config.weight_bridge_mode == "lora_only":
# PEFT models need special handling
if config.weight_bridge_mode in ("lora_only", "lora_restart"):
# PEFT models need special handling - enable_input_require_grads is CRITICAL
# Without this, the LoRA parameters won't receive gradients!
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
model.gradient_checkpointing_enable(