mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
lora restart saving gradient changes
This commit is contained in:
parent
1127083b5f
commit
90281f5993
7 changed files with 805 additions and 19 deletions
|
|
@ -121,6 +121,12 @@ def add_vllm_args(parser: argparse.ArgumentParser) -> None:
|
|||
default=9001,
|
||||
help="Port for the vLLM server",
|
||||
)
|
||||
group.add_argument(
|
||||
"--vllm-gpu",
|
||||
type=int,
|
||||
default=None,
|
||||
help="GPU ID for vLLM server. If not set, uses same GPU as trainer.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gpu-memory-utilization",
|
||||
"--vllm-gpu-memory-utilization",
|
||||
|
|
@ -146,7 +152,7 @@ def add_vllm_args(parser: argparse.ArgumentParser) -> None:
|
|||
"--vllm-restart-interval",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Restart vLLM every N training steps (legacy mode only)",
|
||||
help="Restart vLLM every N training steps (legacy/lora_restart modes)",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -189,9 +195,12 @@ def add_mode_args(parser: argparse.ArgumentParser) -> None:
|
|||
group.add_argument(
|
||||
"--weight-bridge-mode",
|
||||
type=str,
|
||||
choices=["shared_vllm", "lora_only", "none"],
|
||||
choices=["shared_vllm", "lora_only", "lora_restart", "none"],
|
||||
default="none",
|
||||
help="Weight sync mode: 'shared_vllm', 'lora_only', or 'none' (legacy)",
|
||||
help=(
|
||||
"Weight sync mode: 'shared_vllm' (CUDA IPC), 'lora_only' (slow, --enforce-eager), "
|
||||
"'lora_restart' (fast, restarts vLLM), or 'none' (legacy)"
|
||||
),
|
||||
)
|
||||
group.add_argument(
|
||||
"--vllm-config-path",
|
||||
|
|
@ -348,6 +357,7 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig:
|
|||
# vLLM settings
|
||||
vllm_restart_interval=getattr(args, "vllm_restart_interval", 3),
|
||||
vllm_port=args.vllm_port,
|
||||
vllm_gpu=getattr(args, "vllm_gpu", None),
|
||||
vllm_gpu_memory_utilization=getattr(args, "gpu_memory_utilization", 0.45),
|
||||
max_model_len=getattr(args, "max_model_len", 4096),
|
||||
dtype=getattr(args, "dtype", "bfloat16"),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue