mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
Add eval
This commit is contained in:
parent
e26161713e
commit
a01110aa7c
1 changed files with 48 additions and 0 deletions
|
|
@ -9,8 +9,10 @@ from unsloth import FastLanguageModel, PatchFastRL
|
|||
PatchFastRL("GRPO", FastLanguageModel)
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import re
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
from unsloth import is_bfloat16_supported
|
||||
|
|
@ -98,8 +100,40 @@ def train(model, tokenizer, dataset, training_args):
|
|||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
logging.info("Saving model...")
|
||||
trainer.save_model("outputs")
|
||||
|
||||
|
||||
def evaluate(model, tokenizer, dataset, *args, **kwargs):
|
||||
model.eval()
|
||||
correct_preds = 0
|
||||
total_preds = 0
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
prompt = item["prompt"]
|
||||
metadata = item["metadata"]
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(
|
||||
inputs,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
answer = utils.extract_answer(generated_text)
|
||||
score = dataset.data.score_answer(answer, entry=metadata)
|
||||
correct_preds += score
|
||||
total_preds += 1
|
||||
|
||||
return correct_preds / total_preds
|
||||
|
||||
|
||||
def main(args):
|
||||
model, tokenizer = get_model_and_tokenizer(
|
||||
|
|
@ -132,6 +166,17 @@ def main(args):
|
|||
|
||||
train(model, tokenizer, dataset, training_args)
|
||||
|
||||
eval_dataset = ReasoningGymDataset(
|
||||
args.dataset_name,
|
||||
args.eval_seed,
|
||||
args.eval_size,
|
||||
tokenizer,
|
||||
utils.SYSTEM_PROMPTS["DeepSeekZero"],
|
||||
)
|
||||
|
||||
accuracy = evaluate(model, tokenizer, eval_dataset, max_new_tokens=training_args.max_completion_length)
|
||||
logging.info(f"Evaluation accuracy: {accuracy * 100}%")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
|
@ -148,6 +193,9 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--dataset-seed", type=int, default=42)
|
||||
parser.add_argument("--dataset-size", type=int, default=1000)
|
||||
|
||||
parser.add_argument("--eval-seed", type=int, default=42)
|
||||
parser.add_argument("--eval-size", type=int, default=100)
|
||||
|
||||
parser.add_argument("--gpu-memory-utilization", type=float, default=0.7)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue