mirror of
https://github.com/NousResearch/atropos.git
synced 2026-05-03 17:53:17 +00:00
feedback fixes: shared layers + hard coded values + warmup steps
This commit is contained in:
parent
e1f9b926bb
commit
624b3cdabe
9 changed files with 247 additions and 58 deletions
|
|
@ -17,7 +17,7 @@ from .config import TrainingConfig
|
|||
# =============================================================================
|
||||
|
||||
|
||||
def _parse_lora_layer_indices(value: str) -> Optional[List[int]]:
|
||||
def _parse_layer_indices(value: str) -> Optional[List[int]]:
|
||||
"""
|
||||
Parse LoRA layer indices from comma/range syntax.
|
||||
|
||||
|
|
@ -110,6 +110,12 @@ def add_training_args(parser: argparse.ArgumentParser) -> None:
|
|||
default=32,
|
||||
help="Number of gradient accumulation steps",
|
||||
)
|
||||
group.add_argument(
|
||||
"--warmup-steps",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Linear LR warmup steps (0 disables warmup).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--optimizer",
|
||||
type=str,
|
||||
|
|
@ -118,6 +124,16 @@ def add_training_args(parser: argparse.ArgumentParser) -> None:
|
|||
help="Optimizer: 'adamw' (full precision), 'adamw_8bit' (8-bit states), "
|
||||
"'adafactor' (no momentum)",
|
||||
)
|
||||
group.add_argument(
|
||||
"--adafactor-scale-parameter",
|
||||
action="store_true",
|
||||
help="Enable Adafactor scale_parameter behavior (only used when --optimizer adafactor).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--adafactor-relative-step",
|
||||
action="store_true",
|
||||
help="Enable Adafactor relative_step behavior (only used when --optimizer adafactor).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
|
|
@ -144,8 +160,8 @@ def add_grpo_args(parser: argparse.ArgumentParser) -> None:
|
|||
group.add_argument(
|
||||
"--kl-coef",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="KL divergence penalty coefficient (beta). Higher = more conservative.",
|
||||
default=0.0,
|
||||
help="Sampled-token KL-like regularization coefficient. Higher = more conservative.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--clip-eps",
|
||||
|
|
@ -256,6 +272,15 @@ def add_mode_args(parser: argparse.ArgumentParser) -> None:
|
|||
default=None,
|
||||
help="Explicit path to vllm_bridge_config.json (auto-detected if not provided)",
|
||||
)
|
||||
group.add_argument(
|
||||
"--train-layer-indices",
|
||||
type=_parse_layer_indices,
|
||||
default=None,
|
||||
help=(
|
||||
"Optional transformer layer indices to train in full/shared modes, e.g. "
|
||||
"'20-31' or '0-3,28-31'. If omitted, all layers are trainable."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def add_lora_args(parser: argparse.ArgumentParser) -> None:
|
||||
|
|
@ -275,7 +300,7 @@ def add_lora_args(parser: argparse.ArgumentParser) -> None:
|
|||
)
|
||||
group.add_argument(
|
||||
"--lora-layer-indices",
|
||||
type=_parse_lora_layer_indices,
|
||||
type=_parse_layer_indices,
|
||||
default=None,
|
||||
help=(
|
||||
"Optional layer indices to apply LoRA to, e.g. '20-31' or "
|
||||
|
|
@ -403,14 +428,17 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig:
|
|||
batch_size=args.batch_size,
|
||||
seq_len=args.seq_len,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
warmup_steps=getattr(args, "warmup_steps", 0),
|
||||
optimizer=args.optimizer,
|
||||
device=args.device,
|
||||
save_path=args.save_path,
|
||||
checkpoint_interval=getattr(args, "checkpoint_interval", 3),
|
||||
# GRPO/PPO hyperparameters
|
||||
kl_coef=getattr(args, "kl_coef", 0.1),
|
||||
kl_coef=getattr(args, "kl_coef", 0.0),
|
||||
clip_eps=getattr(args, "clip_eps", 0.2),
|
||||
use_reference_logprobs=not getattr(args, "no_reference_logprobs", False),
|
||||
adafactor_scale_parameter=getattr(args, "adafactor_scale_parameter", False),
|
||||
adafactor_relative_step=getattr(args, "adafactor_relative_step", False),
|
||||
# vLLM settings
|
||||
vllm_restart_interval=getattr(args, "vllm_restart_interval", 3),
|
||||
vllm_port=args.vllm_port,
|
||||
|
|
@ -422,6 +450,7 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig:
|
|||
wandb_project=args.wandb_project,
|
||||
wandb_group=getattr(args, "wandb_group", None),
|
||||
weight_bridge_mode=getattr(args, "weight_bridge_mode", "none"),
|
||||
train_layer_indices=getattr(args, "train_layer_indices", None),
|
||||
trainer_rank=getattr(args, "trainer_rank", 0),
|
||||
world_size=getattr(args, "world_size", 1),
|
||||
init_method=getattr(args, "init_method", "env://"),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue