mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
manual testing
This commit is contained in:
parent
da046d3d3b
commit
c1bb4f33f0
5 changed files with 329 additions and 766 deletions
|
|
@ -40,6 +40,33 @@ class TrainingConfig(BaseModel):
|
|||
"'adafactor' (no momentum, ~8GB GPU)"
|
||||
)
|
||||
|
||||
# === GRPO/PPO Hyperparameters ===
|
||||
kl_coef: float = Field(
|
||||
0.1,
|
||||
description=(
|
||||
"KL divergence penalty coefficient (beta). "
|
||||
"Controls how much the policy can deviate from the reference (inference-time) policy. "
|
||||
"Higher values = more conservative updates, prevents reward hacking. "
|
||||
"Set to 0 to disable KL penalty (not recommended)."
|
||||
),
|
||||
)
|
||||
clip_eps: float = Field(
|
||||
0.2,
|
||||
description=(
|
||||
"PPO-style clipping epsilon. "
|
||||
"Clips the importance sampling ratio to [1-eps, 1+eps]. "
|
||||
"Prevents large policy updates that could destabilize training."
|
||||
),
|
||||
)
|
||||
use_reference_logprobs: bool = Field(
|
||||
True,
|
||||
description=(
|
||||
"Whether to use inference logprobs as the reference policy (π_old). "
|
||||
"When True, implements proper GRPO with importance sampling. "
|
||||
"When False, falls back to REINFORCE-style updates (not recommended)."
|
||||
),
|
||||
)
|
||||
|
||||
# === Device & Storage ===
|
||||
device: str = Field(
|
||||
"cuda" if torch.cuda.is_available() else "cpu",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue