lora restart saving gradient changes

This commit is contained in:
Jai Suphavadeeprasit 2026-02-12 10:43:24 -05:00
parent 1127083b5f
commit 90281f5993
7 changed files with 805 additions and 19 deletions

View file

@ -87,6 +87,13 @@ class TrainingConfig(BaseModel):
3, description="Restart vLLM every N training steps (legacy mode)"
)
vllm_port: int = Field(9001, description="Port for the vLLM server")
vllm_gpu: Optional[int] = Field(
None,
description=(
"GPU ID for vLLM server (lora_restart/legacy modes). "
"If None, uses same GPU as trainer. Set different for parallelism."
),
)
vllm_gpu_memory_utilization: float = Field(
0.45, description="GPU memory utilization for vLLM server (0.0-1.0)"
)
@ -105,12 +112,13 @@ class TrainingConfig(BaseModel):
wandb_group: Optional[str] = Field(None, description="Wandb group name")
# === Training Mode Configuration ===
weight_bridge_mode: Literal["shared_vllm", "lora_only", "none"] = Field(
weight_bridge_mode: Literal["shared_vllm", "lora_only", "lora_restart", "none"] = Field(
"none",
description=(
"How to synchronize weights with inference server. "
"'shared_vllm': attach to vLLM's shared memory tensors and update in-place. "
"'lora_only': keep base model frozen, train/swap LoRA adapters via HTTP. "
"'lora_only': keep base model frozen, train/swap LoRA adapters via HTTP (slow, needs --enforce-eager). "
"'lora_restart': LoRA training with vLLM restarts (fast, CUDA graphs enabled). "
"'none': legacy mode, restart vLLM with new checkpoint files."
),
)