BLEUBERI/eval/arena-hard-v2.0/gen_answer.py
2025-06-04 20:36:43 +00:00

141 lines
No EOL
4.3 KiB
Python

import argparse
import json
import os
import re
import time
import concurrent.futures
import tiktoken
import shortuuid
import tqdm
from utils.add_markdown_info import count_markdown_elements, remove_pattern
from utils.completion import (
load_questions,
load_model_answers,
make_config,
get_endpoint,
registered_api_completion,
registered_engine_completion,
reorg_answer_file,
API_ERROR_OUTPUT,
)
def get_answer(
question: dict, answer_file: str, settings: dict
):
# build messages
messages = []
if "sys_prompt" in settings:
messages.append({"role": "system", "content": settings["sys_prompt"]})
messages.append({"role": "user", "content": question["prompt"]})
# retrieve the api completion function from register
api_completion_func = registered_api_completion[settings["api_type"]]
# build arguments for api completions
kwargs = settings | {
"api_dict": get_endpoint(settings["endpoints"]),
"messages": messages,
}
output = api_completion_func(**kwargs)
if output is API_ERROR_OUTPUT:
return
messages.append({"role": "assistant", "content": output})
# Dump answers
ans = {
"uid": question["uid"],
"ans_id": shortuuid.uuid(),
"model": model,
"messages": messages,
"tstamp": time.time(),
}
encoding = tiktoken.encoding_for_model("gpt-4o")
metadata = {
"token_len": len(encoding.encode(output['answer'], disallowed_special=()))
}
ans["metadata"] = metadata | count_markdown_elements(
remove_pattern(
output['answer'],
re.compile("```([^`]*)```")
),
suffix="",
)
os.makedirs(os.path.dirname(answer_file), exist_ok=True)
with open(answer_file, "a", encoding="utf-8") as fout:
fout.write(json.dumps(ans, ensure_ascii=False) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config-file", type=str, default="config/gen_answer_config.yaml"
)
parser.add_argument(
"--endpoint-file", type=str, default="config/api_config.yaml"
)
args = parser.parse_args()
config = make_config(args.config_file)
endpoints = make_config(args.endpoint_file)
existing_answer = load_model_answers(os.path.join("data", config["bench_name"], "model_answer"))
print(config)
for model in config["model_list"]:
assert model in endpoints
endpoint_settings = endpoints[model]
question_file = os.path.join("data", config["bench_name"], "question.jsonl")
questions = load_questions(question_file)
answer_file = os.path.join("data", config["bench_name"], "model_answer", f"{model}.jsonl")
print(f"Output to {answer_file}")
if "parallel" in endpoint_settings:
parallel = endpoint_settings["parallel"]
else:
parallel = 1
if 'local_engine' in endpoint_settings and endpoint_settings['local_engine']:
local_completion_func = registered_engine_completion[endpoint_settings['api_type']]
kwargs = endpoint_settings | {
"answer_file": answer_file,
"batch_context": questions,
}
local_completion_func(**kwargs)
reorg_answer_file(answer_file)
else:
with concurrent.futures.ThreadPoolExecutor(max_workers=parallel) as executor:
futures = []
count = 0
for index, question in enumerate(questions):
if model in existing_answer and question["uid"] in existing_answer[model]:
count += 1
continue
future = executor.submit(
get_answer,
question,
answer_file,
endpoint_settings,
)
futures.append(future)
if count > 0:
print(f"{count} number of existing answers")
for future in tqdm.tqdm(
concurrent.futures.as_completed(futures), total=len(futures)
):
future.result()
reorg_answer_file(answer_file)