mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-29 17:35:07 +00:00
testing set up
This commit is contained in:
parent
f44eb810bf
commit
530fed2877
8 changed files with 599 additions and 2 deletions
|
|
@ -163,6 +163,23 @@ def add_grpo_args(parser: argparse.ArgumentParser) -> None:
|
|||
default=0.2,
|
||||
help="PPO-style clipping epsilon. Clips ratio to [1-eps, 1+eps].",
|
||||
)
|
||||
group.add_argument(
|
||||
"--distill-enabled",
|
||||
action="store_true",
|
||||
help="Enable teacher distillation loss (requires distill payload in Atropos batch).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--distill-coef",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Coefficient for distillation loss term.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--distill-temperature",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Temperature for teacher top-k distribution in distillation loss.",
|
||||
)
|
||||
|
||||
|
||||
def add_vllm_args(parser: argparse.ArgumentParser) -> None:
|
||||
|
|
@ -424,6 +441,9 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig:
|
|||
checkpoint_interval=getattr(args, "checkpoint_interval", 3),
|
||||
# GRPO/PPO hyperparameters
|
||||
clip_eps=getattr(args, "clip_eps", 0.2),
|
||||
distill_enabled=getattr(args, "distill_enabled", False),
|
||||
distill_coef=getattr(args, "distill_coef", 0.0),
|
||||
distill_temperature=getattr(args, "distill_temperature", 1.0),
|
||||
adafactor_scale_parameter=getattr(args, "adafactor_scale_parameter", False),
|
||||
adafactor_relative_step=getattr(args, "adafactor_relative_step", False),
|
||||
# vLLM settings
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue