mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-24 17:04:55 +00:00
nccl loras
This commit is contained in:
parent
950be6f0d4
commit
2501e33ae3
8 changed files with 1121 additions and 16 deletions
|
|
@ -105,12 +105,13 @@ class TrainingConfig(BaseModel):
|
|||
wandb_group: Optional[str] = Field(None, description="Wandb group name")
|
||||
|
||||
# === Training Mode Configuration ===
|
||||
weight_bridge_mode: Literal["shared_vllm", "lora_only", "none"] = Field(
|
||||
weight_bridge_mode: Literal["shared_vllm", "lora_only", "lora_nccl", "none"] = Field(
|
||||
"none",
|
||||
description=(
|
||||
"How to synchronize weights with inference server. "
|
||||
"'shared_vllm': attach to vLLM's shared memory tensors and update in-place. "
|
||||
"'lora_only': keep base model frozen, train/swap LoRA adapters. "
|
||||
"'lora_only': keep base model frozen, train/swap LoRA adapters via HTTP. "
|
||||
"'lora_nccl': LoRA training with NCCL direct weight transfer (torchtitan-style). "
|
||||
"'none': legacy mode, restart vLLM with new checkpoint files."
|
||||
),
|
||||
)
|
||||
|
|
@ -148,6 +149,30 @@ class TrainingConfig(BaseModel):
|
|||
),
|
||||
)
|
||||
|
||||
# === NCCL Weight Bridge Configuration (for lora_nccl mode) ===
|
||||
nccl_init_method: str = Field(
|
||||
"tcp://localhost:29500",
|
||||
description=(
|
||||
"NCCL process group init method for lora_nccl mode. "
|
||||
"Format: tcp://host:port"
|
||||
),
|
||||
)
|
||||
nccl_world_size: int = Field(
|
||||
2,
|
||||
description=(
|
||||
"Total number of processes in the NCCL weight bridge group. "
|
||||
"Typically 2: trainer (rank 0) + vLLM server (rank 1). "
|
||||
"For multi-GPU vLLM, this would be 1 + num_vllm_gpus."
|
||||
),
|
||||
)
|
||||
nccl_sync_every_step: bool = Field(
|
||||
True,
|
||||
description=(
|
||||
"Whether to sync weights after every training step (true on-policy). "
|
||||
"If False, syncs every vllm_restart_interval steps."
|
||||
),
|
||||
)
|
||||
|
||||
# === Single-Copy Mode Configuration ===
|
||||
single_copy: bool = Field(
|
||||
False,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue