manual testing

This commit is contained in:
Jai Suphavadeeprasit 2026-02-02 15:40:24 -05:00
parent da046d3d3b
commit c1bb4f33f0
5 changed files with 329 additions and 766 deletions

View file

@ -70,6 +70,38 @@ def parse_args() -> argparse.Namespace:
"'adamw_cpu' (CPU offload, ~0GB GPU, slower), "
"'adafactor' (no momentum, ~8GB GPU)",
)
# === GRPO/PPO Hyperparameters ===
parser.add_argument(
"--kl-coef",
type=float,
default=0.1,
help=(
"KL divergence penalty coefficient (beta). "
"Controls policy deviation from reference. "
"Higher = more conservative, prevents reward hacking. "
"0 = disabled (not recommended)."
),
)
parser.add_argument(
"--clip-eps",
type=float,
default=0.2,
help=(
"PPO-style clipping epsilon. "
"Clips importance ratio to [1-eps, 1+eps]. "
"Prevents destabilizing large policy updates."
),
)
parser.add_argument(
"--no-reference-logprobs",
action="store_true",
help=(
"Disable use of inference logprobs as reference policy. "
"Falls back to REINFORCE-style updates (not recommended)."
),
)
parser.add_argument(
"--device",
type=str,
@ -265,6 +297,11 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig:
device=args.device,
save_path=args.save_path,
checkpoint_interval=getattr(args, "checkpoint_interval", 3),
# GRPO/PPO hyperparameters
kl_coef=getattr(args, "kl_coef", 0.1),
clip_eps=getattr(args, "clip_eps", 0.2),
use_reference_logprobs=not getattr(args, "no_reference_logprobs", False),
# vLLM settings
vllm_restart_interval=args.vllm_restart_interval,
vllm_port=args.vllm_port,
vllm_gpu_memory_utilization=args.vllm_gpu_memory_utilization,