This commit is contained in:
Jai Suphavadeeprasit 2026-02-02 22:59:32 -05:00
parent 24b8ab8574
commit c8884348c7
8 changed files with 360 additions and 820 deletions

View file

@ -1,229 +1,212 @@
"""
Command-line interface for GRPO trainer.
Provides argument parsing and configuration building.
Provides modular argument group builders and configuration building.
This is the SINGLE SOURCE OF TRUTH for all CLI arguments.
"""
import argparse
from typing import Optional
import torch
from .config import TrainingConfig
def parse_args() -> argparse.Namespace:
"""
Parse command-line arguments for the GRPO trainer.
Returns:
Parsed arguments namespace
"""
parser = argparse.ArgumentParser(
description="GRPO Trainer with optional shared-weight vLLM integration",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# =============================================================================
# Argument Group Builders (modular, reusable)
# =============================================================================
# === Core Training Arguments ===
parser.add_argument(
"--model-name",
def add_model_args(parser: argparse.ArgumentParser) -> None:
"""Add model-related arguments."""
group = parser.add_argument_group("Model")
group.add_argument(
"--model", "--model-name",
type=str,
required=True,
dest="model_name",
help="HuggingFace model identifier (e.g., 'Qwen/Qwen2.5-1.5B-Instruct')",
)
parser.add_argument(
def add_training_args(parser: argparse.ArgumentParser) -> None:
"""Add core training arguments."""
group = parser.add_argument_group("Training")
group.add_argument(
"--lr",
type=float,
default=1e-5,
help="Learning rate for the optimizer",
)
parser.add_argument(
group.add_argument(
"--training-steps",
type=int,
default=10,
help="Number of training steps to run",
)
parser.add_argument(
group.add_argument(
"--batch-size",
type=int,
default=2,
help="Batch size for training",
)
parser.add_argument(
group.add_argument(
"--seq-len",
type=int,
default=2048,
help="Maximum sequence length",
)
parser.add_argument(
group.add_argument(
"--gradient-accumulation-steps",
type=int,
default=32,
help="Number of gradient accumulation steps",
)
parser.add_argument(
group.add_argument(
"--optimizer",
type=str,
choices=["adamw", "adamw_8bit", "adamw_cpu", "adafactor"],
default="adamw_8bit",
help="Optimizer: 'adamw' (full precision, ~32GB GPU), "
"'adamw_8bit' (8-bit states, ~8GB GPU), "
"'adamw_cpu' (CPU offload, ~0GB GPU, slower), "
"'adafactor' (no momentum, ~8GB GPU)",
help="Optimizer: 'adamw' (full precision), 'adamw_8bit' (8-bit states), "
"'adamw_cpu' (CPU offload), 'adafactor' (no momentum)",
)
# === GRPO/PPO Hyperparameters ===
parser.add_argument(
"--kl-coef",
type=float,
default=0.1,
help=(
"KL divergence penalty coefficient (beta). "
"Controls policy deviation from reference. "
"Higher = more conservative, prevents reward hacking. "
"0 = disabled (not recommended)."
),
)
parser.add_argument(
"--clip-eps",
type=float,
default=0.2,
help=(
"PPO-style clipping epsilon. "
"Clips importance ratio to [1-eps, 1+eps]. "
"Prevents destabilizing large policy updates."
),
)
parser.add_argument(
"--no-reference-logprobs",
action="store_true",
help=(
"Disable use of inference logprobs as reference policy. "
"Falls back to REINFORCE-style updates (not recommended)."
),
)
parser.add_argument(
group.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to train on (cuda/cpu)",
)
parser.add_argument(
group.add_argument(
"--save-path",
type=str,
default="trained_model_checkpoints",
help="Directory to save model checkpoints",
)
parser.add_argument(
group.add_argument(
"--checkpoint-interval",
type=int,
default=3,
help="Save checkpoint every N training steps (0 = only save final)",
)
# === vLLM Arguments ===
parser.add_argument(
"--vllm-restart-interval",
type=int,
default=3,
help="Restart vLLM every N training steps (legacy mode only)",
def add_grpo_args(parser: argparse.ArgumentParser) -> None:
"""Add GRPO/PPO hyperparameter arguments."""
group = parser.add_argument_group("GRPO/PPO Hyperparameters")
group.add_argument(
"--kl-coef",
type=float,
default=0.1,
help="KL divergence penalty coefficient (beta). Higher = more conservative.",
)
parser.add_argument(
group.add_argument(
"--clip-eps",
type=float,
default=0.2,
help="PPO-style clipping epsilon. Clips ratio to [1-eps, 1+eps].",
)
group.add_argument(
"--no-reference-logprobs",
action="store_true",
help="Disable use of inference logprobs as reference policy (not recommended).",
)
def add_vllm_args(parser: argparse.ArgumentParser) -> None:
"""Add vLLM server arguments."""
group = parser.add_argument_group("vLLM Server")
group.add_argument(
"--vllm-port",
type=int,
default=9001,
help="Port for the vLLM server",
)
parser.add_argument(
group.add_argument(
"--gpu-memory-utilization", "--vllm-gpu-memory-utilization",
type=float,
default=0.45,
dest="gpu_memory_utilization",
help="GPU memory utilization for vLLM server (0.0-1.0)",
)
group.add_argument(
"--max-model-len",
type=int,
default=4096,
help="Maximum context length for vLLM",
)
group.add_argument(
"--dtype",
type=str,
default="bfloat16",
choices=["bfloat16", "float16", "auto"],
help="Data type for model weights",
)
group.add_argument(
"--vllm-restart-interval",
type=int,
default=3,
help="Restart vLLM every N training steps (legacy mode only)",
)
def add_atropos_args(parser: argparse.ArgumentParser) -> None:
"""Add Atropos API arguments."""
group = parser.add_argument_group("Atropos API")
group.add_argument(
"--atropos-url",
type=str,
default="http://localhost:8000",
help="URL of the Atropos API/environment server (e.g., gsm8k_server)",
)
parser.add_argument(
"--vllm-gpu-memory-utilization",
type=float,
default=0.45,
help="GPU memory utilization for vLLM server (0.0-1.0)",
help="URL of the Atropos API/environment server",
)
# === Wandb Arguments ===
parser.add_argument(
def add_wandb_args(parser: argparse.ArgumentParser) -> None:
"""Add Weights & Biases arguments."""
group = parser.add_argument_group("Weights & Biases")
group.add_argument(
"--use-wandb",
action="store_true",
help="Enable Weights & Biases logging",
)
parser.add_argument(
group.add_argument(
"--wandb-project",
type=str,
default=None,
help="Wandb project name",
)
parser.add_argument(
group.add_argument(
"--wandb-group",
type=str,
default=None,
help="Wandb group name",
)
# === Training Mode Arguments ===
parser.add_argument(
def add_mode_args(parser: argparse.ArgumentParser) -> None:
"""Add training mode arguments."""
group = parser.add_argument_group("Training Mode")
group.add_argument(
"--weight-bridge-mode",
type=str,
choices=["shared_vllm", "lora_only", "none"],
default="none",
help=(
"Weight sync mode: "
"'shared_vllm' = attach to vLLM shared memory, "
"'lora_only' = train LoRA adapters only, "
"'none' = legacy restart-based sync"
),
help="Weight sync mode: 'shared_vllm', 'lora_only', or 'none' (legacy)",
)
parser.add_argument(
"--trainer-rank",
type=int,
default=0,
help="Rank of this trainer in the distributed group",
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Total processes in the distributed group",
)
parser.add_argument(
"--init-method",
group.add_argument(
"--vllm-config-path",
type=str,
default="env://",
help="PyTorch distributed init method (e.g., 'env://', 'tcp://host:port')",
)
parser.add_argument(
"--num-inference-nodes",
type=int,
default=0,
help="Number of inference nodes to coordinate with (0 = single-node local)",
default=None,
help="Explicit path to vllm_bridge_config.json (auto-detected if not provided)",
)
# === LoRA Arguments ===
parser.add_argument(
"--lora-r",
type=int,
default=16,
help="LoRA rank (dimension of low-rank matrices)",
)
parser.add_argument(
"--lora-alpha",
type=int,
default=32,
help="LoRA alpha (scaling factor, typically 2x rank)",
)
parser.add_argument(
"--lora-dropout",
type=float,
default=0.05,
help="Dropout probability for LoRA layers",
)
parser.add_argument(
def add_lora_args(parser: argparse.ArgumentParser) -> None:
"""Add LoRA-specific arguments."""
group = parser.add_argument_group("LoRA Configuration")
group.add_argument("--lora-r", type=int, default=16, help="LoRA rank")
group.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha (scaling factor)")
group.add_argument("--lora-dropout", type=float, default=0.05, help="LoRA dropout")
group.add_argument(
"--lora-target-modules",
type=str,
nargs="+",
@ -231,48 +214,100 @@ def parse_args() -> argparse.Namespace:
help="Module names to apply LoRA to (default: q_proj v_proj)",
)
# === Single-Copy Mode Arguments ===
parser.add_argument(
"--single-copy",
action="store_true",
help=(
"Enable TRUE single-copy mode (shared_vllm mode only). "
"Trainer attaches to vLLM's model tensors via CUDA IPC. "
"Only ONE copy of the model exists in GPU memory! "
"Requires trainer and vLLM to be on the SAME GPU(s). "
"vLLM must be started with VLLM_ENABLE_SHARED_WEIGHTS=1."
),
)
parser.add_argument(
"--vllm-config-path",
type=str,
default=None,
help=(
"Explicit path to vllm_bridge_config.json. "
"If not provided, auto-detects from LOGDIR, current directory, "
"or /tmp/atropos_bridge. "
"This file contains CUDA IPC handles created by vLLM."
),
)
# === Debug Flags ===
parser.add_argument(
def add_distributed_args(parser: argparse.ArgumentParser) -> None:
"""Add distributed training arguments."""
group = parser.add_argument_group("Distributed Training")
group.add_argument("--trainer-rank", type=int, default=0, help="Trainer rank")
group.add_argument("--world-size", type=int, default=1, help="World size")
group.add_argument("--init-method", type=str, default="env://", help="Distributed init method")
group.add_argument("--num-inference-nodes", type=int, default=0, help="Number of inference nodes")
def add_debug_args(parser: argparse.ArgumentParser) -> None:
"""Add debug/benchmark arguments."""
group = parser.add_argument_group("Debug & Benchmarking")
group.add_argument(
"--debug-loading",
action="store_true",
help=(
"Enable verbose debug output during model loading and IPC attachment. "
"Useful for diagnosing single-copy mode issues."
),
help="Enable verbose debug output during model loading",
)
parser.add_argument(
group.add_argument(
"--benchmark",
action="store_true",
help=(
"Enable benchmark timing output showing step time, sync time, "
"data fetch time, and GPU memory usage per step."
),
help="Enable benchmark timing output",
)
group.add_argument(
"--log-dir",
type=str,
default="./logs",
help="Directory for log files",
)
# =============================================================================
# Parser Builders
# =============================================================================
def create_base_parser(description: str) -> argparse.ArgumentParser:
"""Create a base parser with common formatting."""
return argparse.ArgumentParser(
description=description,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
def create_full_parser() -> argparse.ArgumentParser:
"""
Create a parser with ALL arguments (for grpo.py multi-mode entry point).
"""
parser = create_base_parser("GRPO Trainer - Multi-mode training")
add_model_args(parser)
add_training_args(parser)
add_grpo_args(parser)
add_vllm_args(parser)
add_atropos_args(parser)
add_wandb_args(parser)
add_mode_args(parser)
add_lora_args(parser)
add_distributed_args(parser)
add_debug_args(parser)
return parser
def create_unified_parser() -> argparse.ArgumentParser:
"""
Create a parser for run.py (unified shared_vllm mode with integrated vLLM).
"""
parser = create_base_parser(
"Unified GRPO Trainer - Starts vLLM server and trainer in one command"
)
add_model_args(parser)
add_training_args(parser)
add_grpo_args(parser)
add_vllm_args(parser)
add_atropos_args(parser)
add_wandb_args(parser)
add_debug_args(parser)
return parser
# =============================================================================
# Legacy API (backwards compatibility)
# =============================================================================
def parse_args() -> argparse.Namespace:
"""
Parse command-line arguments for the GRPO trainer (grpo.py).
Returns:
Parsed arguments namespace
"""
parser = create_full_parser()
return parser.parse_args()
@ -302,25 +337,25 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig:
clip_eps=getattr(args, "clip_eps", 0.2),
use_reference_logprobs=not getattr(args, "no_reference_logprobs", False),
# vLLM settings
vllm_restart_interval=args.vllm_restart_interval,
vllm_restart_interval=getattr(args, "vllm_restart_interval", 3),
vllm_port=args.vllm_port,
vllm_gpu_memory_utilization=args.vllm_gpu_memory_utilization,
vllm_gpu_memory_utilization=getattr(args, "gpu_memory_utilization", 0.45),
max_model_len=getattr(args, "max_model_len", 4096),
dtype=getattr(args, "dtype", "bfloat16"),
use_wandb=args.use_wandb,
wandb_project=args.wandb_project,
wandb_group=args.wandb_group,
weight_bridge_mode=args.weight_bridge_mode,
trainer_rank=args.trainer_rank,
world_size=args.world_size,
init_method=args.init_method,
num_inference_nodes=args.num_inference_nodes,
lora_r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
lora_target_modules=args.lora_target_modules,
single_copy=getattr(args, "single_copy", False),
wandb_group=getattr(args, "wandb_group", None),
weight_bridge_mode=getattr(args, "weight_bridge_mode", "none"),
trainer_rank=getattr(args, "trainer_rank", 0),
world_size=getattr(args, "world_size", 1),
init_method=getattr(args, "init_method", "env://"),
num_inference_nodes=getattr(args, "num_inference_nodes", 0),
lora_r=getattr(args, "lora_r", 16),
lora_alpha=getattr(args, "lora_alpha", 32),
lora_dropout=getattr(args, "lora_dropout", 0.05),
lora_target_modules=getattr(args, "lora_target_modules", None),
vllm_config_path=getattr(args, "vllm_config_path", None),
debug_loading=getattr(args, "debug_loading", False),
benchmark=getattr(args, "benchmark", False),
atropos_url=getattr(args, "atropos_url", "http://localhost:8000"),
)