mirror of
https://github.com/lilakk/BLEUBERI.git
synced 2026-04-19 12:58:12 +00:00
149 lines
4.6 KiB
Python
149 lines
4.6 KiB
Python
"""Generate answers with GPT-4
|
|
|
|
Usage:
|
|
python3 gen_api_answer.py --model gpt-3.5-turbo
|
|
"""
|
|
import argparse
|
|
import json
|
|
import os
|
|
import time
|
|
import concurrent.futures
|
|
|
|
import openai
|
|
import shortuuid
|
|
import tqdm
|
|
|
|
from fastchat.llm_judge.common import (
|
|
load_questions,
|
|
temperature_config,
|
|
chat_completion_openai,
|
|
chat_completion_anthropic,
|
|
chat_completion_palm,
|
|
)
|
|
from fastchat.llm_judge.gen_model_answer import reorg_answer_file
|
|
from fastchat.model.model_adapter import get_conversation_template, ANTHROPIC_MODEL_LIST
|
|
|
|
|
|
def get_answer(
|
|
question: dict, model: str, num_choices: int, max_tokens: int, answer_file: str
|
|
):
|
|
assert (
|
|
args.force_temperature is not None and "required_temperature" in question.keys()
|
|
) == False
|
|
if args.force_temperature is not None:
|
|
temperature = args.force_temperature
|
|
elif "required_temperature" in question.keys():
|
|
temperature = question["required_temperature"]
|
|
elif question["category"] in temperature_config:
|
|
temperature = temperature_config[question["category"]]
|
|
else:
|
|
temperature = 0.7
|
|
|
|
choices = []
|
|
chat_state = None # for palm-2 model
|
|
for i in range(num_choices):
|
|
conv = get_conversation_template(model)
|
|
|
|
turns = []
|
|
for j in range(len(question["turns"])):
|
|
conv.append_message(conv.roles[0], question["turns"][j])
|
|
conv.append_message(conv.roles[1], None)
|
|
|
|
if model in ANTHROPIC_MODEL_LIST:
|
|
output = chat_completion_anthropic(model, conv, temperature, max_tokens)
|
|
elif model == "palm-2-chat-bison-001":
|
|
chat_state, output = chat_completion_palm(
|
|
chat_state, model, conv, temperature, max_tokens
|
|
)
|
|
else:
|
|
output = chat_completion_openai(model, conv, temperature, max_tokens)
|
|
|
|
conv.update_last_message(output)
|
|
turns.append(output)
|
|
|
|
choices.append({"index": i, "turns": turns})
|
|
|
|
# Dump answers
|
|
ans = {
|
|
"question_id": question["question_id"],
|
|
"answer_id": shortuuid.uuid(),
|
|
"model_id": model,
|
|
"choices": choices,
|
|
"tstamp": time.time(),
|
|
}
|
|
|
|
os.makedirs(os.path.dirname(answer_file), exist_ok=True)
|
|
with open(answer_file, "a") as fout:
|
|
fout.write(json.dumps(ans) + "\n")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--bench-name",
|
|
type=str,
|
|
default="mt_bench",
|
|
help="The name of the benchmark question set.",
|
|
)
|
|
parser.add_argument("--answer-file", type=str, help="The output answer file.")
|
|
parser.add_argument("--model", type=str, default="gpt-3.5-turbo")
|
|
parser.add_argument(
|
|
"--num-choices",
|
|
type=int,
|
|
default=1,
|
|
help="How many completion choices to generate.",
|
|
)
|
|
parser.add_argument(
|
|
"--force-temperature", type=float, help="Forcibly set a sampling temperature."
|
|
)
|
|
parser.add_argument(
|
|
"--max-tokens",
|
|
type=int,
|
|
default=1024,
|
|
help="The maximum number of new generated tokens.",
|
|
)
|
|
parser.add_argument(
|
|
"--question-begin",
|
|
type=int,
|
|
help="A debug option. The begin index of questions.",
|
|
)
|
|
parser.add_argument(
|
|
"--question-end", type=int, help="A debug option. The end index of questions."
|
|
)
|
|
parser.add_argument(
|
|
"--parallel", type=int, default=1, help="The number of concurrent API calls."
|
|
)
|
|
parser.add_argument("--openai-api-base", type=str, default=None)
|
|
args = parser.parse_args()
|
|
|
|
if args.openai_api_base is not None:
|
|
openai.api_base = args.openai_api_base
|
|
|
|
question_file = f"data/{args.bench_name}/question.jsonl"
|
|
questions = load_questions(question_file, args.question_begin, args.question_end)
|
|
|
|
if args.answer_file:
|
|
answer_file = args.answer_file
|
|
else:
|
|
answer_file = f"data/{args.bench_name}/model_answer/{args.model}.jsonl"
|
|
print(f"Output to {answer_file}")
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=args.parallel) as executor:
|
|
futures = []
|
|
for question in questions:
|
|
future = executor.submit(
|
|
get_answer,
|
|
question,
|
|
args.model,
|
|
args.num_choices,
|
|
args.max_tokens,
|
|
answer_file,
|
|
)
|
|
futures.append(future)
|
|
|
|
for future in tqdm.tqdm(
|
|
concurrent.futures.as_completed(futures), total=len(futures)
|
|
):
|
|
future.result()
|
|
|
|
reorg_answer_file(answer_file)
|