feedback fixes: shared layers + hard coded values + warmup steps

This commit is contained in:
Jai Suphavadeeprasit 2026-02-24 10:28:44 -05:00
parent e1f9b926bb
commit 624b3cdabe
9 changed files with 247 additions and 58 deletions

View file

@ -32,21 +32,41 @@ class TrainingConfig(BaseModel):
gradient_accumulation_steps: int = Field(
32, description="Number of gradient accumulation steps"
)
warmup_steps: int = Field(
0,
description=(
"Number of initial optimizer steps for linear LR warmup. "
"0 disables warmup."
),
)
optimizer: Literal["adamw", "adamw_8bit", "adafactor"] = Field(
"adamw_8bit",
description="Optimizer to use: 'adamw' (full precision, ~32GB GPU), "
"'adamw_8bit' (8-bit states, ~8GB GPU, requires bitsandbytes), "
"'adafactor' (no momentum, ~8GB GPU)",
description="Optimizer to use: 'adamw' (full precision), "
"'adamw_8bit' (8-bit states, requires bitsandbytes), "
"'adafactor' (Adafactor optimizer)",
)
adafactor_scale_parameter: bool = Field(
False,
description=(
"Whether to enable Adafactor scale_parameter behavior when using "
"optimizer='adafactor'."
),
)
adafactor_relative_step: bool = Field(
False,
description=(
"Whether to enable Adafactor relative_step behavior when using "
"optimizer='adafactor'."
),
)
# === GRPO/PPO Hyperparameters ===
kl_coef: float = Field(
0.1,
0.0,
description=(
"KL divergence penalty coefficient (beta). "
"Controls how much the policy can deviate from the reference (inference-time) policy. "
"Higher values = more conservative updates, prevents reward hacking. "
"Set to 0 to disable KL penalty (not recommended)."
"Coefficient for sampled-token KL-like regularization against rollout/reference "
"logprobs. Higher values make updates more conservative. "
"Set to 0 to disable this term."
),
)
clip_eps: float = Field(
@ -121,6 +141,13 @@ class TrainingConfig(BaseModel):
"'none': legacy mode, restart vLLM with new checkpoint files."
),
)
train_layer_indices: Optional[List[int]] = Field(
None,
description=(
"Optional list of transformer layer indices to train in shared/legacy "
"full-model modes. If None, all layers are trainable."
),
)
# === Distributed Training Configuration ===
trainer_rank: int = Field(