BLEUBERI/training/utils.py
2025-06-04 20:36:43 +00:00

221 lines
7.7 KiB
Python

import torch
import logging
import os
import psutil
import subprocess
import gc
import json
import csv
from pathlib import Path
from os.path import exists
import os
import random
import numpy as np
import pickle as pkl
import hashlib
from transformers import LlamaForCausalLM, Gemma2ForCausalLM, AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification
from tqdm import tqdm
from evaluate import load
from typing import List, Optional, Union, Tuple, Dict, Any
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
bleu = load("bleu")
rouge = load("rouge")
bertscore = load("bertscore")
def check_existence(path, isDir=False):
if isDir and not path.endswith("/"):
path += "/"
pathExists = exists(path)
if not pathExists:
return False
if not isDir:
filePath = Path(path)
if not filePath.is_file():
return False
return True
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def load_model(model_path, cache_dir, access_token=None):
if "llama" in model_path.lower():
model = LlamaForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="balanced",
cache_dir=cache_dir,
attn_implementation="flash_attention_2",
)
elif "gemma" in model_path.lower():
model = Gemma2ForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="balanced",
cache_dir=cache_dir,
attn_implementation="flash_attention_2",
)
elif "mistral" in model_path.lower() or "qwen" in model_path.lower():
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="balanced",
cache_dir=cache_dir,
attn_implementation="flash_attention_2",
token=access_token
)
elif "olmo" in model_path.lower():
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="balanced",
cache_dir=cache_dir,
attn_implementation="flash_attention_2",
token=access_token
)
else:
raise ValueError("Unrecognized model: {}".format(model_path))
return model
def shorten_ref_model_name(model_name):
model_name = model_name.lower()
if "o4-mini" in model_name:
return "o4mini"
elif "deepseek" in model_name:
return "deepseek"
elif "claude" in model_name:
return "claude"
elif "gemini" in model_name:
return "gemini"
elif "llama" in model_name:
return "llama"
else:
return model_name
def get_reward_model(device="cuda"):
model_name = "Skywork/Skywork-Reward-Llama-3.1-8B-v0.2"
rm = AutoModelForSequenceClassification.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
attn_implementation="flash_attention_2",
num_labels=1,
)
rm_tokenizer = AutoTokenizer.from_pretrained(model_name)
return rm, rm_tokenizer
def bleu_reward(prediction, references):
if isinstance(prediction, float) or (isinstance(prediction, str) and len(prediction.strip()) == 0):
return 0
bleu_score = bleu.compute(predictions=[prediction], references=[references], smooth=True)
return bleu_score["bleu"]
def rouge_reward(prediction, references):
if isinstance(prediction, float) or (isinstance(prediction, str) and len(prediction.strip()) == 0):
return 0
rouge_score = rouge.compute(predictions=[prediction], references=[references])
return rouge_score["rougeL"]
def bleu_rouge_f1_reward(prediction, references):
bleu_score = bleu_reward(prediction, references)
rouge_score = rouge_reward(prediction, references)
return 2 * bleu_score * rouge_score / (bleu_score + rouge_score) if (bleu_score + rouge_score) > 0 else 0.0
def bertscore_reward(prediction, references):
if isinstance(prediction, float) or (isinstance(prediction, str) and len(prediction.strip()) == 0):
return 0
bertscore_score = bertscore.compute(predictions=[prediction], references=[references], model_type="distilbert-base-uncased")
return bertscore_score["f1"][0]
def rm_reward(predictions, prompts, rm_model=None, rm_tokenizer=None, device="cuda"):
single_input = not isinstance(predictions, list)
if single_input:
predictions = [predictions]
prompts = [prompts]
if rm_model is None or rm_tokenizer is None:
rm_model, rm_tokenizer = get_reward_model(device)
all_scores = []
for i, (prediction, prompt) in enumerate(tqdm(zip(predictions, prompts), desc="Computing reward scores", total=len(predictions), disable=single_input)):
if isinstance(prediction, float) or (isinstance(prediction, str) and len(prediction.strip()) == 0):
all_scores.append(0.0)
continue
conversation = [
{"role": "user", "content": prompt},
{"role": "assistant", "content": prediction}
]
conv_tokenized = rm_tokenizer.apply_chat_template(
conversation,
tokenize=True,
return_tensors="pt"
).to(device)
with torch.no_grad():
score = rm_model(conv_tokenized).logits[0][0].item()
all_scores.append(float(score))
return all_scores[0] if single_input else all_scores
def get_model_name(name):
if "ckpts" in name:
model_basename = name.split("/")[-2:]
model_basename = "_".join(model_basename)
else:
model_basename = os.path.basename(name)
return model_basename
def get_ref_models_str(ref_models_or_count):
if isinstance(ref_models_or_count, int):
nrefs = ref_models_or_count
ref_models_str = ""
else:
nrefs = len(ref_models_or_count)
ref_models_str = "-" + "-".join([shorten_ref_model_name(m) for m in ref_models_or_count])
return nrefs, ref_models_str
def build_score_path(base_dir, data_path, metric, model, nrefs, ref_models_str=""):
data_basename = os.path.basename(data_path)
return os.path.join(
base_dir,
f"{data_basename}_{metric}_{get_model_name(model)}_{nrefs}ref{ref_models_str}.json"
)
def save_histogram(score_values: List[float], metric_name_for_plot: str, title: str, xlabel: str, fig_path: str, bins: int = 50):
"""Helper function to generate and save a histogram."""
if not score_values:
print(f"No score values provided for metric '{metric_name_for_plot}', skipping histogram generation for {fig_path}.")
return
plt.figure(figsize=(10, 6))
plt.hist(score_values, bins=bins)
plt.title(title)
plt.xlabel(xlabel)
plt.ylabel("Frequency")
plt.savefig(fig_path)
plt.close()
print(f"Saved score distribution plot to {fig_path}")
def save_scores(scores, score_path):
os.makedirs(os.path.dirname(score_path), exist_ok=True)
with open(score_path, 'w') as f:
json.dump(scores, f)
print(f"Saved scores to {score_path}")
if scores: # Ensure scores is not empty
metric_name = list(scores.values())[0]['metric']
score_values = [s["score"] for s in scores.values()]
fig_path = f"{os.path.splitext(score_path)[0]}_distribution.png"
save_histogram(score_values, metric_name, f"{metric_name.upper()} Score Distribution", f"{metric_name.upper()} Score", fig_path)
else:
print("No scores to save, skipping histogram generation.")