BLEUBERI/training/config.py
2025-06-04 20:36:43 +00:00

70 lines
4.7 KiB
Python

import argparse
from typing import Optional, Union, List
def parse_args():
parser = argparse.ArgumentParser(description="Training configuration arguments")
# DeepSpeed configuration
parser.add_argument("--deepspeed", type=str, default=None, help="DeepSpeed config file path")
parser.add_argument("--local_rank", type=int, default=-1, help="Used by Deepspeed for distributed training")
# Output and run configuration
parser.add_argument("--run_name", type=str, default="default", help="Name of the training run")
parser.add_argument("--wandb_project", type=str, default="default", help="Name of the wandb project")
# Model and cache configuration
parser.add_argument("--model_path", type=str, required=True, help="Path to the model")
parser.add_argument("--cache_dir", type=str, required=True, help="Cache directory for models and data")
parser.add_argument("--ckpt_dir", type=str, required=True, help="Checkpoint directory for model checkpoints")
# Data configuration
parser.add_argument("--data_path", type=str, default=None, help="Path to the data")
parser.add_argument("--load_from_disk", action="store_true", help="Whether to load data from disk")
parser.add_argument("--shuffle", action="store_true", help="Whether to shuffle the data")
parser.add_argument("--streaming", action="store_true", help="Whether to stream the data")
parser.add_argument("--train_split", type=str, default="train", help="Training split name")
parser.add_argument("--model_max_length", type=int, default=128000, help="Maximum model sequence length")
# Training configuration
parser.add_argument("--save_strategy", type=str, required=True, help="Save strategy (steps or epoch)")
parser.add_argument("--save_steps", type=int, help="Save frequency in steps")
parser.add_argument("--save_total_limit", type=int, default=10, help="Maximum number of checkpoints to keep")
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
parser.add_argument("--max_steps", type=int, default=-1, help="Maximum number of training steps")
parser.add_argument("--eval_strategy", type=str, default="steps", help="Evaluation strategy (steps or epoch)")
parser.add_argument("--eval_steps", type=float, default=0.1, help="Evaluation frequency (as ratio or absolute steps)")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument("--bf16", action="store_true", help="Whether to use bf16 precision")
# Learning rate and optimization
parser.add_argument("--learning_rate", type=float, default=5e-6, help="Learning rate")
parser.add_argument("--lr_scheduler_type", type=str, default="constant_with_warmup", help="Learning rate scheduler type")
parser.add_argument("--warmup_ratio", type=float, default=0.05, help="Warmup ratio")
parser.add_argument("--max_grad_norm", type=float, default=0.2, help="Maximum gradient norm")
# Batch configuration
parser.add_argument("--per_device_train_batch_size", type=int, default=4, help="Per device training batch size")
parser.add_argument("--per_device_eval_batch_size", type=int, default=4, help="Per device evaluation batch size")
parser.add_argument("--gradient_accumulation_steps", type=int, default=4, help="Number of gradient accumulation steps")
# Generation configuration
parser.add_argument("--reasoning", action="store_true", help="Whether to add reasoning to the data")
parser.add_argument("--num_generations", type=int, default=4, help="Number of generations")
parser.add_argument("--max_prompt_length", type=int, default=512, help="Maximum prompt length")
parser.add_argument("--max_completion_length", type=int, default=512, help="Maximum completion length")
parser.add_argument("--temperature", type=float, default=0.9, help="Temperature for generation")
parser.add_argument("--num_iterations", type=int, default=1, help="Number of iterations")
# Logging
parser.add_argument("--log_level", type=str, default="info", help="Logging level")
parser.add_argument("--logging_steps", type=int, default=1, help="Logging frequency in steps")
# SFT configuration
parser.add_argument("--packing", action="store_true", help="Whether to pack the data")
parser.add_argument("--dataset_text_field", type=str, default="text", help="Dataset text field")
# Reward functions
parser.add_argument("--reward_funcs", nargs="+", choices=["bleu", "emb", "bertscore", "rm", "format", "bleu_emb", "rouge", "bleu_rouge_f1", "bleu_rm"],
help="List of reward functions to use (e.g., --reward_funcs bleu emb)")
return parser.parse_args()