change OPD style

This commit is contained in:
Jai Suphavadeeprasit 2026-02-19 17:08:27 -05:00
parent 33f5696171
commit 527433b5bc
10 changed files with 452 additions and 90 deletions

View file

@ -158,6 +158,30 @@ def add_grpo_args(parser: argparse.ArgumentParser) -> None:
action="store_true",
help="Disable use of inference logprobs as reference policy (not recommended).",
)
group.add_argument(
"--distillation-enabled",
action="store_true",
help="Enable on-policy distillation using teacher top-K arrays from Atropos.",
)
group.add_argument(
"--distillation-coef",
type=float,
default=0.1,
help="Scale factor for distillation loss contribution.",
)
group.add_argument(
"--distillation-temperature",
type=float,
default=1.0,
help="Temperature used for teacher/student matching in distillation loss.",
)
group.add_argument(
"--distillation-loss-type",
type=str,
choices=["kl", "cross_entropy"],
default="kl",
help="Distillation objective type.",
)
def add_vllm_args(parser: argparse.ArgumentParser) -> None:
@ -411,6 +435,10 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig:
kl_coef=getattr(args, "kl_coef", 0.1),
clip_eps=getattr(args, "clip_eps", 0.2),
use_reference_logprobs=not getattr(args, "no_reference_logprobs", False),
distillation_enabled=getattr(args, "distillation_enabled", False),
distillation_coef=getattr(args, "distillation_coef", 0.1),
distillation_temperature=getattr(args, "distillation_temperature", 1.0),
distillation_loss_type=getattr(args, "distillation_loss_type", "kl"),
# vLLM settings
vllm_restart_interval=getattr(args, "vllm_restart_interval", 3),
vllm_port=args.vllm_port,