nccl loras

This commit is contained in:
Jai Suphavadeeprasit 2026-02-11 20:46:41 -05:00
parent 950be6f0d4
commit 2501e33ae3
8 changed files with 1121 additions and 16 deletions

View file

@ -189,9 +189,9 @@ def add_mode_args(parser: argparse.ArgumentParser) -> None:
group.add_argument(
"--weight-bridge-mode",
type=str,
choices=["shared_vllm", "lora_only", "none"],
choices=["shared_vllm", "lora_only", "lora_nccl", "none"],
default="none",
help="Weight sync mode: 'shared_vllm', 'lora_only', or 'none' (legacy)",
help="Weight sync mode: 'shared_vllm', 'lora_only', 'lora_nccl', or 'none' (legacy)",
)
group.add_argument(
"--vllm-config-path",
@ -218,6 +218,35 @@ def add_lora_args(parser: argparse.ArgumentParser) -> None:
)
def add_nccl_args(parser: argparse.ArgumentParser) -> None:
"""Add NCCL weight bridge arguments (for lora_nccl mode)."""
group = parser.add_argument_group("NCCL Weight Bridge (lora_nccl mode)")
group.add_argument(
"--nccl-init-method",
type=str,
default="tcp://localhost:29500",
help="NCCL process group init method (tcp://host:port)",
)
group.add_argument(
"--nccl-world-size",
type=int,
default=2,
help="Total processes in NCCL group (trainer + vLLM instances)",
)
group.add_argument(
"--nccl-sync-every-step",
action="store_true",
default=True,
help="Sync weights after every step (true on-policy)",
)
group.add_argument(
"--no-nccl-sync-every-step",
action="store_false",
dest="nccl_sync_every_step",
help="Sync weights only at vllm_restart_interval",
)
def add_distributed_args(parser: argparse.ArgumentParser) -> None:
"""Add distributed training arguments."""
group = parser.add_argument_group("Distributed Training")
@ -279,6 +308,7 @@ def create_full_parser() -> argparse.ArgumentParser:
add_wandb_args(parser)
add_mode_args(parser)
add_lora_args(parser)
add_nccl_args(parser)
add_distributed_args(parser)
add_debug_args(parser)
@ -367,4 +397,8 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig:
debug_loading=getattr(args, "debug_loading", False),
benchmark=getattr(args, "benchmark", False),
atropos_url=getattr(args, "atropos_url", "http://localhost:8000"),
# NCCL settings (for lora_nccl mode)
nccl_init_method=getattr(args, "nccl_init_method", "tcp://localhost:29500"),
nccl_world_size=getattr(args, "nccl_world_size", 2),
nccl_sync_every_step=getattr(args, "nccl_sync_every_step", True),
)