feedback fixes: shared layers + hard coded values + warmup steps

This commit is contained in:
Jai Suphavadeeprasit 2026-02-24 10:28:44 -05:00
parent e1f9b926bb
commit 624b3cdabe
9 changed files with 247 additions and 58 deletions

View file

@ -17,7 +17,7 @@ from .config import TrainingConfig
# =============================================================================
def _parse_lora_layer_indices(value: str) -> Optional[List[int]]:
def _parse_layer_indices(value: str) -> Optional[List[int]]:
"""
Parse LoRA layer indices from comma/range syntax.
@ -110,6 +110,12 @@ def add_training_args(parser: argparse.ArgumentParser) -> None:
default=32,
help="Number of gradient accumulation steps",
)
group.add_argument(
"--warmup-steps",
type=int,
default=0,
help="Linear LR warmup steps (0 disables warmup).",
)
group.add_argument(
"--optimizer",
type=str,
@ -118,6 +124,16 @@ def add_training_args(parser: argparse.ArgumentParser) -> None:
help="Optimizer: 'adamw' (full precision), 'adamw_8bit' (8-bit states), "
"'adafactor' (no momentum)",
)
group.add_argument(
"--adafactor-scale-parameter",
action="store_true",
help="Enable Adafactor scale_parameter behavior (only used when --optimizer adafactor).",
)
group.add_argument(
"--adafactor-relative-step",
action="store_true",
help="Enable Adafactor relative_step behavior (only used when --optimizer adafactor).",
)
group.add_argument(
"--device",
type=str,
@ -144,8 +160,8 @@ def add_grpo_args(parser: argparse.ArgumentParser) -> None:
group.add_argument(
"--kl-coef",
type=float,
default=0.1,
help="KL divergence penalty coefficient (beta). Higher = more conservative.",
default=0.0,
help="Sampled-token KL-like regularization coefficient. Higher = more conservative.",
)
group.add_argument(
"--clip-eps",
@ -256,6 +272,15 @@ def add_mode_args(parser: argparse.ArgumentParser) -> None:
default=None,
help="Explicit path to vllm_bridge_config.json (auto-detected if not provided)",
)
group.add_argument(
"--train-layer-indices",
type=_parse_layer_indices,
default=None,
help=(
"Optional transformer layer indices to train in full/shared modes, e.g. "
"'20-31' or '0-3,28-31'. If omitted, all layers are trainable."
),
)
def add_lora_args(parser: argparse.ArgumentParser) -> None:
@ -275,7 +300,7 @@ def add_lora_args(parser: argparse.ArgumentParser) -> None:
)
group.add_argument(
"--lora-layer-indices",
type=_parse_lora_layer_indices,
type=_parse_layer_indices,
default=None,
help=(
"Optional layer indices to apply LoRA to, e.g. '20-31' or "
@ -403,14 +428,17 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig:
batch_size=args.batch_size,
seq_len=args.seq_len,
gradient_accumulation_steps=args.gradient_accumulation_steps,
warmup_steps=getattr(args, "warmup_steps", 0),
optimizer=args.optimizer,
device=args.device,
save_path=args.save_path,
checkpoint_interval=getattr(args, "checkpoint_interval", 3),
# GRPO/PPO hyperparameters
kl_coef=getattr(args, "kl_coef", 0.1),
kl_coef=getattr(args, "kl_coef", 0.0),
clip_eps=getattr(args, "clip_eps", 0.2),
use_reference_logprobs=not getattr(args, "no_reference_logprobs", False),
adafactor_scale_parameter=getattr(args, "adafactor_scale_parameter", False),
adafactor_relative_step=getattr(args, "adafactor_relative_step", False),
# vLLM settings
vllm_restart_interval=getattr(args, "vllm_restart_interval", 3),
vllm_port=args.vllm_port,
@ -422,6 +450,7 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig:
wandb_project=args.wandb_project,
wandb_group=getattr(args, "wandb_group", None),
weight_bridge_mode=getattr(args, "weight_bridge_mode", "none"),
train_layer_indices=getattr(args, "train_layer_indices", None),
trainer_rank=getattr(args, "trainer_rank", 0),
world_size=getattr(args, "world_size", 1),
init_method=getattr(args, "init_method", "env://"),