mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-30 17:40:45 +00:00
Better progress tracking
This commit is contained in:
parent
90547f30c7
commit
f16dd9a7d4
1 changed files with 4 additions and 2 deletions
|
|
@ -14,6 +14,7 @@ import re
|
|||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
from unsloth import is_bfloat16_supported
|
||||
|
||||
|
|
@ -112,7 +113,7 @@ def evaluate(model, tokenizer, dataset, *args, **kwargs):
|
|||
correct_preds = 0
|
||||
total_preds = 0
|
||||
|
||||
for i in range(len(dataset)):
|
||||
for i in tqdm(range(len(dataset))):
|
||||
item = dataset[i]
|
||||
prompt = item["prompt"]
|
||||
metadata = item["metadata"]
|
||||
|
|
@ -156,7 +157,7 @@ def main(args):
|
|||
logging_steps=1,
|
||||
bf16=is_bfloat16_supported(),
|
||||
fp16=not is_bfloat16_supported(),
|
||||
per_device_train_batch_size=1,
|
||||
per_device_train_batch_size=args.train_batch_size,
|
||||
gradient_accumulation_steps=1,
|
||||
num_generations=args.num_generations,
|
||||
num_train_epochs=args.train_epochs,
|
||||
|
|
@ -193,6 +194,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--quantize", action="store_true")
|
||||
parser.add_argument("--num-generations", type=int, default=8)
|
||||
parser.add_argument("--train-epochs", type=int, default=1)
|
||||
parser.add_argument("--train-batch-size", type=int, default=8)
|
||||
|
||||
parser.add_argument("--dataset-seed", type=int, default=42)
|
||||
parser.add_argument("--dataset-size", type=int, default=1000)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue