mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
memory enhancements
This commit is contained in:
parent
99eaab3192
commit
75c4f5c853
4 changed files with 43 additions and 7 deletions
|
|
@ -60,6 +60,16 @@ def parse_args() -> argparse.Namespace:
|
|||
default=32,
|
||||
help="Number of gradient accumulation steps",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optimizer",
|
||||
type=str,
|
||||
choices=["adamw", "adamw_8bit", "adamw_cpu", "adafactor"],
|
||||
default="adamw_8bit",
|
||||
help="Optimizer: 'adamw' (full precision, ~32GB GPU), "
|
||||
"'adamw_8bit' (8-bit states, ~8GB GPU), "
|
||||
"'adamw_cpu' (CPU offload, ~0GB GPU, slower), "
|
||||
"'adafactor' (no momentum, ~8GB GPU)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
|
|
@ -245,6 +255,7 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig:
|
|||
batch_size=args.batch_size,
|
||||
seq_len=args.seq_len,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
optimizer=args.optimizer,
|
||||
device=args.device,
|
||||
save_path=args.save_path,
|
||||
vllm_restart_interval=args.vllm_restart_interval,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue