manual testing

This commit is contained in:
Jai Suphavadeeprasit 2026-02-02 15:40:24 -05:00
parent da046d3d3b
commit c1bb4f33f0
5 changed files with 329 additions and 766 deletions

View file

@ -40,6 +40,33 @@ class TrainingConfig(BaseModel):
"'adafactor' (no momentum, ~8GB GPU)"
)
# === GRPO/PPO Hyperparameters ===
kl_coef: float = Field(
0.1,
description=(
"KL divergence penalty coefficient (beta). "
"Controls how much the policy can deviate from the reference (inference-time) policy. "
"Higher values = more conservative updates, prevents reward hacking. "
"Set to 0 to disable KL penalty (not recommended)."
),
)
clip_eps: float = Field(
0.2,
description=(
"PPO-style clipping epsilon. "
"Clips the importance sampling ratio to [1-eps, 1+eps]. "
"Prevents large policy updates that could destabilize training."
),
)
use_reference_logprobs: bool = Field(
True,
description=(
"Whether to use inference logprobs as the reference policy (π_old). "
"When True, implements proper GRPO with importance sampling. "
"When False, falls back to REINFORCE-style updates (not recommended)."
),
)
# === Device & Storage ===
device: str = Field(
"cuda" if torch.cuda.is_available() else "cpu",