mirror of
https://github.com/lilakk/BLEUBERI.git
synced 2026-04-19 12:58:12 +00:00
292 lines
9.7 KiB
Python
292 lines
9.7 KiB
Python
"""
|
|
|
|
Usage:
|
|
python3 topic_clustering.py --in arena.json --english-only --min-length 32
|
|
python3 topic_clustering.py --in clean_conv_20230809_100k.json --english-only --min-length 32 --max-length 1536
|
|
"""
|
|
import argparse
|
|
import json
|
|
import pickle
|
|
import string
|
|
import time
|
|
|
|
import numpy as np
|
|
from sentence_transformers import SentenceTransformer
|
|
from sentence_transformers.util import cos_sim
|
|
from sklearn.cluster import KMeans, AgglomerativeClustering
|
|
import torch
|
|
from tqdm import tqdm
|
|
from openai import OpenAI
|
|
|
|
from fastchat.utils import detect_language
|
|
|
|
|
|
def remove_punctuation(input_string):
|
|
# Make a translator object to remove all punctuation
|
|
translator = str.maketrans("", "", string.punctuation)
|
|
|
|
# Use the translator object to remove the punctuation
|
|
no_punct = input_string.translate(translator)
|
|
return no_punct
|
|
|
|
|
|
def read_texts(input_file, min_length, max_length, english_only):
|
|
visited = set()
|
|
texts = []
|
|
|
|
lines = json.load(open(input_file, "r"))
|
|
|
|
for l in tqdm(lines):
|
|
if "text" in l:
|
|
line_texts = [l["text"]]
|
|
elif "conversation_a" in l:
|
|
line_texts = [
|
|
x["content"] for x in l["conversation_a"] if x["role"] == "user"
|
|
]
|
|
elif "conversation" in l:
|
|
line_texts = [
|
|
x["content"] for x in l["conversation"] if x["role"] == "user"
|
|
]
|
|
elif "turns" in l:
|
|
line_texts = l["turns"]
|
|
|
|
for text in line_texts:
|
|
text = text.strip()
|
|
|
|
# Filter language
|
|
if english_only:
|
|
lang = detect_language(text)
|
|
if lang != "English":
|
|
continue
|
|
|
|
# Filter short or long prompts
|
|
if min_length:
|
|
if len(text) < min_length:
|
|
continue
|
|
|
|
if max_length:
|
|
if len(text) > max_length:
|
|
continue
|
|
|
|
# De-duplication
|
|
words = sorted([x.lower() for x in remove_punctuation(text).split(" ")])
|
|
words = "".join(words)
|
|
if words in visited:
|
|
continue
|
|
|
|
visited.add(words)
|
|
texts.append(text)
|
|
return np.array(texts)
|
|
|
|
|
|
def get_embeddings(texts, model_name, batch_size):
|
|
if model_name == "text-embedding-ada-002":
|
|
client = OpenAI()
|
|
texts = texts.tolist()
|
|
|
|
embeddings = []
|
|
for i in tqdm(range(0, len(texts), batch_size)):
|
|
text = texts[i : i + batch_size]
|
|
responses = client.embeddings.create(input=text, model=model_name).data
|
|
embeddings.extend([data.embedding for data in responses])
|
|
embeddings = torch.tensor(embeddings)
|
|
else:
|
|
model = SentenceTransformer(model_name)
|
|
embeddings = model.encode(
|
|
texts,
|
|
batch_size=batch_size,
|
|
show_progress_bar=True,
|
|
device="cuda",
|
|
convert_to_tensor=True,
|
|
)
|
|
|
|
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
|
return embeddings.cpu()
|
|
|
|
|
|
def run_k_means(embeddings, num_clusters):
|
|
np.random.seed(42)
|
|
clustering_model = KMeans(n_clusters=num_clusters, n_init="auto")
|
|
clustering_model.fit(embeddings.numpy())
|
|
centers = torch.from_numpy(clustering_model.cluster_centers_)
|
|
labels = torch.from_numpy(clustering_model.labels_)
|
|
|
|
# Sort labels
|
|
classes, counts = np.unique(labels, return_counts=True)
|
|
indices = np.argsort(counts)[::-1]
|
|
classes = [classes[i] for i in indices]
|
|
new_labels = torch.empty_like(labels)
|
|
new_centers = torch.empty_like(centers)
|
|
for i, c in enumerate(classes):
|
|
new_labels[labels == c] = i
|
|
new_centers[i] = centers[c]
|
|
return new_centers, new_labels
|
|
|
|
|
|
def run_agg_cluster(embeddings, num_clusters):
|
|
np.random.seed(42)
|
|
clustering_model = AgglomerativeClustering(n_clusters=num_clusters)
|
|
clustering_model.fit(embeddings)
|
|
labels = torch.from_numpy(clustering_model.labels_)
|
|
|
|
# Sort labels
|
|
classes, counts = np.unique(labels, return_counts=True)
|
|
indices = np.argsort(counts)[::-1]
|
|
classes = [classes[i] for i in indices]
|
|
new_labels = torch.empty_like(labels)
|
|
for i, c in enumerate(classes):
|
|
new_labels[labels == c] = i
|
|
|
|
# Compute centers
|
|
centers = []
|
|
for i in range(len(classes)):
|
|
centers.append(embeddings[new_labels == i].mean(axis=0, keepdim=True))
|
|
centers = torch.cat(centers)
|
|
return centers, new_labels
|
|
|
|
|
|
def run_hdbscan_cluster(embeddings):
|
|
import hdbscan
|
|
|
|
np.random.seed(42)
|
|
clusterer = hdbscan.HDBSCAN(min_cluster_size=10)
|
|
labels = torch.from_numpy(clusterer.fit_predict(embeddings))
|
|
|
|
# Sort labels
|
|
classes, counts = np.unique(labels, return_counts=True)
|
|
indices = np.argsort(counts)[::-1]
|
|
classes = [classes[i] for i in indices]
|
|
new_labels = torch.empty_like(labels)
|
|
for i, c in enumerate(classes):
|
|
new_labels[labels == c] = i
|
|
|
|
# Compute centers
|
|
centers = []
|
|
for i in range(len(classes)):
|
|
centers.append(embeddings[new_labels == i].mean(axis=0, keepdim=True))
|
|
centers = torch.cat(centers)
|
|
return centers, new_labels
|
|
|
|
|
|
def get_topk_indices(centers, labels, embeddings, topk):
|
|
indices = []
|
|
arange = torch.arange(len(labels))
|
|
counts = torch.unique(labels, return_counts=True)[1]
|
|
topk = min(topk, counts.min().item())
|
|
for i in range(len(centers)):
|
|
tmp_indices = labels == i
|
|
tmp_arange = arange[tmp_indices]
|
|
tmp_embeddings = embeddings[tmp_indices]
|
|
|
|
scores = cos_sim(centers[i].unsqueeze(0), tmp_embeddings)[0]
|
|
sorted_indices = torch.flip(torch.argsort(scores), dims=[0])
|
|
indices.append(tmp_arange[sorted_indices[:topk]].unsqueeze(0))
|
|
return torch.cat(indices)
|
|
|
|
|
|
def print_topk(texts, labels, topk_indices, show_cut_off):
|
|
ret = ""
|
|
for k in range(len(topk_indices)):
|
|
num_samples = torch.sum(labels == k).item()
|
|
|
|
ret += "=" * 20 + f" cluster {k}, #samples: {num_samples} " + "=" * 20 + "\n"
|
|
for idx in topk_indices[k]:
|
|
ret += "PROMPT: " + texts[idx][:show_cut_off] + "\n"
|
|
ret += "=" * 40 + "\n\n"
|
|
|
|
return ret
|
|
|
|
|
|
def get_cluster_info(texts, labels, topk_indices):
|
|
np.random.seed(42)
|
|
|
|
cluster_info = []
|
|
for k in range(len(topk_indices)):
|
|
num_samples = torch.sum(labels == k).item()
|
|
topk_prompts = []
|
|
for idx in topk_indices[k]:
|
|
topk_prompts.append(texts[idx])
|
|
random_prompts = []
|
|
for idx in range(len(topk_indices)):
|
|
random_prompts.append(np.random.choice(texts))
|
|
cluster_info.append((num_samples, topk_prompts, random_prompts))
|
|
|
|
return cluster_info
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--input-file", type=str, required=True)
|
|
parser.add_argument("--model", type=str, default="all-mpnet-base-v2")
|
|
# default="all-MiniLM-L12-v2")
|
|
# default="multi-qa-distilbert-cos-v1")
|
|
parser.add_argument("--batch-size", type=int, default=256)
|
|
parser.add_argument("--min-length", type=int)
|
|
parser.add_argument("--max-length", type=int)
|
|
parser.add_argument("--english-only", action="store_true")
|
|
parser.add_argument("--num-clusters", type=int, default=20)
|
|
parser.add_argument(
|
|
"--cluster-alg",
|
|
type=str,
|
|
choices=["kmeans", "aggcls", "HDBSCAN"],
|
|
default="kmeans",
|
|
)
|
|
parser.add_argument("--show-top-k", type=int, default=200)
|
|
parser.add_argument("--show-cut-off", type=int, default=512)
|
|
parser.add_argument("--save-embeddings", action="store_true")
|
|
parser.add_argument("--embeddings-file", type=str, default=None)
|
|
args = parser.parse_args()
|
|
|
|
num_clusters = args.num_clusters
|
|
show_top_k = args.show_top_k
|
|
show_cut_off = args.show_cut_off
|
|
|
|
texts = read_texts(
|
|
args.input_file, args.min_length, args.max_length, args.english_only
|
|
)
|
|
print(f"#text: {len(texts)}")
|
|
|
|
if args.embeddings_file is None:
|
|
embeddings = get_embeddings(texts, args.model, args.batch_size)
|
|
if args.save_embeddings:
|
|
# allow saving embedding to save time and money
|
|
torch.save(embeddings, "embeddings.pt")
|
|
else:
|
|
embeddings = torch.load(args.embeddings_file)
|
|
print(f"embeddings shape: {embeddings.shape}")
|
|
|
|
if args.cluster_alg == "kmeans":
|
|
centers, labels = run_k_means(embeddings, num_clusters)
|
|
elif args.cluster_alg == "aggcls":
|
|
centers, labels = run_agg_cluster(embeddings, num_clusters)
|
|
elif args.cluster_alg == "HDBSCAN":
|
|
centers, labels = run_hdbscan_cluster(embeddings)
|
|
else:
|
|
raise ValueError(f"Invalid clustering algorithm: {args.cluster_alg}")
|
|
|
|
topk_indices = get_topk_indices(centers, labels, embeddings, args.show_top_k)
|
|
topk_str = print_topk(texts, labels, topk_indices, args.show_cut_off)
|
|
num_clusters = len(centers)
|
|
|
|
# Dump results
|
|
filename_prefix = f"results_c{num_clusters}_{args.cluster_alg}"
|
|
print(topk_str)
|
|
with open(filename_prefix + "_topk.txt", "w") as fout:
|
|
fout.write(topk_str)
|
|
|
|
with open(filename_prefix + "_all.jsonl", "w") as fout:
|
|
for i in range(len(centers)):
|
|
tmp_indices = labels == i
|
|
tmp_embeddings = embeddings[tmp_indices]
|
|
tmp_texts = texts[tmp_indices]
|
|
|
|
scores = cos_sim(centers[i].unsqueeze(0), tmp_embeddings)[0]
|
|
sorted_indices = torch.flip(torch.argsort(scores), dims=[0])
|
|
|
|
for text, score in zip(tmp_texts[sorted_indices], scores[sorted_indices]):
|
|
obj = {"cluster": i, "text": text, "sim": score.item()}
|
|
fout.write(json.dumps(obj, ensure_ascii=False) + "\n")
|
|
|
|
cluster_info = get_cluster_info(texts, labels, topk_indices)
|
|
with open(filename_prefix + "_cluster.pkl", "wb") as fout:
|
|
pickle.dump(cluster_info, fout)
|