[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2026-02-06 06:46:14 +00:00 committed by Jai Suphavadeeprasit
parent d07ab3e3ce
commit 5cfd1929f1
19 changed files with 708 additions and 452 deletions

View file

@ -11,16 +11,17 @@ import torch
from .config import TrainingConfig
# =============================================================================
# Argument Group Builders (modular, reusable)
# =============================================================================
def add_model_args(parser: argparse.ArgumentParser) -> None:
"""Add model-related arguments."""
group = parser.add_argument_group("Model")
group.add_argument(
"--model", "--model-name",
"--model",
"--model-name",
type=str,
required=True,
dest="model_name",
@ -67,7 +68,7 @@ def add_training_args(parser: argparse.ArgumentParser) -> None:
choices=["adamw", "adamw_8bit", "adamw_cpu", "adafactor"],
default="adamw_8bit",
help="Optimizer: 'adamw' (full precision), 'adamw_8bit' (8-bit states), "
"'adamw_cpu' (CPU offload), 'adafactor' (no momentum)",
"'adamw_cpu' (CPU offload), 'adafactor' (no momentum)",
)
group.add_argument(
"--device",
@ -121,7 +122,8 @@ def add_vllm_args(parser: argparse.ArgumentParser) -> None:
help="Port for the vLLM server",
)
group.add_argument(
"--gpu-memory-utilization", "--vllm-gpu-memory-utilization",
"--gpu-memory-utilization",
"--vllm-gpu-memory-utilization",
type=float,
default=0.45,
dest="gpu_memory_utilization",
@ -203,7 +205,9 @@ def add_lora_args(parser: argparse.ArgumentParser) -> None:
"""Add LoRA-specific arguments."""
group = parser.add_argument_group("LoRA Configuration")
group.add_argument("--lora-r", type=int, default=16, help="LoRA rank")
group.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha (scaling factor)")
group.add_argument(
"--lora-alpha", type=int, default=32, help="LoRA alpha (scaling factor)"
)
group.add_argument("--lora-dropout", type=float, default=0.05, help="LoRA dropout")
group.add_argument(
"--lora-target-modules",
@ -219,8 +223,12 @@ def add_distributed_args(parser: argparse.ArgumentParser) -> None:
group = parser.add_argument_group("Distributed Training")
group.add_argument("--trainer-rank", type=int, default=0, help="Trainer rank")
group.add_argument("--world-size", type=int, default=1, help="World size")
group.add_argument("--init-method", type=str, default="env://", help="Distributed init method")
group.add_argument("--num-inference-nodes", type=int, default=0, help="Number of inference nodes")
group.add_argument(
"--init-method", type=str, default="env://", help="Distributed init method"
)
group.add_argument(
"--num-inference-nodes", type=int, default=0, help="Number of inference nodes"
)
def add_debug_args(parser: argparse.ArgumentParser) -> None:
@ -248,6 +256,7 @@ def add_debug_args(parser: argparse.ArgumentParser) -> None:
# Parser Builders
# =============================================================================
def create_base_parser(description: str) -> argparse.ArgumentParser:
"""Create a base parser with common formatting."""
return argparse.ArgumentParser(
@ -261,7 +270,7 @@ def create_full_parser() -> argparse.ArgumentParser:
Create a parser with ALL arguments (for grpo.py multi-mode entry point).
"""
parser = create_base_parser("GRPO Trainer - Multi-mode training")
add_model_args(parser)
add_training_args(parser)
add_grpo_args(parser)
@ -272,7 +281,7 @@ def create_full_parser() -> argparse.ArgumentParser:
add_lora_args(parser)
add_distributed_args(parser)
add_debug_args(parser)
return parser
@ -283,7 +292,7 @@ def create_unified_parser() -> argparse.ArgumentParser:
parser = create_base_parser(
"Unified GRPO Trainer - Starts vLLM server and trainer in one command"
)
add_model_args(parser)
add_training_args(parser)
add_grpo_args(parser)
@ -291,7 +300,7 @@ def create_unified_parser() -> argparse.ArgumentParser:
add_atropos_args(parser)
add_wandb_args(parser)
add_debug_args(parser)
return parser
@ -299,10 +308,11 @@ def create_unified_parser() -> argparse.ArgumentParser:
# Legacy API (backwards compatibility)
# =============================================================================
def parse_args() -> argparse.Namespace:
"""
Parse command-line arguments for the GRPO trainer (grpo.py).
Returns:
Parsed arguments namespace
"""
@ -313,10 +323,10 @@ def parse_args() -> argparse.Namespace:
def config_from_args(args: argparse.Namespace) -> TrainingConfig:
"""
Build a TrainingConfig from parsed CLI arguments.
Args:
args: Parsed argparse namespace
Returns:
TrainingConfig instance
"""