lora restart saving gradient changes

This commit is contained in:
Jai Suphavadeeprasit 2026-02-12 10:43:24 -05:00
parent 1127083b5f
commit 90281f5993
7 changed files with 805 additions and 19 deletions

View file

@ -121,6 +121,12 @@ def add_vllm_args(parser: argparse.ArgumentParser) -> None:
default=9001,
help="Port for the vLLM server",
)
group.add_argument(
"--vllm-gpu",
type=int,
default=None,
help="GPU ID for vLLM server. If not set, uses same GPU as trainer.",
)
group.add_argument(
"--gpu-memory-utilization",
"--vllm-gpu-memory-utilization",
@ -146,7 +152,7 @@ def add_vllm_args(parser: argparse.ArgumentParser) -> None:
"--vllm-restart-interval",
type=int,
default=3,
help="Restart vLLM every N training steps (legacy mode only)",
help="Restart vLLM every N training steps (legacy/lora_restart modes)",
)
@ -189,9 +195,12 @@ def add_mode_args(parser: argparse.ArgumentParser) -> None:
group.add_argument(
"--weight-bridge-mode",
type=str,
choices=["shared_vllm", "lora_only", "none"],
choices=["shared_vllm", "lora_only", "lora_restart", "none"],
default="none",
help="Weight sync mode: 'shared_vllm', 'lora_only', or 'none' (legacy)",
help=(
"Weight sync mode: 'shared_vllm' (CUDA IPC), 'lora_only' (slow, --enforce-eager), "
"'lora_restart' (fast, restarts vLLM), or 'none' (legacy)"
),
)
group.add_argument(
"--vllm-config-path",
@ -348,6 +357,7 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig:
# vLLM settings
vllm_restart_interval=getattr(args, "vllm_restart_interval", 3),
vllm_port=args.vllm_port,
vllm_gpu=getattr(args, "vllm_gpu", None),
vllm_gpu_memory_utilization=getattr(args, "gpu_memory_utilization", 0.45),
max_model_len=getattr(args, "max_model_len", 4096),
dtype=getattr(args, "dtype", "bfloat16"),