This commit is contained in:
Jai Suphavadeeprasit 2026-02-02 15:58:47 -05:00
parent 6a659a8c9d
commit 2b5debe0a2

View file

@ -100,6 +100,20 @@ def main():
parser.add_argument("--save-path", type=str, default="trained_model_checkpoints")
parser.add_argument("--checkpoint-interval", type=int, default=3)
# === GRPO/PPO Hyperparameters ===
parser.add_argument(
"--kl-coef", type=float, default=0.1,
help="KL divergence penalty coefficient. Higher = more conservative updates (default: 0.1)",
)
parser.add_argument(
"--clip-eps", type=float, default=0.2,
help="PPO clipping epsilon. Clips ratio to [1-eps, 1+eps] (default: 0.2)",
)
parser.add_argument(
"--no-reference-logprobs", action="store_true",
help="Disable use of inference logprobs as reference policy (not recommended)",
)
# === vLLM Server ===
parser.add_argument("--vllm-port", type=int, default=9001, help="Port for vLLM server")
parser.add_argument("--gpu-memory-utilization", type=float, default=0.5, help="vLLM GPU memory fraction")
@ -233,6 +247,11 @@ def main():
use_wandb=args.use_wandb,
wandb_project=args.wandb_project,
checkpoint_interval=args.checkpoint_interval,
# GRPO hyperparameters
kl_coef=args.kl_coef,
clip_eps=args.clip_eps,
use_reference_logprobs=not args.no_reference_logprobs,
benchmark=True, # Always show timing info
)
try: