checkpointing fixes

This commit is contained in:
Jai Suphavadeeprasit 2026-01-29 11:41:24 -05:00
parent b9414e4076
commit 04652fd97c

View file

@ -48,6 +48,8 @@ def load_model_and_tokenizer(
if model is not None:
print("[Setup] ✓ Single-copy mode active - using vLLM's tensors directly!")
# Enable gradient checkpointing to save memory (was missing before!)
_setup_gradient_checkpointing(model, config)
model.train()
return model, tokenizer
else: