diff --git a/examples/unsloth/train_grpo_lora.py b/examples/unsloth/train_grpo_lora.py index 5e4fd21a..3773feb3 100644 --- a/examples/unsloth/train_grpo_lora.py +++ b/examples/unsloth/train_grpo_lora.py @@ -14,6 +14,7 @@ import re import torch from torch.utils.data import Dataset +from tqdm import tqdm from trl import GRPOConfig, GRPOTrainer from unsloth import is_bfloat16_supported @@ -112,7 +113,7 @@ def evaluate(model, tokenizer, dataset, *args, **kwargs): correct_preds = 0 total_preds = 0 - for i in range(len(dataset)): + for i in tqdm(range(len(dataset))): item = dataset[i] prompt = item["prompt"] metadata = item["metadata"] @@ -156,7 +157,7 @@ def main(args): logging_steps=1, bf16=is_bfloat16_supported(), fp16=not is_bfloat16_supported(), - per_device_train_batch_size=1, + per_device_train_batch_size=args.train_batch_size, gradient_accumulation_steps=1, num_generations=args.num_generations, num_train_epochs=args.train_epochs, @@ -193,6 +194,7 @@ if __name__ == "__main__": parser.add_argument("--quantize", action="store_true") parser.add_argument("--num-generations", type=int, default=8) parser.add_argument("--train-epochs", type=int, default=1) + parser.add_argument("--train-batch-size", type=int, default=8) parser.add_argument("--dataset-seed", type=int, default=42) parser.add_argument("--dataset-size", type=int, default=1000)