diff --git a/examples/unsloth/train_grpo_lora.py b/examples/unsloth/train_grpo_lora.py index 02edc2e6..b166cfbc 100644 --- a/examples/unsloth/train_grpo_lora.py +++ b/examples/unsloth/train_grpo_lora.py @@ -9,8 +9,10 @@ from unsloth import FastLanguageModel, PatchFastRL PatchFastRL("GRPO", FastLanguageModel) import argparse +import logging import re +import torch from torch.utils.data import Dataset from trl import GRPOConfig, GRPOTrainer from unsloth import is_bfloat16_supported @@ -98,8 +100,40 @@ def train(model, tokenizer, dataset, training_args): args=training_args, train_dataset=dataset, ) + trainer.train() + logging.info("Saving model...") + trainer.save_model("outputs") + + +def evaluate(model, tokenizer, dataset, *args, **kwargs): + model.eval() + correct_preds = 0 + total_preds = 0 + + for i in range(len(dataset)): + item = dataset[i] + prompt = item["prompt"] + metadata = item["metadata"] + inputs = tokenizer(prompt, return_tensors="pt").to("cuda") + + with torch.no_grad(): + outputs = model.generate( + inputs, + pad_token_id=tokenizer.eos_token_id, + *args, + **kwargs, + ) + + generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + answer = utils.extract_answer(generated_text) + score = dataset.data.score_answer(answer, entry=metadata) + correct_preds += score + total_preds += 1 + + return correct_preds / total_preds + def main(args): model, tokenizer = get_model_and_tokenizer( @@ -132,6 +166,17 @@ def main(args): train(model, tokenizer, dataset, training_args) + eval_dataset = ReasoningGymDataset( + args.dataset_name, + args.eval_seed, + args.eval_size, + tokenizer, + utils.SYSTEM_PROMPTS["DeepSeekZero"], + ) + + accuracy = evaluate(model, tokenizer, eval_dataset, max_new_tokens=training_args.max_completion_length) + logging.info(f"Evaluation accuracy: {accuracy * 100}%") + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -148,6 +193,9 @@ if __name__ == "__main__": parser.add_argument("--dataset-seed", type=int, default=42) parser.add_argument("--dataset-size", type=int, default=1000) + parser.add_argument("--eval-seed", type=int, default=42) + parser.add_argument("--eval-size", type=int, default=100) + parser.add_argument("--gpu-memory-utilization", type=float, default=0.7) args = parser.parse_args()