mirror of
https://github.com/lilakk/BLEUBERI.git
synced 2026-04-27 17:23:23 +00:00
221 lines
7.7 KiB
Python
221 lines
7.7 KiB
Python
import torch
|
|
import logging
|
|
import os
|
|
import psutil
|
|
import subprocess
|
|
import gc
|
|
import json
|
|
import csv
|
|
from pathlib import Path
|
|
from os.path import exists
|
|
import os
|
|
import random
|
|
import numpy as np
|
|
import pickle as pkl
|
|
import hashlib
|
|
from transformers import LlamaForCausalLM, Gemma2ForCausalLM, AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification
|
|
from tqdm import tqdm
|
|
from evaluate import load
|
|
from typing import List, Optional, Union, Tuple, Dict, Any
|
|
import matplotlib
|
|
matplotlib.use('Agg')
|
|
import matplotlib.pyplot as plt
|
|
|
|
bleu = load("bleu")
|
|
rouge = load("rouge")
|
|
bertscore = load("bertscore")
|
|
|
|
def check_existence(path, isDir=False):
|
|
if isDir and not path.endswith("/"):
|
|
path += "/"
|
|
pathExists = exists(path)
|
|
if not pathExists:
|
|
return False
|
|
if not isDir:
|
|
filePath = Path(path)
|
|
if not filePath.is_file():
|
|
return False
|
|
return True
|
|
|
|
def set_seed(seed):
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
def load_model(model_path, cache_dir, access_token=None):
|
|
if "llama" in model_path.lower():
|
|
model = LlamaForCausalLM.from_pretrained(
|
|
model_path,
|
|
torch_dtype=torch.bfloat16,
|
|
device_map="balanced",
|
|
cache_dir=cache_dir,
|
|
attn_implementation="flash_attention_2",
|
|
)
|
|
elif "gemma" in model_path.lower():
|
|
model = Gemma2ForCausalLM.from_pretrained(
|
|
model_path,
|
|
torch_dtype=torch.bfloat16,
|
|
device_map="balanced",
|
|
cache_dir=cache_dir,
|
|
attn_implementation="flash_attention_2",
|
|
)
|
|
elif "mistral" in model_path.lower() or "qwen" in model_path.lower():
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_path,
|
|
torch_dtype=torch.bfloat16,
|
|
device_map="balanced",
|
|
cache_dir=cache_dir,
|
|
attn_implementation="flash_attention_2",
|
|
token=access_token
|
|
)
|
|
elif "olmo" in model_path.lower():
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_path,
|
|
torch_dtype=torch.bfloat16,
|
|
device_map="balanced",
|
|
cache_dir=cache_dir,
|
|
attn_implementation="flash_attention_2",
|
|
token=access_token
|
|
)
|
|
else:
|
|
raise ValueError("Unrecognized model: {}".format(model_path))
|
|
return model
|
|
|
|
def shorten_ref_model_name(model_name):
|
|
model_name = model_name.lower()
|
|
if "o4-mini" in model_name:
|
|
return "o4mini"
|
|
elif "deepseek" in model_name:
|
|
return "deepseek"
|
|
elif "claude" in model_name:
|
|
return "claude"
|
|
elif "gemini" in model_name:
|
|
return "gemini"
|
|
elif "llama" in model_name:
|
|
return "llama"
|
|
else:
|
|
return model_name
|
|
|
|
def get_reward_model(device="cuda"):
|
|
model_name = "Skywork/Skywork-Reward-Llama-3.1-8B-v0.2"
|
|
rm = AutoModelForSequenceClassification.from_pretrained(
|
|
model_name,
|
|
torch_dtype=torch.bfloat16,
|
|
device_map=device,
|
|
attn_implementation="flash_attention_2",
|
|
num_labels=1,
|
|
)
|
|
rm_tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
return rm, rm_tokenizer
|
|
|
|
def bleu_reward(prediction, references):
|
|
if isinstance(prediction, float) or (isinstance(prediction, str) and len(prediction.strip()) == 0):
|
|
return 0
|
|
bleu_score = bleu.compute(predictions=[prediction], references=[references], smooth=True)
|
|
return bleu_score["bleu"]
|
|
|
|
def rouge_reward(prediction, references):
|
|
if isinstance(prediction, float) or (isinstance(prediction, str) and len(prediction.strip()) == 0):
|
|
return 0
|
|
rouge_score = rouge.compute(predictions=[prediction], references=[references])
|
|
return rouge_score["rougeL"]
|
|
|
|
def bleu_rouge_f1_reward(prediction, references):
|
|
bleu_score = bleu_reward(prediction, references)
|
|
rouge_score = rouge_reward(prediction, references)
|
|
return 2 * bleu_score * rouge_score / (bleu_score + rouge_score) if (bleu_score + rouge_score) > 0 else 0.0
|
|
|
|
def bertscore_reward(prediction, references):
|
|
if isinstance(prediction, float) or (isinstance(prediction, str) and len(prediction.strip()) == 0):
|
|
return 0
|
|
bertscore_score = bertscore.compute(predictions=[prediction], references=[references], model_type="distilbert-base-uncased")
|
|
return bertscore_score["f1"][0]
|
|
|
|
def rm_reward(predictions, prompts, rm_model=None, rm_tokenizer=None, device="cuda"):
|
|
single_input = not isinstance(predictions, list)
|
|
if single_input:
|
|
predictions = [predictions]
|
|
prompts = [prompts]
|
|
|
|
if rm_model is None or rm_tokenizer is None:
|
|
rm_model, rm_tokenizer = get_reward_model(device)
|
|
|
|
all_scores = []
|
|
|
|
for i, (prediction, prompt) in enumerate(tqdm(zip(predictions, prompts), desc="Computing reward scores", total=len(predictions), disable=single_input)):
|
|
if isinstance(prediction, float) or (isinstance(prediction, str) and len(prediction.strip()) == 0):
|
|
all_scores.append(0.0)
|
|
continue
|
|
conversation = [
|
|
{"role": "user", "content": prompt},
|
|
{"role": "assistant", "content": prediction}
|
|
]
|
|
conv_tokenized = rm_tokenizer.apply_chat_template(
|
|
conversation,
|
|
tokenize=True,
|
|
return_tensors="pt"
|
|
).to(device)
|
|
with torch.no_grad():
|
|
score = rm_model(conv_tokenized).logits[0][0].item()
|
|
all_scores.append(float(score))
|
|
|
|
return all_scores[0] if single_input else all_scores
|
|
|
|
def get_model_name(name):
|
|
if "ckpts" in name:
|
|
model_basename = name.split("/")[-2:]
|
|
model_basename = "_".join(model_basename)
|
|
else:
|
|
model_basename = os.path.basename(name)
|
|
return model_basename
|
|
|
|
def get_ref_models_str(ref_models_or_count):
|
|
if isinstance(ref_models_or_count, int):
|
|
nrefs = ref_models_or_count
|
|
ref_models_str = ""
|
|
else:
|
|
nrefs = len(ref_models_or_count)
|
|
ref_models_str = "-" + "-".join([shorten_ref_model_name(m) for m in ref_models_or_count])
|
|
return nrefs, ref_models_str
|
|
|
|
def build_score_path(base_dir, data_path, metric, model, nrefs, ref_models_str=""):
|
|
data_basename = os.path.basename(data_path)
|
|
return os.path.join(
|
|
base_dir,
|
|
f"{data_basename}_{metric}_{get_model_name(model)}_{nrefs}ref{ref_models_str}.json"
|
|
)
|
|
|
|
def save_histogram(score_values: List[float], metric_name_for_plot: str, title: str, xlabel: str, fig_path: str, bins: int = 50):
|
|
"""Helper function to generate and save a histogram."""
|
|
if not score_values:
|
|
print(f"No score values provided for metric '{metric_name_for_plot}', skipping histogram generation for {fig_path}.")
|
|
return
|
|
|
|
plt.figure(figsize=(10, 6))
|
|
plt.hist(score_values, bins=bins)
|
|
plt.title(title)
|
|
plt.xlabel(xlabel)
|
|
plt.ylabel("Frequency")
|
|
plt.savefig(fig_path)
|
|
plt.close()
|
|
print(f"Saved score distribution plot to {fig_path}")
|
|
|
|
def save_scores(scores, score_path):
|
|
os.makedirs(os.path.dirname(score_path), exist_ok=True)
|
|
|
|
with open(score_path, 'w') as f:
|
|
json.dump(scores, f)
|
|
|
|
print(f"Saved scores to {score_path}")
|
|
|
|
if scores: # Ensure scores is not empty
|
|
metric_name = list(scores.values())[0]['metric']
|
|
score_values = [s["score"] for s in scores.values()]
|
|
fig_path = f"{os.path.splitext(score_path)[0]}_distribution.png"
|
|
save_histogram(score_values, metric_name, f"{metric_name.upper()} Score Distribution", f"{metric_name.upper()} Score", fig_path)
|
|
else:
|
|
print("No scores to save, skipping histogram generation.")
|