diff --git a/example_trainer/grpo.py b/example_trainer/grpo.py index 4464869f..2d40a445 100644 --- a/example_trainer/grpo.py +++ b/example_trainer/grpo.py @@ -389,7 +389,17 @@ def load_model_and_tokenizer( ) model.to(config.device) - model.gradient_checkpointing_enable() + # Enable gradient checkpointing (saves memory) + # For LoRA, use PEFT's method; for others, use standard method + if config.weight_bridge_mode == "lora_only": + # PEFT models need gradient_checkpointing enabled on base model + # and require use_reentrant=False for proper gradient flow + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + else: + model.gradient_checkpointing_enable() + model.train() return model, tokenizer