mirror of
https://github.com/lilakk/BLEUBERI.git
synced 2026-04-19 12:58:12 +00:00
463 lines
17 KiB
Python
463 lines
17 KiB
Python
"""
|
|
Usage:
|
|
python gen_judgment.py --model-list [LIST-OF-MODEL-ID] --parallel [num-concurrent-api-call] --mode [single|pairwise-baseline|pairwise-all]
|
|
"""
|
|
import argparse
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
import json
|
|
import os
|
|
import time
|
|
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
|
|
from fastchat.llm_judge.common import (
|
|
load_questions,
|
|
load_model_answers,
|
|
load_judge_prompts,
|
|
check_data,
|
|
play_a_match_pair,
|
|
play_a_match_single,
|
|
get_model_list,
|
|
Judge,
|
|
MatchPair,
|
|
MatchSingle,
|
|
NEED_REF_CATS,
|
|
)
|
|
|
|
|
|
def load_existing_judgments(output_file_path, mode):
|
|
"""Loads existing judgments from the output file."""
|
|
existing_judgments = set()
|
|
if not os.path.exists(output_file_path):
|
|
print(f"No existing judgment file found at: {output_file_path}")
|
|
return existing_judgments
|
|
|
|
print(f"Loading existing judgments from: {output_file_path}")
|
|
count = 0
|
|
malformed_count = 0
|
|
with open(output_file_path, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
try:
|
|
data = json.loads(line)
|
|
question_id = data.get("question_id")
|
|
|
|
if question_id is None: # Skip if no question_id
|
|
malformed_count +=1
|
|
continue
|
|
|
|
if mode == "single":
|
|
# The judgment output for single mode has "model_id"
|
|
# Changed to "model" based on user's file format
|
|
model_key_to_check = "model"
|
|
model_id = data.get(model_key_to_check) # Use the new variable
|
|
if model_id is not None:
|
|
existing_judgments.add((question_id, model_id))
|
|
count += 1
|
|
else:
|
|
malformed_count += 1
|
|
# print(f"Debug: Malformed single - QID: {question_id}, Keys: {data.keys()}, Expected key: {model_key_to_check}")
|
|
else: # pairwise-baseline or pairwise-all
|
|
# The judgment output for pair mode has "model_1" and "model_2"
|
|
model_1 = data.get("model_1")
|
|
model_2 = data.get("model_2")
|
|
if model_1 is not None and model_2 is not None:
|
|
# The order in MatchPair is (model1, model2)
|
|
# and this order is preserved from make_match / make_match_all_pairs
|
|
existing_judgments.add((question_id, model_1, model_2))
|
|
count += 1
|
|
else:
|
|
malformed_count +=1
|
|
except json.JSONDecodeError:
|
|
malformed_count += 1
|
|
# print(f"Warning: Skipping malformed JSON line in {output_file_path}: {line.strip()}") # Can be noisy
|
|
|
|
if malformed_count > 0:
|
|
print(f"Warning: Skipped {malformed_count} malformed or incomplete lines in existing judgments.")
|
|
print(f"Successfully loaded {count} unique existing judgments.")
|
|
return existing_judgments
|
|
|
|
|
|
def make_match(
|
|
questions,
|
|
models,
|
|
model_answers,
|
|
judge,
|
|
baseline_model,
|
|
ref_answers=None,
|
|
multi_turn=False,
|
|
):
|
|
matches = []
|
|
for q in questions:
|
|
if multi_turn and len(q["turns"]) != 2:
|
|
continue
|
|
for i in range(len(models)):
|
|
q_id = q["question_id"]
|
|
m_1 = models[i]
|
|
m_2 = baseline_model
|
|
if m_1 == m_2:
|
|
continue
|
|
# Ensure answers exist
|
|
if q_id not in model_answers.get(m_1, {}) or \
|
|
q_id not in model_answers.get(baseline_model, {}):
|
|
# print(f"Warning: Skipping Q:{q_id} for {m_1} vs {m_2} due to missing answer(s).")
|
|
continue
|
|
|
|
a_1 = model_answers[m_1][q_id]
|
|
a_2 = model_answers[baseline_model][q_id]
|
|
|
|
current_ref = None
|
|
if ref_answers is not None and judge.ref_based:
|
|
if judge.model_name in ref_answers and q_id in ref_answers[judge.model_name]:
|
|
current_ref = ref_answers[judge.model_name][q_id]
|
|
# else:
|
|
# print(f"Warning: No ref answer for Q:{q_id} with judge {judge.model_name} though ref_based.")
|
|
|
|
|
|
match = MatchPair(
|
|
dict(q),
|
|
m_1,
|
|
m_2,
|
|
a_1,
|
|
a_2,
|
|
judge,
|
|
ref_answer=current_ref,
|
|
multi_turn=multi_turn,
|
|
)
|
|
matches.append(match)
|
|
return matches
|
|
|
|
|
|
def make_match_all_pairs(
|
|
questions,
|
|
models,
|
|
model_answers,
|
|
judge,
|
|
baseline_model=None, # Not used but kept for signature consistency if called polymorphically
|
|
ref_answers=None,
|
|
multi_turn=False,
|
|
):
|
|
matches = []
|
|
for q in questions:
|
|
if multi_turn and len(q["turns"]) != 2:
|
|
continue
|
|
for i in range(len(models)):
|
|
for j in range(i + 1, len(models)):
|
|
q_id = q["question_id"]
|
|
m_1 = models[i]
|
|
m_2 = models[j]
|
|
|
|
if q_id not in model_answers.get(m_1, {}) or \
|
|
q_id not in model_answers.get(m_2, {}):
|
|
# print(f"Warning: Skipping Q:{q_id} for {m_1} vs {m_2} due to missing answer(s).")
|
|
continue
|
|
|
|
a_1 = model_answers[m_1][q_id]
|
|
a_2 = model_answers[m_2][q_id]
|
|
|
|
current_ref = None
|
|
if ref_answers is not None and judge.ref_based:
|
|
if judge.model_name in ref_answers and q_id in ref_answers[judge.model_name]:
|
|
current_ref = ref_answers[judge.model_name][q_id]
|
|
# else:
|
|
# print(f"Warning: No ref answer for Q:{q_id} with judge {judge.model_name} though ref_based.")
|
|
|
|
match = MatchPair(
|
|
dict(q),
|
|
m_1,
|
|
m_2,
|
|
a_1,
|
|
a_2,
|
|
judge,
|
|
ref_answer=current_ref,
|
|
multi_turn=multi_turn,
|
|
)
|
|
matches.append(match)
|
|
return matches
|
|
|
|
|
|
def make_match_single(
|
|
questions,
|
|
models,
|
|
model_answers,
|
|
judge,
|
|
baseline_model=None, # Not used but kept for signature consistency
|
|
ref_answers=None,
|
|
multi_turn=False,
|
|
):
|
|
matches = []
|
|
for q in questions:
|
|
if multi_turn and len(q["turns"]) != 2:
|
|
continue
|
|
for i in range(len(models)):
|
|
q_id = q["question_id"]
|
|
m = models[i]
|
|
|
|
if q_id not in model_answers.get(m, {}):
|
|
# print(f"Warning: Skipping Q:{q_id} for model {m} due to missing answer.")
|
|
continue
|
|
a = model_answers[m][q_id]
|
|
|
|
current_ref = None
|
|
if ref_answers is not None and judge.ref_based:
|
|
if judge.model_name in ref_answers and q_id in ref_answers[judge.model_name]:
|
|
current_ref = ref_answers[judge.model_name][q_id]
|
|
# else:
|
|
# print(f"Warning: No ref answer for Q:{q_id} with judge {judge.model_name} though ref_based.")
|
|
|
|
matches.append(
|
|
MatchSingle(
|
|
dict(q), m, a, judge, ref_answer=current_ref, multi_turn=multi_turn
|
|
)
|
|
)
|
|
return matches
|
|
|
|
|
|
def make_judge_pairwise(judge_model, judge_prompts):
|
|
judges = {}
|
|
judges["default"] = Judge(judge_model, judge_prompts["pair-v2"])
|
|
judges["math"] = Judge(judge_model, judge_prompts["pair-math-v1"], ref_based=True)
|
|
judges["default-mt"] = Judge(
|
|
judge_model, judge_prompts["pair-v2-multi-turn"], multi_turn=True
|
|
)
|
|
judges["math-mt"] = Judge(
|
|
judge_model,
|
|
judge_prompts["pair-math-v1-multi-turn"],
|
|
ref_based=True,
|
|
multi_turn=True,
|
|
)
|
|
return judges
|
|
|
|
|
|
def make_judge_single(judge_model, judge_prompts):
|
|
judges = {}
|
|
judges["default"] = Judge(judge_model, judge_prompts["single-v1"])
|
|
judges["math"] = Judge(judge_model, judge_prompts["single-math-v1"], ref_based=True)
|
|
judges["default-mt"] = Judge(
|
|
judge_model, judge_prompts["single-v1-multi-turn"], multi_turn=True
|
|
)
|
|
judges["math-mt"] = Judge(
|
|
judge_model,
|
|
judge_prompts["single-math-v1-multi-turn"],
|
|
ref_based=True,
|
|
multi_turn=True,
|
|
)
|
|
return judges
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--bench-name",
|
|
type=str,
|
|
default="mt_bench",
|
|
help="The name of the benchmark question set.",
|
|
)
|
|
parser.add_argument(
|
|
"--judge-file",
|
|
type=str,
|
|
default="data/judge_prompts.jsonl",
|
|
help="The file of judge prompts.",
|
|
)
|
|
parser.add_argument("--judge-model", type=str, default="gpt-4.1-mini")
|
|
parser.add_argument("--baseline-model", type=str, default="gpt-3.5-turbo")
|
|
parser.add_argument(
|
|
"--mode",
|
|
type=str,
|
|
default="single",
|
|
choices=["pairwise-baseline", "pairwise-all", "single"],
|
|
help=(
|
|
"Evaluation mode. "
|
|
"`pairwise-baseline` runs pairwise comparision against a baseline. "
|
|
"`pairwise-all` runs pairwise comparision between all pairs. "
|
|
"`single` runs single answer grading."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--model-list",
|
|
type=str,
|
|
nargs="+",
|
|
default=None,
|
|
help="A list of models to be evaluated",
|
|
)
|
|
parser.add_argument(
|
|
"--parallel", type=int, default=1, help="The number of concurrent API calls."
|
|
)
|
|
parser.add_argument(
|
|
"--first-n", type=int, help="A debug option. Only run the first `n` judgments."
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
question_file = f"data/{args.bench_name}/question.jsonl"
|
|
answer_dir = f"data/{args.bench_name}/model_answer"
|
|
ref_answer_dir = f"data/{args.bench_name}/reference_answer"
|
|
|
|
# Load questions
|
|
questions = load_questions(question_file, None, None)
|
|
|
|
# Load answers
|
|
model_answers = load_model_answers(answer_dir)
|
|
# Handle cases where ref_answers might not exist or be relevant for all judges
|
|
try:
|
|
ref_answers = load_model_answers(ref_answer_dir)
|
|
except FileNotFoundError:
|
|
print(f"Reference answer directory not found: {ref_answer_dir}. Proceeding without reference answers.")
|
|
ref_answers = {} # Ensure ref_answers is a dict
|
|
|
|
# Load judge
|
|
judge_prompts = load_judge_prompts(args.judge_file)
|
|
|
|
if args.first_n:
|
|
questions = questions[: args.first_n]
|
|
|
|
if args.model_list is None:
|
|
models = get_model_list(answer_dir)
|
|
else:
|
|
models = args.model_list
|
|
|
|
# Define output_file and play_a_match_func based on mode
|
|
if args.mode == "single":
|
|
judges = make_judge_single(args.judge_model, judge_prompts)
|
|
play_a_match_func = play_a_match_single
|
|
output_file = (
|
|
f"data/{args.bench_name}/model_judgment/{args.judge_model}_single.jsonl"
|
|
)
|
|
make_match_func = make_match_single
|
|
baseline_model = None
|
|
else: # pairwise modes
|
|
judges = make_judge_pairwise(args.judge_model, judge_prompts)
|
|
play_a_match_func = play_a_match_pair
|
|
# Ensure consistent output file naming for pairwise modes
|
|
pair_mode_suffix = "pair_all" if args.mode == "pairwise-all" else "pair_baseline"
|
|
output_file = (
|
|
f"data/{args.bench_name}/model_judgment/{args.judge_model}_{pair_mode_suffix}.jsonl"
|
|
)
|
|
if args.mode == "pairwise-all":
|
|
make_match_func = make_match_all_pairs
|
|
baseline_model = None
|
|
else: # pairwise-baseline
|
|
make_match_func = make_match
|
|
baseline_model = args.baseline_model
|
|
if baseline_model not in models: # If baseline is also in model-list, it's fine, handled by m1==m2 check
|
|
# if baseline model is not in general answer dir, it might cause issues later
|
|
if baseline_model not in model_answers:
|
|
print(f"Warning: Baseline model {baseline_model} answers not found in {answer_dir}. "
|
|
"Ensure it exists or it will be skipped in pairings.")
|
|
|
|
|
|
# Create directory for output file if it doesn't exist
|
|
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
|
print(f"Output file will be: {output_file}")
|
|
|
|
check_data(questions, model_answers, ref_answers, models, judges)
|
|
|
|
question_math = [q for q in questions if q["category"] in NEED_REF_CATS]
|
|
question_default = [q for q in questions if q["category"] not in NEED_REF_CATS]
|
|
|
|
# Make all potential matches
|
|
all_matches = []
|
|
# For default questions (non-reference based typically, unless judge is ref_based)
|
|
all_matches += make_match_func(
|
|
question_default, models, model_answers, judges["default"], baseline_model, ref_answers=ref_answers
|
|
)
|
|
# For math questions (reference based typically)
|
|
all_matches += make_match_func(
|
|
question_math,
|
|
models,
|
|
model_answers,
|
|
judges["math"],
|
|
baseline_model,
|
|
ref_answers=ref_answers, # Pass ref_answers here
|
|
)
|
|
# For multi-turn default questions
|
|
all_matches += make_match_func(
|
|
question_default,
|
|
models,
|
|
model_answers,
|
|
judges["default-mt"],
|
|
baseline_model,
|
|
ref_answers=ref_answers,
|
|
multi_turn=True,
|
|
)
|
|
# For multi-turn math questions
|
|
all_matches += make_match_func(
|
|
question_math,
|
|
models,
|
|
model_answers,
|
|
judges["math-mt"],
|
|
baseline_model,
|
|
ref_answers=ref_answers, # Pass ref_answers here
|
|
multi_turn=True,
|
|
)
|
|
|
|
# ---- MODIFICATION: Load existing judgments and filter matches ----
|
|
existing_judgments = load_existing_judgments(output_file, args.mode)
|
|
|
|
matches_to_play = []
|
|
skipped_count = 0
|
|
if not all_matches:
|
|
print("No matches were generated based on the inputs. Check model answers and question categories.")
|
|
else:
|
|
for match_item in all_matches:
|
|
qid = match_item.question["question_id"]
|
|
identifier = None
|
|
if args.mode == "single":
|
|
# MatchSingle has .model attribute
|
|
identifier = (qid, match_item.model)
|
|
else:
|
|
# MatchPair has .model1 and .model2 attributes
|
|
identifier = (qid, match_item.model1, match_item.model2)
|
|
|
|
if identifier and identifier in existing_judgments:
|
|
skipped_count += 1
|
|
continue
|
|
matches_to_play.append(match_item)
|
|
|
|
print(f"Total potential matches generated: {len(all_matches)}")
|
|
if existing_judgments: # Check if the set is not empty before printing its length
|
|
print(f"Found {len(existing_judgments)} existing judgment records.")
|
|
print(f"Number of new matches to play: {len(matches_to_play)}")
|
|
print(f"Number of matches skipped (already judged): {skipped_count}")
|
|
# ---- END MODIFICATION ----
|
|
|
|
match_stat = {}
|
|
match_stat["bench_name"] = args.bench_name
|
|
match_stat["mode"] = args.mode
|
|
match_stat["judge_model"] = args.judge_model # Changed from "judge"
|
|
match_stat["baseline"] = baseline_model
|
|
match_stat["model_list"] = models
|
|
match_stat["total_num_questions"] = len(questions)
|
|
match_stat["total_potential_matches"] = len(all_matches) # Original total
|
|
match_stat["num_matches_to_play"] = len(matches_to_play) # New stat
|
|
match_stat["num_already_judged"] = skipped_count # New stat
|
|
match_stat["output_path"] = output_file
|
|
|
|
# Show match stats and prompt enter to continue
|
|
print("\nStats:")
|
|
print(json.dumps(match_stat, indent=4))
|
|
|
|
if not matches_to_play:
|
|
print("\nNo new matches to play. Exiting.")
|
|
else:
|
|
print(f"\nProceeding to judge {len(matches_to_play)} new matches.")
|
|
|
|
# Play matches
|
|
if args.parallel == 1:
|
|
for match_item_loop in tqdm(matches_to_play, desc="Playing matches"):
|
|
play_a_match_func(match_item_loop, output_file=output_file)
|
|
else:
|
|
# Inner function for parallel execution
|
|
def play_a_match_wrapper(match_wrapper_item):
|
|
play_a_match_func(match_wrapper_item, output_file=output_file)
|
|
|
|
np.random.seed(0) # Seed for reproducible shuffling
|
|
shuffled_matches_to_play = list(matches_to_play) # Make a mutable copy for shuffle
|
|
np.random.shuffle(shuffled_matches_to_play)
|
|
|
|
with ThreadPoolExecutor(args.parallel) as executor:
|
|
list(tqdm( # Consume the iterator to ensure execution and show progress
|
|
executor.map(play_a_match_wrapper, shuffled_matches_to_play),
|
|
total=len(shuffled_matches_to_play),
|
|
desc="Playing matches (parallel)"
|
|
))
|
|
print(f"\nFinished judging. Results saved to {output_file}")
|