mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
d07ab3e3ce
commit
5cfd1929f1
19 changed files with 708 additions and 452 deletions
|
|
@ -11,16 +11,17 @@ import torch
|
|||
|
||||
from .config import TrainingConfig
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Argument Group Builders (modular, reusable)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def add_model_args(parser: argparse.ArgumentParser) -> None:
|
||||
"""Add model-related arguments."""
|
||||
group = parser.add_argument_group("Model")
|
||||
group.add_argument(
|
||||
"--model", "--model-name",
|
||||
"--model",
|
||||
"--model-name",
|
||||
type=str,
|
||||
required=True,
|
||||
dest="model_name",
|
||||
|
|
@ -67,7 +68,7 @@ def add_training_args(parser: argparse.ArgumentParser) -> None:
|
|||
choices=["adamw", "adamw_8bit", "adamw_cpu", "adafactor"],
|
||||
default="adamw_8bit",
|
||||
help="Optimizer: 'adamw' (full precision), 'adamw_8bit' (8-bit states), "
|
||||
"'adamw_cpu' (CPU offload), 'adafactor' (no momentum)",
|
||||
"'adamw_cpu' (CPU offload), 'adafactor' (no momentum)",
|
||||
)
|
||||
group.add_argument(
|
||||
"--device",
|
||||
|
|
@ -121,7 +122,8 @@ def add_vllm_args(parser: argparse.ArgumentParser) -> None:
|
|||
help="Port for the vLLM server",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gpu-memory-utilization", "--vllm-gpu-memory-utilization",
|
||||
"--gpu-memory-utilization",
|
||||
"--vllm-gpu-memory-utilization",
|
||||
type=float,
|
||||
default=0.45,
|
||||
dest="gpu_memory_utilization",
|
||||
|
|
@ -203,7 +205,9 @@ def add_lora_args(parser: argparse.ArgumentParser) -> None:
|
|||
"""Add LoRA-specific arguments."""
|
||||
group = parser.add_argument_group("LoRA Configuration")
|
||||
group.add_argument("--lora-r", type=int, default=16, help="LoRA rank")
|
||||
group.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha (scaling factor)")
|
||||
group.add_argument(
|
||||
"--lora-alpha", type=int, default=32, help="LoRA alpha (scaling factor)"
|
||||
)
|
||||
group.add_argument("--lora-dropout", type=float, default=0.05, help="LoRA dropout")
|
||||
group.add_argument(
|
||||
"--lora-target-modules",
|
||||
|
|
@ -219,8 +223,12 @@ def add_distributed_args(parser: argparse.ArgumentParser) -> None:
|
|||
group = parser.add_argument_group("Distributed Training")
|
||||
group.add_argument("--trainer-rank", type=int, default=0, help="Trainer rank")
|
||||
group.add_argument("--world-size", type=int, default=1, help="World size")
|
||||
group.add_argument("--init-method", type=str, default="env://", help="Distributed init method")
|
||||
group.add_argument("--num-inference-nodes", type=int, default=0, help="Number of inference nodes")
|
||||
group.add_argument(
|
||||
"--init-method", type=str, default="env://", help="Distributed init method"
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-inference-nodes", type=int, default=0, help="Number of inference nodes"
|
||||
)
|
||||
|
||||
|
||||
def add_debug_args(parser: argparse.ArgumentParser) -> None:
|
||||
|
|
@ -248,6 +256,7 @@ def add_debug_args(parser: argparse.ArgumentParser) -> None:
|
|||
# Parser Builders
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def create_base_parser(description: str) -> argparse.ArgumentParser:
|
||||
"""Create a base parser with common formatting."""
|
||||
return argparse.ArgumentParser(
|
||||
|
|
@ -261,7 +270,7 @@ def create_full_parser() -> argparse.ArgumentParser:
|
|||
Create a parser with ALL arguments (for grpo.py multi-mode entry point).
|
||||
"""
|
||||
parser = create_base_parser("GRPO Trainer - Multi-mode training")
|
||||
|
||||
|
||||
add_model_args(parser)
|
||||
add_training_args(parser)
|
||||
add_grpo_args(parser)
|
||||
|
|
@ -272,7 +281,7 @@ def create_full_parser() -> argparse.ArgumentParser:
|
|||
add_lora_args(parser)
|
||||
add_distributed_args(parser)
|
||||
add_debug_args(parser)
|
||||
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
|
@ -283,7 +292,7 @@ def create_unified_parser() -> argparse.ArgumentParser:
|
|||
parser = create_base_parser(
|
||||
"Unified GRPO Trainer - Starts vLLM server and trainer in one command"
|
||||
)
|
||||
|
||||
|
||||
add_model_args(parser)
|
||||
add_training_args(parser)
|
||||
add_grpo_args(parser)
|
||||
|
|
@ -291,7 +300,7 @@ def create_unified_parser() -> argparse.ArgumentParser:
|
|||
add_atropos_args(parser)
|
||||
add_wandb_args(parser)
|
||||
add_debug_args(parser)
|
||||
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
|
@ -299,10 +308,11 @@ def create_unified_parser() -> argparse.ArgumentParser:
|
|||
# Legacy API (backwards compatibility)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""
|
||||
Parse command-line arguments for the GRPO trainer (grpo.py).
|
||||
|
||||
|
||||
Returns:
|
||||
Parsed arguments namespace
|
||||
"""
|
||||
|
|
@ -313,10 +323,10 @@ def parse_args() -> argparse.Namespace:
|
|||
def config_from_args(args: argparse.Namespace) -> TrainingConfig:
|
||||
"""
|
||||
Build a TrainingConfig from parsed CLI arguments.
|
||||
|
||||
|
||||
Args:
|
||||
args: Parsed argparse namespace
|
||||
|
||||
|
||||
Returns:
|
||||
TrainingConfig instance
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue