mirror of
https://github.com/lilakk/BLEUBERI.git
synced 2026-04-30 17:40:49 +00:00
141 lines
No EOL
4.3 KiB
Python
141 lines
No EOL
4.3 KiB
Python
import argparse
|
|
import json
|
|
import os
|
|
import re
|
|
import time
|
|
import concurrent.futures
|
|
|
|
import tiktoken
|
|
import shortuuid
|
|
import tqdm
|
|
|
|
from utils.add_markdown_info import count_markdown_elements, remove_pattern
|
|
from utils.completion import (
|
|
load_questions,
|
|
load_model_answers,
|
|
make_config,
|
|
get_endpoint,
|
|
registered_api_completion,
|
|
registered_engine_completion,
|
|
reorg_answer_file,
|
|
API_ERROR_OUTPUT,
|
|
)
|
|
|
|
|
|
def get_answer(
|
|
question: dict, answer_file: str, settings: dict
|
|
):
|
|
# build messages
|
|
messages = []
|
|
if "sys_prompt" in settings:
|
|
messages.append({"role": "system", "content": settings["sys_prompt"]})
|
|
|
|
messages.append({"role": "user", "content": question["prompt"]})
|
|
|
|
# retrieve the api completion function from register
|
|
api_completion_func = registered_api_completion[settings["api_type"]]
|
|
|
|
# build arguments for api completions
|
|
kwargs = settings | {
|
|
"api_dict": get_endpoint(settings["endpoints"]),
|
|
"messages": messages,
|
|
}
|
|
output = api_completion_func(**kwargs)
|
|
if output is API_ERROR_OUTPUT:
|
|
return
|
|
|
|
messages.append({"role": "assistant", "content": output})
|
|
|
|
# Dump answers
|
|
ans = {
|
|
"uid": question["uid"],
|
|
"ans_id": shortuuid.uuid(),
|
|
"model": model,
|
|
"messages": messages,
|
|
"tstamp": time.time(),
|
|
}
|
|
|
|
encoding = tiktoken.encoding_for_model("gpt-4o")
|
|
metadata = {
|
|
"token_len": len(encoding.encode(output['answer'], disallowed_special=()))
|
|
}
|
|
ans["metadata"] = metadata | count_markdown_elements(
|
|
remove_pattern(
|
|
output['answer'],
|
|
re.compile("```([^`]*)```")
|
|
),
|
|
suffix="",
|
|
)
|
|
|
|
os.makedirs(os.path.dirname(answer_file), exist_ok=True)
|
|
with open(answer_file, "a", encoding="utf-8") as fout:
|
|
fout.write(json.dumps(ans, ensure_ascii=False) + "\n")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--config-file", type=str, default="config/gen_answer_config.yaml"
|
|
)
|
|
parser.add_argument(
|
|
"--endpoint-file", type=str, default="config/api_config.yaml"
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
config = make_config(args.config_file)
|
|
endpoints = make_config(args.endpoint_file)
|
|
|
|
existing_answer = load_model_answers(os.path.join("data", config["bench_name"], "model_answer"))
|
|
|
|
print(config)
|
|
|
|
for model in config["model_list"]:
|
|
assert model in endpoints
|
|
endpoint_settings = endpoints[model]
|
|
|
|
question_file = os.path.join("data", config["bench_name"], "question.jsonl")
|
|
questions = load_questions(question_file)
|
|
|
|
answer_file = os.path.join("data", config["bench_name"], "model_answer", f"{model}.jsonl")
|
|
print(f"Output to {answer_file}")
|
|
|
|
if "parallel" in endpoint_settings:
|
|
parallel = endpoint_settings["parallel"]
|
|
else:
|
|
parallel = 1
|
|
|
|
if 'local_engine' in endpoint_settings and endpoint_settings['local_engine']:
|
|
local_completion_func = registered_engine_completion[endpoint_settings['api_type']]
|
|
|
|
kwargs = endpoint_settings | {
|
|
"answer_file": answer_file,
|
|
"batch_context": questions,
|
|
}
|
|
local_completion_func(**kwargs)
|
|
|
|
reorg_answer_file(answer_file)
|
|
|
|
else:
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=parallel) as executor:
|
|
futures = []
|
|
count = 0
|
|
for index, question in enumerate(questions):
|
|
if model in existing_answer and question["uid"] in existing_answer[model]:
|
|
count += 1
|
|
continue
|
|
future = executor.submit(
|
|
get_answer,
|
|
question,
|
|
answer_file,
|
|
endpoint_settings,
|
|
)
|
|
futures.append(future)
|
|
if count > 0:
|
|
print(f"{count} number of existing answers")
|
|
for future in tqdm.tqdm(
|
|
concurrent.futures.as_completed(futures), total=len(futures)
|
|
):
|
|
future.result()
|
|
|
|
reorg_answer_file(answer_file)
|
|
|