mirror of
https://github.com/NousResearch/atropos.git
synced 2026-05-01 17:45:16 +00:00
feedback fixes: shared layers + hard coded values + warmup steps
This commit is contained in:
parent
e1f9b926bb
commit
624b3cdabe
9 changed files with 247 additions and 58 deletions
|
|
@ -32,21 +32,41 @@ class TrainingConfig(BaseModel):
|
|||
gradient_accumulation_steps: int = Field(
|
||||
32, description="Number of gradient accumulation steps"
|
||||
)
|
||||
warmup_steps: int = Field(
|
||||
0,
|
||||
description=(
|
||||
"Number of initial optimizer steps for linear LR warmup. "
|
||||
"0 disables warmup."
|
||||
),
|
||||
)
|
||||
optimizer: Literal["adamw", "adamw_8bit", "adafactor"] = Field(
|
||||
"adamw_8bit",
|
||||
description="Optimizer to use: 'adamw' (full precision, ~32GB GPU), "
|
||||
"'adamw_8bit' (8-bit states, ~8GB GPU, requires bitsandbytes), "
|
||||
"'adafactor' (no momentum, ~8GB GPU)",
|
||||
description="Optimizer to use: 'adamw' (full precision), "
|
||||
"'adamw_8bit' (8-bit states, requires bitsandbytes), "
|
||||
"'adafactor' (Adafactor optimizer)",
|
||||
)
|
||||
adafactor_scale_parameter: bool = Field(
|
||||
False,
|
||||
description=(
|
||||
"Whether to enable Adafactor scale_parameter behavior when using "
|
||||
"optimizer='adafactor'."
|
||||
),
|
||||
)
|
||||
adafactor_relative_step: bool = Field(
|
||||
False,
|
||||
description=(
|
||||
"Whether to enable Adafactor relative_step behavior when using "
|
||||
"optimizer='adafactor'."
|
||||
),
|
||||
)
|
||||
|
||||
# === GRPO/PPO Hyperparameters ===
|
||||
kl_coef: float = Field(
|
||||
0.1,
|
||||
0.0,
|
||||
description=(
|
||||
"KL divergence penalty coefficient (beta). "
|
||||
"Controls how much the policy can deviate from the reference (inference-time) policy. "
|
||||
"Higher values = more conservative updates, prevents reward hacking. "
|
||||
"Set to 0 to disable KL penalty (not recommended)."
|
||||
"Coefficient for sampled-token KL-like regularization against rollout/reference "
|
||||
"logprobs. Higher values make updates more conservative. "
|
||||
"Set to 0 to disable this term."
|
||||
),
|
||||
)
|
||||
clip_eps: float = Field(
|
||||
|
|
@ -121,6 +141,13 @@ class TrainingConfig(BaseModel):
|
|||
"'none': legacy mode, restart vLLM with new checkpoint files."
|
||||
),
|
||||
)
|
||||
train_layer_indices: Optional[List[int]] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Optional list of transformer layer indices to train in shared/legacy "
|
||||
"full-model modes. If None, all layers are trainable."
|
||||
),
|
||||
)
|
||||
|
||||
# === Distributed Training Configuration ===
|
||||
trainer_rank: int = Field(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue