memory enhancements

This commit is contained in:
Jai Suphavadeeprasit 2026-01-29 21:44:24 -05:00
parent 99eaab3192
commit 75c4f5c853
4 changed files with 43 additions and 7 deletions

View file

@ -98,6 +98,12 @@ def save_checkpoint(
# Save state dict manually, then save config separately
torch.save(state_dict, os.path.join(checkpoint_path, "pytorch_model.bin"))
model.config.save_pretrained(checkpoint_path)
# CRITICAL: Clean up the copied state_dict to free ~8GB GPU memory!
del state_dict
import gc
gc.collect()
torch.cuda.empty_cache()
else:
# Standard save (may have issues with view tensors)
model.save_pretrained(checkpoint_path)