This commit is contained in:
Oliver 2025-02-20 22:18:52 +00:00
parent e26161713e
commit a01110aa7c

View file

@ -9,8 +9,10 @@ from unsloth import FastLanguageModel, PatchFastRL
PatchFastRL("GRPO", FastLanguageModel)
import argparse
import logging
import re
import torch
from torch.utils.data import Dataset
from trl import GRPOConfig, GRPOTrainer
from unsloth import is_bfloat16_supported
@ -98,8 +100,40 @@ def train(model, tokenizer, dataset, training_args):
args=training_args,
train_dataset=dataset,
)
trainer.train()
logging.info("Saving model...")
trainer.save_model("outputs")
def evaluate(model, tokenizer, dataset, *args, **kwargs):
model.eval()
correct_preds = 0
total_preds = 0
for i in range(len(dataset)):
item = dataset[i]
prompt = item["prompt"]
metadata = item["metadata"]
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = model.generate(
inputs,
pad_token_id=tokenizer.eos_token_id,
*args,
**kwargs,
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
answer = utils.extract_answer(generated_text)
score = dataset.data.score_answer(answer, entry=metadata)
correct_preds += score
total_preds += 1
return correct_preds / total_preds
def main(args):
model, tokenizer = get_model_and_tokenizer(
@ -132,6 +166,17 @@ def main(args):
train(model, tokenizer, dataset, training_args)
eval_dataset = ReasoningGymDataset(
args.dataset_name,
args.eval_seed,
args.eval_size,
tokenizer,
utils.SYSTEM_PROMPTS["DeepSeekZero"],
)
accuracy = evaluate(model, tokenizer, eval_dataset, max_new_tokens=training_args.max_completion_length)
logging.info(f"Evaluation accuracy: {accuracy * 100}%")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
@ -148,6 +193,9 @@ if __name__ == "__main__":
parser.add_argument("--dataset-seed", type=int, default=42)
parser.add_argument("--dataset-size", type=int, default=1000)
parser.add_argument("--eval-seed", type=int, default=42)
parser.add_argument("--eval-size", type=int, default=100)
parser.add_argument("--gpu-memory-utilization", type=float, default=0.7)
args = parser.parse_args()