testing set up

This commit is contained in:
Jai Suphavadeeprasit 2026-03-06 14:49:32 -05:00
parent f44eb810bf
commit 530fed2877
8 changed files with 599 additions and 2 deletions

View file

@ -163,6 +163,23 @@ def add_grpo_args(parser: argparse.ArgumentParser) -> None:
default=0.2,
help="PPO-style clipping epsilon. Clips ratio to [1-eps, 1+eps].",
)
group.add_argument(
"--distill-enabled",
action="store_true",
help="Enable teacher distillation loss (requires distill payload in Atropos batch).",
)
group.add_argument(
"--distill-coef",
type=float,
default=0.0,
help="Coefficient for distillation loss term.",
)
group.add_argument(
"--distill-temperature",
type=float,
default=1.0,
help="Temperature for teacher top-k distribution in distillation loss.",
)
def add_vllm_args(parser: argparse.ArgumentParser) -> None:
@ -424,6 +441,9 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig:
checkpoint_interval=getattr(args, "checkpoint_interval", 3),
# GRPO/PPO hyperparameters
clip_eps=getattr(args, "clip_eps", 0.2),
distill_enabled=getattr(args, "distill_enabled", False),
distill_coef=getattr(args, "distill_coef", 0.0),
distill_temperature=getattr(args, "distill_temperature", 1.0),
adafactor_scale_parameter=getattr(args, "adafactor_scale_parameter", False),
adafactor_relative_step=getattr(args, "adafactor_relative_step", False),
# vLLM settings