mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
lora restart saving gradient changes
This commit is contained in:
parent
1127083b5f
commit
90281f5993
7 changed files with 805 additions and 19 deletions
|
|
@ -87,6 +87,13 @@ class TrainingConfig(BaseModel):
|
|||
3, description="Restart vLLM every N training steps (legacy mode)"
|
||||
)
|
||||
vllm_port: int = Field(9001, description="Port for the vLLM server")
|
||||
vllm_gpu: Optional[int] = Field(
|
||||
None,
|
||||
description=(
|
||||
"GPU ID for vLLM server (lora_restart/legacy modes). "
|
||||
"If None, uses same GPU as trainer. Set different for parallelism."
|
||||
),
|
||||
)
|
||||
vllm_gpu_memory_utilization: float = Field(
|
||||
0.45, description="GPU memory utilization for vLLM server (0.0-1.0)"
|
||||
)
|
||||
|
|
@ -105,12 +112,13 @@ class TrainingConfig(BaseModel):
|
|||
wandb_group: Optional[str] = Field(None, description="Wandb group name")
|
||||
|
||||
# === Training Mode Configuration ===
|
||||
weight_bridge_mode: Literal["shared_vllm", "lora_only", "none"] = Field(
|
||||
weight_bridge_mode: Literal["shared_vllm", "lora_only", "lora_restart", "none"] = Field(
|
||||
"none",
|
||||
description=(
|
||||
"How to synchronize weights with inference server. "
|
||||
"'shared_vllm': attach to vLLM's shared memory tensors and update in-place. "
|
||||
"'lora_only': keep base model frozen, train/swap LoRA adapters via HTTP. "
|
||||
"'lora_only': keep base model frozen, train/swap LoRA adapters via HTTP (slow, needs --enforce-eager). "
|
||||
"'lora_restart': LoRA training with vLLM restarts (fast, CUDA graphs enabled). "
|
||||
"'none': legacy mode, restart vLLM with new checkpoint files."
|
||||
),
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue