mirror of
https://github.com/lilakk/BLEUBERI.git
synced 2026-04-19 12:58:12 +00:00
423 lines
12 KiB
Python
423 lines
12 KiB
Python
"""
|
|
Clean chatbot arena battle log.
|
|
|
|
Usage:
|
|
python3 clean_battle_data.py --mode conv_release
|
|
"""
|
|
import argparse
|
|
import datetime
|
|
import json
|
|
import os
|
|
from pytz import timezone
|
|
import time
|
|
|
|
from tqdm import tqdm
|
|
from multiprocessing import Pool
|
|
import tiktoken
|
|
from collections import Counter
|
|
import shortuuid
|
|
|
|
from fastchat.serve.monitor.basic_stats import get_log_files, NUM_SERVERS
|
|
from fastchat.utils import detect_language
|
|
|
|
|
|
VOTES = ["tievote", "leftvote", "rightvote", "bothbad_vote"]
|
|
IDENTITY_WORDS = [
|
|
"vicuna",
|
|
"lmsys",
|
|
"koala",
|
|
"uc berkeley",
|
|
"open assistant",
|
|
"laion",
|
|
"chatglm",
|
|
"chatgpt",
|
|
"gpt-4",
|
|
"openai",
|
|
"anthropic",
|
|
"claude",
|
|
"bard",
|
|
"palm",
|
|
"lamda",
|
|
"google",
|
|
"gemini",
|
|
"llama",
|
|
"qianwan",
|
|
"qwen",
|
|
"alibaba",
|
|
"mistral",
|
|
"zhipu",
|
|
"KEG lab",
|
|
"01.AI",
|
|
"AI2",
|
|
"Tülu",
|
|
"Tulu",
|
|
"deepseek",
|
|
"hermes",
|
|
"cohere",
|
|
"DBRX",
|
|
"databricks",
|
|
]
|
|
|
|
ERROR_WORDS = [
|
|
"NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.",
|
|
"$MODERATION$ YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES.",
|
|
"API REQUEST ERROR. Please increase the number of max tokens.",
|
|
"**API REQUEST ERROR** Reason: The response was blocked.",
|
|
"**API REQUEST ERROR**",
|
|
]
|
|
|
|
UNFINISHED_WORDS = [
|
|
"▌",
|
|
'<span class="cursor">',
|
|
]
|
|
|
|
for i in range(len(IDENTITY_WORDS)):
|
|
IDENTITY_WORDS[i] = IDENTITY_WORDS[i].lower()
|
|
|
|
for i in range(len(ERROR_WORDS)):
|
|
ERROR_WORDS[i] = ERROR_WORDS[i].lower()
|
|
|
|
|
|
def remove_html(raw):
|
|
if isinstance(raw, str) and raw.startswith("<h3>"):
|
|
return raw[raw.find(": ") + 2 : -len("</h3>\n")]
|
|
return raw
|
|
|
|
|
|
def to_openai_format(messages):
|
|
roles = ["user", "assistant"]
|
|
ret = []
|
|
for i, x in enumerate(messages):
|
|
ret.append({"role": roles[i % 2], "content": x[1]})
|
|
return ret
|
|
|
|
|
|
def replace_model_name(old_name, tstamp):
|
|
replace_dict = {
|
|
"bard": "palm-2",
|
|
"claude-v1": "claude-1",
|
|
"claude-instant-v1": "claude-instant-1",
|
|
"oasst-sft-1-pythia-12b": "oasst-pythia-12b",
|
|
"claude-2": "claude-2.0",
|
|
"StripedHyena-Nous-7B": "stripedhyena-nous-7b",
|
|
"gpt-4-turbo": "gpt-4-1106-preview",
|
|
"gpt-4-0125-assistants-api": "gpt-4-turbo-browsing",
|
|
}
|
|
if old_name in ["gpt-4", "gpt-3.5-turbo"]:
|
|
if tstamp > 1687849200:
|
|
return old_name + "-0613"
|
|
else:
|
|
return old_name + "-0314"
|
|
if old_name in replace_dict:
|
|
return replace_dict[old_name]
|
|
return old_name
|
|
|
|
|
|
def read_file(filename):
|
|
data = []
|
|
for retry in range(5):
|
|
try:
|
|
# lines = open(filename).readlines()
|
|
for l in open(filename):
|
|
row = json.loads(l)
|
|
if row["type"] in VOTES:
|
|
data.append(row)
|
|
break
|
|
except FileNotFoundError:
|
|
time.sleep(2)
|
|
return data
|
|
|
|
|
|
def read_file_parallel(log_files, num_threads=16):
|
|
data_all = []
|
|
with Pool(num_threads) as p:
|
|
ret_all = list(tqdm(p.imap(read_file, log_files), total=len(log_files)))
|
|
for ret in ret_all:
|
|
data_all.extend(ret)
|
|
return data_all
|
|
|
|
|
|
def process_data(
|
|
data,
|
|
exclude_model_names,
|
|
sanitize_ip,
|
|
ban_ip_list,
|
|
):
|
|
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
|
convert_type = {
|
|
"leftvote": "model_a",
|
|
"rightvote": "model_b",
|
|
"tievote": "tie",
|
|
"bothbad_vote": "tie (bothbad)",
|
|
}
|
|
|
|
all_ips = dict()
|
|
|
|
count_dict = {
|
|
"anony": 0,
|
|
"invalid": 0,
|
|
"leaked_identity": 0,
|
|
"banned": 0,
|
|
"error": 0,
|
|
"unfinished": 0,
|
|
"none_msg": 0,
|
|
"exclude_model": 0,
|
|
}
|
|
count_leak = {}
|
|
|
|
battles = []
|
|
for row in data:
|
|
flag_anony = False
|
|
flag_leaked_identity = False
|
|
flag_error = False
|
|
flag_unfinished = False
|
|
flag_none_msg = False
|
|
|
|
if row["models"][0] is None or row["models"][1] is None:
|
|
continue
|
|
|
|
# Resolve model names
|
|
models_public = [remove_html(row["models"][0]), remove_html(row["models"][1])]
|
|
if "model_name" in row["states"][0]:
|
|
models_hidden = [
|
|
row["states"][0]["model_name"],
|
|
row["states"][1]["model_name"],
|
|
]
|
|
if models_hidden[0] is None:
|
|
models_hidden = models_public
|
|
else:
|
|
models_hidden = models_public
|
|
|
|
if (models_public[0] == "" and models_public[1] != "") or (
|
|
models_public[1] == "" and models_public[0] != ""
|
|
):
|
|
count_dict["invalid"] += 1
|
|
continue
|
|
|
|
if models_public[0] == "" or models_public[0] == "Model A":
|
|
flag_anony = True
|
|
models = models_hidden
|
|
else:
|
|
flag_anony = False
|
|
models = models_public
|
|
if (
|
|
models_hidden[0] not in models_public[0]
|
|
or models_hidden[1] not in models_public[1]
|
|
):
|
|
count_dict["invalid"] += 1
|
|
continue
|
|
|
|
# Detect langauge
|
|
state = row["states"][0]
|
|
if state["offset"] >= len(state["messages"]):
|
|
count_dict["invalid"] += 1
|
|
continue
|
|
lang_code = detect_language(state["messages"][state["offset"]][1])
|
|
|
|
# Drop conversations if the model names are leaked
|
|
messages = ""
|
|
for i in range(2):
|
|
state = row["states"][i]
|
|
for _, (role, msg) in enumerate(state["messages"][state["offset"] :]):
|
|
if msg:
|
|
messages += msg.lower()
|
|
else:
|
|
flag_none_msg = True
|
|
|
|
for word in IDENTITY_WORDS:
|
|
if word in messages:
|
|
if word not in count_leak:
|
|
count_leak[word] = 0
|
|
count_leak[word] += 1
|
|
flag_leaked_identity = True
|
|
break
|
|
|
|
for word in ERROR_WORDS:
|
|
if word in messages:
|
|
flag_error = True
|
|
break
|
|
|
|
for word in UNFINISHED_WORDS:
|
|
if word in messages:
|
|
flag_unfinished = True
|
|
break
|
|
|
|
if flag_none_msg:
|
|
count_dict["none_msg"] += 1
|
|
continue
|
|
if flag_leaked_identity:
|
|
count_dict["leaked_identity"] += 1
|
|
continue
|
|
if flag_error:
|
|
count_dict["error"] += 1
|
|
continue
|
|
if flag_unfinished:
|
|
count_dict["unfinished"] += 1
|
|
continue
|
|
|
|
# Replace bard with palm
|
|
models = [replace_model_name(m, row["tstamp"]) for m in models]
|
|
# Exclude certain models
|
|
if exclude_model_names and any(x in exclude_model_names for x in models):
|
|
count_dict["exclude_model"] += 1
|
|
continue
|
|
|
|
question_id = row["states"][0]["conv_id"]
|
|
conversation_a = to_openai_format(
|
|
row["states"][0]["messages"][row["states"][0]["offset"] :]
|
|
)
|
|
conversation_b = to_openai_format(
|
|
row["states"][1]["messages"][row["states"][1]["offset"] :]
|
|
)
|
|
|
|
ip = row["ip"]
|
|
if ip not in all_ips:
|
|
all_ips[ip] = {"ip": ip, "count": 0, "sanitized_id": shortuuid.uuid()}
|
|
all_ips[ip]["count"] += 1
|
|
if sanitize_ip:
|
|
user_id = f"{all_ips[ip]['sanitized_id']}"
|
|
else:
|
|
user_id = f"{all_ips[ip]['ip']}"
|
|
|
|
if ban_ip_list is not None and ip in ban_ip_list:
|
|
count_dict["banned"] += 1
|
|
continue
|
|
|
|
if flag_anony:
|
|
count_dict["anony"] += 1
|
|
|
|
for conv in conversation_a:
|
|
conv["num_tokens"] = len(
|
|
encoding.encode(conv["content"], allowed_special="all")
|
|
)
|
|
for conv in conversation_b:
|
|
conv["num_tokens"] = len(
|
|
encoding.encode(conv["content"], allowed_special="all")
|
|
)
|
|
|
|
# Save the results
|
|
battles.append(
|
|
dict(
|
|
question_id=question_id,
|
|
model_a=models[0],
|
|
model_b=models[1],
|
|
winner=convert_type[row["type"]],
|
|
judge=f"arena_user_{user_id}",
|
|
conversation_a=conversation_a,
|
|
conversation_b=conversation_b,
|
|
turn=len(conversation_a) // 2,
|
|
anony=flag_anony,
|
|
language=lang_code,
|
|
tstamp=row["tstamp"],
|
|
)
|
|
)
|
|
return battles, count_dict, count_leak, all_ips
|
|
|
|
|
|
def clean_battle_data(
|
|
log_files,
|
|
exclude_model_names,
|
|
ban_ip_list=None,
|
|
sanitize_ip=False,
|
|
anony_only=False,
|
|
num_threads=16,
|
|
):
|
|
data = read_file_parallel(log_files, num_threads=16)
|
|
|
|
battles = []
|
|
count_dict = {}
|
|
count_leak = {}
|
|
all_ips = {}
|
|
with Pool(num_threads) as p:
|
|
# split data into chunks
|
|
chunk_size = len(data) // min(100, len(data))
|
|
data_chunks = [
|
|
data[i : i + chunk_size] for i in range(0, len(data), chunk_size)
|
|
]
|
|
|
|
args_list = [
|
|
(data_chunk, exclude_model_names, sanitize_ip, ban_ip_list)
|
|
for data_chunk in data_chunks
|
|
]
|
|
ret_all = list(tqdm(p.starmap(process_data, args_list), total=len(data_chunks)))
|
|
|
|
for ret in ret_all:
|
|
sub_battles, sub_count_dict, sub_count_leak, sub_all_ips = ret
|
|
battles.extend(sub_battles)
|
|
count_dict = dict(Counter(count_dict) + Counter(sub_count_dict))
|
|
count_leak = dict(Counter(count_leak) + Counter(sub_count_leak))
|
|
for ip in sub_all_ips:
|
|
if ip not in all_ips:
|
|
all_ips[ip] = sub_all_ips[ip]
|
|
else:
|
|
all_ips[ip]["count"] += sub_all_ips[ip]["count"]
|
|
battles.sort(key=lambda x: x["tstamp"])
|
|
last_updated_tstamp = battles[-1]["tstamp"]
|
|
|
|
last_updated_datetime = datetime.datetime.fromtimestamp(
|
|
last_updated_tstamp, tz=timezone("US/Pacific")
|
|
).strftime("%Y-%m-%d %H:%M:%S %Z")
|
|
|
|
print(f"#votes: {len(data)}")
|
|
print(count_dict)
|
|
print(f"#battles: {len(battles)}, #anony: {count_dict['anony']}")
|
|
print(f"last-updated: {last_updated_datetime}")
|
|
print(f"leaked_identity: {count_leak}")
|
|
|
|
if ban_ip_list is not None:
|
|
for ban_ip in ban_ip_list:
|
|
if ban_ip in all_ips:
|
|
del all_ips[ban_ip]
|
|
print("Top 30 IPs:")
|
|
print(sorted(all_ips.values(), key=lambda x: x["count"], reverse=True)[:30])
|
|
return battles
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--max-num-files", type=int)
|
|
parser.add_argument(
|
|
"--mode", type=str, choices=["simple", "conv_release"], default="simple"
|
|
)
|
|
parser.add_argument("--exclude-model-names", type=str, nargs="+")
|
|
parser.add_argument("--ban-ip-file", type=str)
|
|
parser.add_argument("--sanitize-ip", action="store_true", default=False)
|
|
args = parser.parse_args()
|
|
|
|
log_files = get_log_files(args.max_num_files)
|
|
ban_ip_list = json.load(open(args.ban_ip_file)) if args.ban_ip_file else None
|
|
|
|
battles = clean_battle_data(
|
|
log_files, args.exclude_model_names or [], ban_ip_list, args.sanitize_ip
|
|
)
|
|
last_updated_tstamp = battles[-1]["tstamp"]
|
|
cutoff_date = datetime.datetime.fromtimestamp(
|
|
last_updated_tstamp, tz=timezone("US/Pacific")
|
|
).strftime("%Y%m%d")
|
|
|
|
if args.mode == "simple":
|
|
for x in battles:
|
|
for key in [
|
|
"conversation_a",
|
|
"conversation_b",
|
|
"question_id",
|
|
]:
|
|
del x[key]
|
|
print("Samples:")
|
|
for i in range(4):
|
|
print(battles[i])
|
|
output = f"clean_battle_{cutoff_date}.json"
|
|
elif args.mode == "conv_release":
|
|
new_battles = []
|
|
for x in battles:
|
|
if not x["anony"]:
|
|
continue
|
|
for key in []:
|
|
del x[key]
|
|
new_battles.append(x)
|
|
battles = new_battles
|
|
output = f"clean_battle_conv_{cutoff_date}.json"
|
|
|
|
with open(output, "w", encoding="utf-8", errors="replace") as fout:
|
|
json.dump(battles, fout, indent=2, ensure_ascii=False)
|
|
print(f"Write cleaned data to {output}")
|