memory enhancements

This commit is contained in:
Jai Suphavadeeprasit 2026-01-29 21:44:24 -05:00
parent 99eaab3192
commit 75c4f5c853
4 changed files with 43 additions and 7 deletions

View file

@ -60,6 +60,16 @@ def parse_args() -> argparse.Namespace:
default=32,
help="Number of gradient accumulation steps",
)
parser.add_argument(
"--optimizer",
type=str,
choices=["adamw", "adamw_8bit", "adamw_cpu", "adafactor"],
default="adamw_8bit",
help="Optimizer: 'adamw' (full precision, ~32GB GPU), "
"'adamw_8bit' (8-bit states, ~8GB GPU), "
"'adamw_cpu' (CPU offload, ~0GB GPU, slower), "
"'adafactor' (no momentum, ~8GB GPU)",
)
parser.add_argument(
"--device",
type=str,
@ -245,6 +255,7 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig:
batch_size=args.batch_size,
seq_len=args.seq_len,
gradient_accumulation_steps=args.gradient_accumulation_steps,
optimizer=args.optimizer,
device=args.device,
save_path=args.save_path,
vllm_restart_interval=args.vllm_restart_interval,