Better progress tracking

This commit is contained in:
Oliver 2025-02-20 23:32:54 +00:00
parent 90547f30c7
commit f16dd9a7d4

View file

@ -14,6 +14,7 @@ import re
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from trl import GRPOConfig, GRPOTrainer
from unsloth import is_bfloat16_supported
@ -112,7 +113,7 @@ def evaluate(model, tokenizer, dataset, *args, **kwargs):
correct_preds = 0
total_preds = 0
for i in range(len(dataset)):
for i in tqdm(range(len(dataset))):
item = dataset[i]
prompt = item["prompt"]
metadata = item["metadata"]
@ -156,7 +157,7 @@ def main(args):
logging_steps=1,
bf16=is_bfloat16_supported(),
fp16=not is_bfloat16_supported(),
per_device_train_batch_size=1,
per_device_train_batch_size=args.train_batch_size,
gradient_accumulation_steps=1,
num_generations=args.num_generations,
num_train_epochs=args.train_epochs,
@ -193,6 +194,7 @@ if __name__ == "__main__":
parser.add_argument("--quantize", action="store_true")
parser.add_argument("--num-generations", type=int, default=8)
parser.add_argument("--train-epochs", type=int, default=1)
parser.add_argument("--train-batch-size", type=int, default=8)
parser.add_argument("--dataset-seed", type=int, default=42)
parser.add_argument("--dataset-size", type=int, default=1000)