mirror of
https://github.com/lilakk/BLEUBERI.git
synced 2026-04-22 16:49:17 +00:00
85 lines
2.6 KiB
Python
85 lines
2.6 KiB
Python
import os
|
|
import json
|
|
import pandas as pd
|
|
import ast
|
|
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib import rcParams
|
|
|
|
import argparse
|
|
import seaborn as sns
|
|
from tqdm import tqdm
|
|
import matplotlib.pyplot as plt
|
|
|
|
import numpy as np
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--output_dir", type=str, default="output")
|
|
parser.add_argument("--model", type=str, default=None)
|
|
parser.add_argument("--input_file", type=str, required=True)
|
|
parser.add_argument("--percentile", type=float, default=0.9999)
|
|
args = parser.parse_args()
|
|
output_dir = args.output_dir
|
|
input_file = args.input_file
|
|
|
|
with open(input_file) as f:
|
|
data = json.load(f)
|
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
# Preprocessing
|
|
all_convs_new = []
|
|
convs = []
|
|
for row in data:
|
|
conv = ""
|
|
for turns in row["conversation_a"]:
|
|
if turns["role"] == "user":
|
|
conv += f"{turns['content']}\n"
|
|
|
|
convs.append(conv[:10000])
|
|
row["post_process_conv"] = conv[:10000]
|
|
all_convs_new.append(row)
|
|
|
|
df = pd.DataFrame(all_convs_new)
|
|
print("Number of conversations: ", len(df))
|
|
|
|
prompt_counts = df["post_process_conv"].value_counts()
|
|
# Select the top 20 most frequent prompts
|
|
top_prompts = prompt_counts.head(20)
|
|
print(top_prompts)
|
|
|
|
# Determine the percentile count
|
|
percentile_cutoff = prompt_counts.quantile(args.percentile)
|
|
print(f"{args.percentile*100} percentile count: {percentile_cutoff}")
|
|
|
|
# prompts that are more common than the percentile cutoff
|
|
high_frequency_prompts = prompt_counts[prompt_counts > percentile_cutoff].index
|
|
print(
|
|
f"Number of high frequency prompts: {len(high_frequency_prompts)}/{len(prompt_counts)}"
|
|
)
|
|
|
|
# initialize a new column dedup_tag
|
|
dedup_tags = np.array(
|
|
[{"high_freq": False, "sampled": True} for _ in range(len(df))]
|
|
)
|
|
high_freq_groups = df.groupby("post_process_conv")
|
|
for prompt in tqdm(high_frequency_prompts):
|
|
df_high_freq = high_freq_groups.get_group(prompt)
|
|
sampled_indices = df_high_freq.sample(
|
|
n=int(percentile_cutoff), random_state=42
|
|
).index
|
|
dedup_tags[df_high_freq.index] = {"high_freq": True, "sampled": False}
|
|
dedup_tags[sampled_indices] = {"high_freq": True, "sampled": True}
|
|
|
|
df["dedup_tag"] = dedup_tags
|
|
|
|
# drop intermediate columns (post_process_conv)
|
|
df = df.drop(columns=["post_process_conv"])
|
|
|
|
df.to_json(
|
|
os.path.join(output_dir, "dedup.json"),
|
|
orient="records",
|
|
indent=4,
|
|
force_ascii=False,
|
|
)
|