mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-25 17:10:42 +00:00
cleanup
This commit is contained in:
parent
24b8ab8574
commit
c8884348c7
8 changed files with 360 additions and 820 deletions
|
|
@ -1,229 +1,212 @@
|
|||
"""
|
||||
Command-line interface for GRPO trainer.
|
||||
|
||||
Provides argument parsing and configuration building.
|
||||
Provides modular argument group builders and configuration building.
|
||||
This is the SINGLE SOURCE OF TRUTH for all CLI arguments.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .config import TrainingConfig
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""
|
||||
Parse command-line arguments for the GRPO trainer.
|
||||
|
||||
Returns:
|
||||
Parsed arguments namespace
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="GRPO Trainer with optional shared-weight vLLM integration",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
# =============================================================================
|
||||
# Argument Group Builders (modular, reusable)
|
||||
# =============================================================================
|
||||
|
||||
# === Core Training Arguments ===
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
def add_model_args(parser: argparse.ArgumentParser) -> None:
|
||||
"""Add model-related arguments."""
|
||||
group = parser.add_argument_group("Model")
|
||||
group.add_argument(
|
||||
"--model", "--model-name",
|
||||
type=str,
|
||||
required=True,
|
||||
dest="model_name",
|
||||
help="HuggingFace model identifier (e.g., 'Qwen/Qwen2.5-1.5B-Instruct')",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
|
||||
def add_training_args(parser: argparse.ArgumentParser) -> None:
|
||||
"""Add core training arguments."""
|
||||
group = parser.add_argument_group("Training")
|
||||
group.add_argument(
|
||||
"--lr",
|
||||
type=float,
|
||||
default=1e-5,
|
||||
help="Learning rate for the optimizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
group.add_argument(
|
||||
"--training-steps",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of training steps to run",
|
||||
)
|
||||
parser.add_argument(
|
||||
group.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Batch size for training",
|
||||
)
|
||||
parser.add_argument(
|
||||
group.add_argument(
|
||||
"--seq-len",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Maximum sequence length",
|
||||
)
|
||||
parser.add_argument(
|
||||
group.add_argument(
|
||||
"--gradient-accumulation-steps",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Number of gradient accumulation steps",
|
||||
)
|
||||
parser.add_argument(
|
||||
group.add_argument(
|
||||
"--optimizer",
|
||||
type=str,
|
||||
choices=["adamw", "adamw_8bit", "adamw_cpu", "adafactor"],
|
||||
default="adamw_8bit",
|
||||
help="Optimizer: 'adamw' (full precision, ~32GB GPU), "
|
||||
"'adamw_8bit' (8-bit states, ~8GB GPU), "
|
||||
"'adamw_cpu' (CPU offload, ~0GB GPU, slower), "
|
||||
"'adafactor' (no momentum, ~8GB GPU)",
|
||||
help="Optimizer: 'adamw' (full precision), 'adamw_8bit' (8-bit states), "
|
||||
"'adamw_cpu' (CPU offload), 'adafactor' (no momentum)",
|
||||
)
|
||||
|
||||
# === GRPO/PPO Hyperparameters ===
|
||||
parser.add_argument(
|
||||
"--kl-coef",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help=(
|
||||
"KL divergence penalty coefficient (beta). "
|
||||
"Controls policy deviation from reference. "
|
||||
"Higher = more conservative, prevents reward hacking. "
|
||||
"0 = disabled (not recommended)."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clip-eps",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help=(
|
||||
"PPO-style clipping epsilon. "
|
||||
"Clips importance ratio to [1-eps, 1+eps]. "
|
||||
"Prevents destabilizing large policy updates."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-reference-logprobs",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Disable use of inference logprobs as reference policy. "
|
||||
"Falls back to REINFORCE-style updates (not recommended)."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
group.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
help="Device to train on (cuda/cpu)",
|
||||
)
|
||||
parser.add_argument(
|
||||
group.add_argument(
|
||||
"--save-path",
|
||||
type=str,
|
||||
default="trained_model_checkpoints",
|
||||
help="Directory to save model checkpoints",
|
||||
)
|
||||
parser.add_argument(
|
||||
group.add_argument(
|
||||
"--checkpoint-interval",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Save checkpoint every N training steps (0 = only save final)",
|
||||
)
|
||||
|
||||
# === vLLM Arguments ===
|
||||
parser.add_argument(
|
||||
"--vllm-restart-interval",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Restart vLLM every N training steps (legacy mode only)",
|
||||
|
||||
def add_grpo_args(parser: argparse.ArgumentParser) -> None:
|
||||
"""Add GRPO/PPO hyperparameter arguments."""
|
||||
group = parser.add_argument_group("GRPO/PPO Hyperparameters")
|
||||
group.add_argument(
|
||||
"--kl-coef",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="KL divergence penalty coefficient (beta). Higher = more conservative.",
|
||||
)
|
||||
parser.add_argument(
|
||||
group.add_argument(
|
||||
"--clip-eps",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="PPO-style clipping epsilon. Clips ratio to [1-eps, 1+eps].",
|
||||
)
|
||||
group.add_argument(
|
||||
"--no-reference-logprobs",
|
||||
action="store_true",
|
||||
help="Disable use of inference logprobs as reference policy (not recommended).",
|
||||
)
|
||||
|
||||
|
||||
def add_vllm_args(parser: argparse.ArgumentParser) -> None:
|
||||
"""Add vLLM server arguments."""
|
||||
group = parser.add_argument_group("vLLM Server")
|
||||
group.add_argument(
|
||||
"--vllm-port",
|
||||
type=int,
|
||||
default=9001,
|
||||
help="Port for the vLLM server",
|
||||
)
|
||||
parser.add_argument(
|
||||
group.add_argument(
|
||||
"--gpu-memory-utilization", "--vllm-gpu-memory-utilization",
|
||||
type=float,
|
||||
default=0.45,
|
||||
dest="gpu_memory_utilization",
|
||||
help="GPU memory utilization for vLLM server (0.0-1.0)",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-model-len",
|
||||
type=int,
|
||||
default=4096,
|
||||
help="Maximum context length for vLLM",
|
||||
)
|
||||
group.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
default="bfloat16",
|
||||
choices=["bfloat16", "float16", "auto"],
|
||||
help="Data type for model weights",
|
||||
)
|
||||
group.add_argument(
|
||||
"--vllm-restart-interval",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Restart vLLM every N training steps (legacy mode only)",
|
||||
)
|
||||
|
||||
|
||||
def add_atropos_args(parser: argparse.ArgumentParser) -> None:
|
||||
"""Add Atropos API arguments."""
|
||||
group = parser.add_argument_group("Atropos API")
|
||||
group.add_argument(
|
||||
"--atropos-url",
|
||||
type=str,
|
||||
default="http://localhost:8000",
|
||||
help="URL of the Atropos API/environment server (e.g., gsm8k_server)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vllm-gpu-memory-utilization",
|
||||
type=float,
|
||||
default=0.45,
|
||||
help="GPU memory utilization for vLLM server (0.0-1.0)",
|
||||
help="URL of the Atropos API/environment server",
|
||||
)
|
||||
|
||||
# === Wandb Arguments ===
|
||||
parser.add_argument(
|
||||
|
||||
def add_wandb_args(parser: argparse.ArgumentParser) -> None:
|
||||
"""Add Weights & Biases arguments."""
|
||||
group = parser.add_argument_group("Weights & Biases")
|
||||
group.add_argument(
|
||||
"--use-wandb",
|
||||
action="store_true",
|
||||
help="Enable Weights & Biases logging",
|
||||
)
|
||||
parser.add_argument(
|
||||
group.add_argument(
|
||||
"--wandb-project",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Wandb project name",
|
||||
)
|
||||
parser.add_argument(
|
||||
group.add_argument(
|
||||
"--wandb-group",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Wandb group name",
|
||||
)
|
||||
|
||||
# === Training Mode Arguments ===
|
||||
parser.add_argument(
|
||||
|
||||
def add_mode_args(parser: argparse.ArgumentParser) -> None:
|
||||
"""Add training mode arguments."""
|
||||
group = parser.add_argument_group("Training Mode")
|
||||
group.add_argument(
|
||||
"--weight-bridge-mode",
|
||||
type=str,
|
||||
choices=["shared_vllm", "lora_only", "none"],
|
||||
default="none",
|
||||
help=(
|
||||
"Weight sync mode: "
|
||||
"'shared_vllm' = attach to vLLM shared memory, "
|
||||
"'lora_only' = train LoRA adapters only, "
|
||||
"'none' = legacy restart-based sync"
|
||||
),
|
||||
help="Weight sync mode: 'shared_vllm', 'lora_only', or 'none' (legacy)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trainer-rank",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Rank of this trainer in the distributed group",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--world-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Total processes in the distributed group",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--init-method",
|
||||
group.add_argument(
|
||||
"--vllm-config-path",
|
||||
type=str,
|
||||
default="env://",
|
||||
help="PyTorch distributed init method (e.g., 'env://', 'tcp://host:port')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-inference-nodes",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of inference nodes to coordinate with (0 = single-node local)",
|
||||
default=None,
|
||||
help="Explicit path to vllm_bridge_config.json (auto-detected if not provided)",
|
||||
)
|
||||
|
||||
# === LoRA Arguments ===
|
||||
parser.add_argument(
|
||||
"--lora-r",
|
||||
type=int,
|
||||
default=16,
|
||||
help="LoRA rank (dimension of low-rank matrices)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-alpha",
|
||||
type=int,
|
||||
default=32,
|
||||
help="LoRA alpha (scaling factor, typically 2x rank)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-dropout",
|
||||
type=float,
|
||||
default=0.05,
|
||||
help="Dropout probability for LoRA layers",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
def add_lora_args(parser: argparse.ArgumentParser) -> None:
|
||||
"""Add LoRA-specific arguments."""
|
||||
group = parser.add_argument_group("LoRA Configuration")
|
||||
group.add_argument("--lora-r", type=int, default=16, help="LoRA rank")
|
||||
group.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha (scaling factor)")
|
||||
group.add_argument("--lora-dropout", type=float, default=0.05, help="LoRA dropout")
|
||||
group.add_argument(
|
||||
"--lora-target-modules",
|
||||
type=str,
|
||||
nargs="+",
|
||||
|
|
@ -231,48 +214,100 @@ def parse_args() -> argparse.Namespace:
|
|||
help="Module names to apply LoRA to (default: q_proj v_proj)",
|
||||
)
|
||||
|
||||
# === Single-Copy Mode Arguments ===
|
||||
parser.add_argument(
|
||||
"--single-copy",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Enable TRUE single-copy mode (shared_vllm mode only). "
|
||||
"Trainer attaches to vLLM's model tensors via CUDA IPC. "
|
||||
"Only ONE copy of the model exists in GPU memory! "
|
||||
"Requires trainer and vLLM to be on the SAME GPU(s). "
|
||||
"vLLM must be started with VLLM_ENABLE_SHARED_WEIGHTS=1."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vllm-config-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Explicit path to vllm_bridge_config.json. "
|
||||
"If not provided, auto-detects from LOGDIR, current directory, "
|
||||
"or /tmp/atropos_bridge. "
|
||||
"This file contains CUDA IPC handles created by vLLM."
|
||||
),
|
||||
)
|
||||
|
||||
# === Debug Flags ===
|
||||
parser.add_argument(
|
||||
def add_distributed_args(parser: argparse.ArgumentParser) -> None:
|
||||
"""Add distributed training arguments."""
|
||||
group = parser.add_argument_group("Distributed Training")
|
||||
group.add_argument("--trainer-rank", type=int, default=0, help="Trainer rank")
|
||||
group.add_argument("--world-size", type=int, default=1, help="World size")
|
||||
group.add_argument("--init-method", type=str, default="env://", help="Distributed init method")
|
||||
group.add_argument("--num-inference-nodes", type=int, default=0, help="Number of inference nodes")
|
||||
|
||||
|
||||
def add_debug_args(parser: argparse.ArgumentParser) -> None:
|
||||
"""Add debug/benchmark arguments."""
|
||||
group = parser.add_argument_group("Debug & Benchmarking")
|
||||
group.add_argument(
|
||||
"--debug-loading",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Enable verbose debug output during model loading and IPC attachment. "
|
||||
"Useful for diagnosing single-copy mode issues."
|
||||
),
|
||||
help="Enable verbose debug output during model loading",
|
||||
)
|
||||
parser.add_argument(
|
||||
group.add_argument(
|
||||
"--benchmark",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Enable benchmark timing output showing step time, sync time, "
|
||||
"data fetch time, and GPU memory usage per step."
|
||||
),
|
||||
help="Enable benchmark timing output",
|
||||
)
|
||||
group.add_argument(
|
||||
"--log-dir",
|
||||
type=str,
|
||||
default="./logs",
|
||||
help="Directory for log files",
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Parser Builders
|
||||
# =============================================================================
|
||||
|
||||
def create_base_parser(description: str) -> argparse.ArgumentParser:
|
||||
"""Create a base parser with common formatting."""
|
||||
return argparse.ArgumentParser(
|
||||
description=description,
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
|
||||
def create_full_parser() -> argparse.ArgumentParser:
|
||||
"""
|
||||
Create a parser with ALL arguments (for grpo.py multi-mode entry point).
|
||||
"""
|
||||
parser = create_base_parser("GRPO Trainer - Multi-mode training")
|
||||
|
||||
add_model_args(parser)
|
||||
add_training_args(parser)
|
||||
add_grpo_args(parser)
|
||||
add_vllm_args(parser)
|
||||
add_atropos_args(parser)
|
||||
add_wandb_args(parser)
|
||||
add_mode_args(parser)
|
||||
add_lora_args(parser)
|
||||
add_distributed_args(parser)
|
||||
add_debug_args(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def create_unified_parser() -> argparse.ArgumentParser:
|
||||
"""
|
||||
Create a parser for run.py (unified shared_vllm mode with integrated vLLM).
|
||||
"""
|
||||
parser = create_base_parser(
|
||||
"Unified GRPO Trainer - Starts vLLM server and trainer in one command"
|
||||
)
|
||||
|
||||
add_model_args(parser)
|
||||
add_training_args(parser)
|
||||
add_grpo_args(parser)
|
||||
add_vllm_args(parser)
|
||||
add_atropos_args(parser)
|
||||
add_wandb_args(parser)
|
||||
add_debug_args(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Legacy API (backwards compatibility)
|
||||
# =============================================================================
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""
|
||||
Parse command-line arguments for the GRPO trainer (grpo.py).
|
||||
|
||||
Returns:
|
||||
Parsed arguments namespace
|
||||
"""
|
||||
parser = create_full_parser()
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
|
@ -302,25 +337,25 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig:
|
|||
clip_eps=getattr(args, "clip_eps", 0.2),
|
||||
use_reference_logprobs=not getattr(args, "no_reference_logprobs", False),
|
||||
# vLLM settings
|
||||
vllm_restart_interval=args.vllm_restart_interval,
|
||||
vllm_restart_interval=getattr(args, "vllm_restart_interval", 3),
|
||||
vllm_port=args.vllm_port,
|
||||
vllm_gpu_memory_utilization=args.vllm_gpu_memory_utilization,
|
||||
vllm_gpu_memory_utilization=getattr(args, "gpu_memory_utilization", 0.45),
|
||||
max_model_len=getattr(args, "max_model_len", 4096),
|
||||
dtype=getattr(args, "dtype", "bfloat16"),
|
||||
use_wandb=args.use_wandb,
|
||||
wandb_project=args.wandb_project,
|
||||
wandb_group=args.wandb_group,
|
||||
weight_bridge_mode=args.weight_bridge_mode,
|
||||
trainer_rank=args.trainer_rank,
|
||||
world_size=args.world_size,
|
||||
init_method=args.init_method,
|
||||
num_inference_nodes=args.num_inference_nodes,
|
||||
lora_r=args.lora_r,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_dropout=args.lora_dropout,
|
||||
lora_target_modules=args.lora_target_modules,
|
||||
single_copy=getattr(args, "single_copy", False),
|
||||
wandb_group=getattr(args, "wandb_group", None),
|
||||
weight_bridge_mode=getattr(args, "weight_bridge_mode", "none"),
|
||||
trainer_rank=getattr(args, "trainer_rank", 0),
|
||||
world_size=getattr(args, "world_size", 1),
|
||||
init_method=getattr(args, "init_method", "env://"),
|
||||
num_inference_nodes=getattr(args, "num_inference_nodes", 0),
|
||||
lora_r=getattr(args, "lora_r", 16),
|
||||
lora_alpha=getattr(args, "lora_alpha", 32),
|
||||
lora_dropout=getattr(args, "lora_dropout", 0.05),
|
||||
lora_target_modules=getattr(args, "lora_target_modules", None),
|
||||
vllm_config_path=getattr(args, "vllm_config_path", None),
|
||||
debug_loading=getattr(args, "debug_loading", False),
|
||||
benchmark=getattr(args, "benchmark", False),
|
||||
atropos_url=getattr(args, "atropos_url", "http://localhost:8000"),
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue