mirror of
https://github.com/lilakk/BLEUBERI.git
synced 2026-04-26 17:13:21 +00:00
70 lines
4.7 KiB
Python
70 lines
4.7 KiB
Python
import argparse
|
|
from typing import Optional, Union, List
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Training configuration arguments")
|
|
|
|
# DeepSpeed configuration
|
|
parser.add_argument("--deepspeed", type=str, default=None, help="DeepSpeed config file path")
|
|
parser.add_argument("--local_rank", type=int, default=-1, help="Used by Deepspeed for distributed training")
|
|
|
|
# Output and run configuration
|
|
parser.add_argument("--run_name", type=str, default="default", help="Name of the training run")
|
|
parser.add_argument("--wandb_project", type=str, default="default", help="Name of the wandb project")
|
|
|
|
# Model and cache configuration
|
|
parser.add_argument("--model_path", type=str, required=True, help="Path to the model")
|
|
parser.add_argument("--cache_dir", type=str, required=True, help="Cache directory for models and data")
|
|
parser.add_argument("--ckpt_dir", type=str, required=True, help="Checkpoint directory for model checkpoints")
|
|
|
|
# Data configuration
|
|
parser.add_argument("--data_path", type=str, default=None, help="Path to the data")
|
|
parser.add_argument("--load_from_disk", action="store_true", help="Whether to load data from disk")
|
|
parser.add_argument("--shuffle", action="store_true", help="Whether to shuffle the data")
|
|
parser.add_argument("--streaming", action="store_true", help="Whether to stream the data")
|
|
parser.add_argument("--train_split", type=str, default="train", help="Training split name")
|
|
parser.add_argument("--model_max_length", type=int, default=128000, help="Maximum model sequence length")
|
|
|
|
# Training configuration
|
|
parser.add_argument("--save_strategy", type=str, required=True, help="Save strategy (steps or epoch)")
|
|
parser.add_argument("--save_steps", type=int, help="Save frequency in steps")
|
|
parser.add_argument("--save_total_limit", type=int, default=10, help="Maximum number of checkpoints to keep")
|
|
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
|
|
parser.add_argument("--max_steps", type=int, default=-1, help="Maximum number of training steps")
|
|
parser.add_argument("--eval_strategy", type=str, default="steps", help="Evaluation strategy (steps or epoch)")
|
|
parser.add_argument("--eval_steps", type=float, default=0.1, help="Evaluation frequency (as ratio or absolute steps)")
|
|
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
|
parser.add_argument("--bf16", action="store_true", help="Whether to use bf16 precision")
|
|
|
|
# Learning rate and optimization
|
|
parser.add_argument("--learning_rate", type=float, default=5e-6, help="Learning rate")
|
|
parser.add_argument("--lr_scheduler_type", type=str, default="constant_with_warmup", help="Learning rate scheduler type")
|
|
parser.add_argument("--warmup_ratio", type=float, default=0.05, help="Warmup ratio")
|
|
parser.add_argument("--max_grad_norm", type=float, default=0.2, help="Maximum gradient norm")
|
|
|
|
# Batch configuration
|
|
parser.add_argument("--per_device_train_batch_size", type=int, default=4, help="Per device training batch size")
|
|
parser.add_argument("--per_device_eval_batch_size", type=int, default=4, help="Per device evaluation batch size")
|
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=4, help="Number of gradient accumulation steps")
|
|
|
|
# Generation configuration
|
|
parser.add_argument("--reasoning", action="store_true", help="Whether to add reasoning to the data")
|
|
parser.add_argument("--num_generations", type=int, default=4, help="Number of generations")
|
|
parser.add_argument("--max_prompt_length", type=int, default=512, help="Maximum prompt length")
|
|
parser.add_argument("--max_completion_length", type=int, default=512, help="Maximum completion length")
|
|
parser.add_argument("--temperature", type=float, default=0.9, help="Temperature for generation")
|
|
parser.add_argument("--num_iterations", type=int, default=1, help="Number of iterations")
|
|
|
|
# Logging
|
|
parser.add_argument("--log_level", type=str, default="info", help="Logging level")
|
|
parser.add_argument("--logging_steps", type=int, default=1, help="Logging frequency in steps")
|
|
|
|
# SFT configuration
|
|
parser.add_argument("--packing", action="store_true", help="Whether to pack the data")
|
|
parser.add_argument("--dataset_text_field", type=str, default="text", help="Dataset text field")
|
|
|
|
# Reward functions
|
|
parser.add_argument("--reward_funcs", nargs="+", choices=["bleu", "emb", "bertscore", "rm", "format", "bleu_emb", "rouge", "bleu_rouge_f1", "bleu_rm"],
|
|
help="List of reward functions to use (e.g., --reward_funcs bleu emb)")
|
|
|
|
return parser.parse_args()
|