mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
major refactor
This commit is contained in:
parent
119721ef3d
commit
6833d4d820
13 changed files with 3268 additions and 3423 deletions
271
example_trainer/cli.py
Normal file
271
example_trainer/cli.py
Normal file
|
|
@ -0,0 +1,271 @@
|
|||
"""
|
||||
Command-line interface for GRPO trainer.
|
||||
|
||||
Provides argument parsing and configuration building.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from .config import TrainingConfig
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""
|
||||
Parse command-line arguments for the GRPO trainer.
|
||||
|
||||
Returns:
|
||||
Parsed arguments namespace
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="GRPO Trainer with optional shared-weight vLLM integration",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
# === Core Training Arguments ===
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
required=True,
|
||||
help="HuggingFace model identifier (e.g., 'Qwen/Qwen2.5-1.5B-Instruct')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr",
|
||||
type=float,
|
||||
default=1e-5,
|
||||
help="Learning rate for the optimizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--training-steps",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of training steps to run",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Batch size for training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seq-len",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Maximum sequence length",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient-accumulation-steps",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Number of gradient accumulation steps",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
help="Device to train on (cuda/cpu)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
type=str,
|
||||
default="trained_model_checkpoints",
|
||||
help="Directory to save model checkpoints",
|
||||
)
|
||||
|
||||
# === vLLM Arguments ===
|
||||
parser.add_argument(
|
||||
"--vllm-restart-interval",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Restart vLLM every N training steps (legacy mode only)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vllm-port",
|
||||
type=int,
|
||||
default=9001,
|
||||
help="Port for the vLLM server",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--atropos-url",
|
||||
type=str,
|
||||
default="http://localhost:8000",
|
||||
help="URL of the Atropos API/environment server (e.g., gsm8k_server)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vllm-gpu-memory-utilization",
|
||||
type=float,
|
||||
default=0.45,
|
||||
help="GPU memory utilization for vLLM server (0.0-1.0)",
|
||||
)
|
||||
|
||||
# === Wandb Arguments ===
|
||||
parser.add_argument(
|
||||
"--use-wandb",
|
||||
action="store_true",
|
||||
help="Enable Weights & Biases logging",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wandb-project",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Wandb project name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wandb-group",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Wandb group name",
|
||||
)
|
||||
|
||||
# === Training Mode Arguments ===
|
||||
parser.add_argument(
|
||||
"--weight-bridge-mode",
|
||||
type=str,
|
||||
choices=["shared_vllm", "lora_only", "none"],
|
||||
default="none",
|
||||
help=(
|
||||
"Weight sync mode: "
|
||||
"'shared_vllm' = attach to vLLM shared memory, "
|
||||
"'lora_only' = train LoRA adapters only, "
|
||||
"'none' = legacy restart-based sync"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trainer-rank",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Rank of this trainer in the distributed group",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--world-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Total processes in the distributed group",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--init-method",
|
||||
type=str,
|
||||
default="env://",
|
||||
help="PyTorch distributed init method (e.g., 'env://', 'tcp://host:port')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-inference-nodes",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of inference nodes to coordinate with (0 = single-node local)",
|
||||
)
|
||||
|
||||
# === LoRA Arguments ===
|
||||
parser.add_argument(
|
||||
"--lora-r",
|
||||
type=int,
|
||||
default=16,
|
||||
help="LoRA rank (dimension of low-rank matrices)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-alpha",
|
||||
type=int,
|
||||
default=32,
|
||||
help="LoRA alpha (scaling factor, typically 2x rank)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-dropout",
|
||||
type=float,
|
||||
default=0.05,
|
||||
help="Dropout probability for LoRA layers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-target-modules",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="Module names to apply LoRA to (default: q_proj v_proj)",
|
||||
)
|
||||
|
||||
# === Single-Copy Mode Arguments ===
|
||||
parser.add_argument(
|
||||
"--single-copy",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Enable TRUE single-copy mode (shared_vllm mode only). "
|
||||
"Trainer attaches to vLLM's model tensors via CUDA IPC. "
|
||||
"Only ONE copy of the model exists in GPU memory! "
|
||||
"Requires trainer and vLLM to be on the SAME GPU(s). "
|
||||
"vLLM must be started with VLLM_ENABLE_SHARED_WEIGHTS=1."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vllm-config-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Explicit path to vllm_bridge_config.json. "
|
||||
"If not provided, auto-detects from LOGDIR, current directory, "
|
||||
"or /tmp/atropos_bridge. "
|
||||
"This file contains CUDA IPC handles created by vLLM."
|
||||
),
|
||||
)
|
||||
|
||||
# === Debug Flags ===
|
||||
parser.add_argument(
|
||||
"--debug-loading",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Enable verbose debug output during model loading and IPC attachment. "
|
||||
"Useful for diagnosing single-copy mode issues."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--benchmark",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Enable benchmark timing output showing step time, sync time, "
|
||||
"data fetch time, and GPU memory usage per step."
|
||||
),
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def config_from_args(args: argparse.Namespace) -> TrainingConfig:
|
||||
"""
|
||||
Build a TrainingConfig from parsed CLI arguments.
|
||||
|
||||
Args:
|
||||
args: Parsed argparse namespace
|
||||
|
||||
Returns:
|
||||
TrainingConfig instance
|
||||
"""
|
||||
return TrainingConfig(
|
||||
model_name=args.model_name,
|
||||
lr=args.lr,
|
||||
training_steps=args.training_steps,
|
||||
batch_size=args.batch_size,
|
||||
seq_len=args.seq_len,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
device=args.device,
|
||||
save_path=args.save_path,
|
||||
vllm_restart_interval=args.vllm_restart_interval,
|
||||
vllm_port=args.vllm_port,
|
||||
vllm_gpu_memory_utilization=args.vllm_gpu_memory_utilization,
|
||||
use_wandb=args.use_wandb,
|
||||
wandb_project=args.wandb_project,
|
||||
wandb_group=args.wandb_group,
|
||||
weight_bridge_mode=args.weight_bridge_mode,
|
||||
trainer_rank=args.trainer_rank,
|
||||
world_size=args.world_size,
|
||||
init_method=args.init_method,
|
||||
num_inference_nodes=args.num_inference_nodes,
|
||||
lora_r=args.lora_r,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_dropout=args.lora_dropout,
|
||||
lora_target_modules=args.lora_target_modules,
|
||||
single_copy=getattr(args, "single_copy", False),
|
||||
vllm_config_path=getattr(args, "vllm_config_path", None),
|
||||
debug_loading=getattr(args, "debug_loading", False),
|
||||
benchmark=getattr(args, "benchmark", False),
|
||||
atropos_url=getattr(args, "atropos_url", "http://localhost:8000"),
|
||||
)
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue