diff --git a/example_trainer/run.py b/example_trainer/run.py index 9a28af13..1b089abb 100644 --- a/example_trainer/run.py +++ b/example_trainer/run.py @@ -100,6 +100,20 @@ def main(): parser.add_argument("--save-path", type=str, default="trained_model_checkpoints") parser.add_argument("--checkpoint-interval", type=int, default=3) + # === GRPO/PPO Hyperparameters === + parser.add_argument( + "--kl-coef", type=float, default=0.1, + help="KL divergence penalty coefficient. Higher = more conservative updates (default: 0.1)", + ) + parser.add_argument( + "--clip-eps", type=float, default=0.2, + help="PPO clipping epsilon. Clips ratio to [1-eps, 1+eps] (default: 0.2)", + ) + parser.add_argument( + "--no-reference-logprobs", action="store_true", + help="Disable use of inference logprobs as reference policy (not recommended)", + ) + # === vLLM Server === parser.add_argument("--vllm-port", type=int, default=9001, help="Port for vLLM server") parser.add_argument("--gpu-memory-utilization", type=float, default=0.5, help="vLLM GPU memory fraction") @@ -233,6 +247,11 @@ def main(): use_wandb=args.use_wandb, wandb_project=args.wandb_project, checkpoint_interval=args.checkpoint_interval, + # GRPO hyperparameters + kl_coef=args.kl_coef, + clip_eps=args.clip_eps, + use_reference_logprobs=not args.no_reference_logprobs, + benchmark=True, # Always show timing info ) try: