diff --git a/example_trainer/model.py b/example_trainer/model.py index 97484eb7..c6df552c 100644 --- a/example_trainer/model.py +++ b/example_trainer/model.py @@ -226,8 +226,9 @@ def _setup_gradient_checkpointing( # Disable KV cache - incompatible with gradient checkpointing model.config.use_cache = False - if config.weight_bridge_mode == "lora_only": - # PEFT models need special handling + if config.weight_bridge_mode in ("lora_only", "lora_restart"): + # PEFT models need special handling - enable_input_require_grads is CRITICAL + # Without this, the LoRA parameters won't receive gradients! if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() model.gradient_checkpointing_enable(