atropos/atroposlib/cli/dpo.py

324 lines
11 KiB
Python

import argparse
import asyncio
import os
import random
import aiohttp
import jsonlines
from tqdm.asyncio import tqdm # Import tqdm for async
from transformers import AutoTokenizer
from atroposlib.utils.io import parse_http_response
def find_common_prefix(strings):
"""
Finds the longest common prefix among a list of strings.
Args:
strings: A list of strings.
Returns:
The longest common prefix string, or an empty string if the list is empty
or no common prefix exists.
"""
if not strings:
return ""
prefix = strings[0]
for s in strings[1:]:
while not s.startswith(prefix):
prefix = prefix[:-1]
if not prefix:
return ""
return prefix
async def register_to_api(group_size, max_token_len, api_url, num_steps):
"""
Registers this data grabber instance with the Atropos API.
This involves resetting any previous data on the server and then sending
configuration parameters for the current session.
Args:
group_size: The number of sequences processed per group by the API.
max_token_len: The maximum token length for sequences.
api_url: The base URL of the Atropos API server.
num_steps: The number of steps to run the API for.
"""
async with aiohttp.ClientSession() as session:
# Reset data on the API server before registering
async with session.get(f"{api_url}/reset_data") as response:
print(await response.text())
# Register this instance with its configuration
async with session.post(
f"{api_url}/register",
json={
"wandb_group": "test",
"wandb_project": "test",
"batch_size": group_size * 8,
"max_token_len": max_token_len,
"checkpoint_dir": "checkpoints",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": num_steps * 2, # For a bit of a buffer just in case
},
) as response:
print("output of register is")
print(await response.text())
async def check_for_batch(api_url):
"""
Continuously polls the Atropos API until a batch of data is available.
Args:
api_url: The base URL of the Atropos API server.
Returns:
The batch data received from the API.
"""
while True:
async with aiohttp.ClientSession() as session:
async with session.get(f"{api_url}/batch") as response:
data = await parse_http_response(response)
if data["batch"] is not None:
return data["batch"]
await asyncio.sleep(1) # Wait before polling again
def grab_group_data(
tok,
datagroup,
save_messages,
save_n_pairs_per_group,
allow_negative_scores=False,
minimum_score_diff_max_min=0.0,
):
"""
Processes a single group of data received from the API.
This function sorts the sequences within the group by score, filters them
based on scoring criteria, and formats them for saving.
Args:
tok: The Hugging Face tokenizer instance.
datagroup: A dictionary representing a group of sequences and their scores.
save_messages: Boolean indicating whether to save raw message structures
or decoded text completions.
save_n_pairs_per_group: The maximum number of sequences to save from this group.
allow_negative_scores: Boolean indicating whether to allow sequences with
negative scores.
minimum_score_diff_max_min: The minimum score difference required to save a pair.
Returns:
A list of processed and filtered sequences from the group, ready to be
written to the output file.
"""
if save_messages:
chats = datagroup["messages"]
else:
chats = [tok.decode(chat) for chat in datagroup["tokens"]]
# find common prefix
prefix = find_common_prefix(chats)
chats = [(prefix, chat.split(prefix)[1]) for chat in chats]
# sort chats by scores
scores = datagroup["scores"]
sorted_chats = [
(
{"prefix": x[0], "pos": x[1], "score": score}
if not save_messages
else {"pos": x, "score": score}
)
for score, x in sorted(
zip(scores, chats), key=lambda pair: pair[0], reverse=True
)
]
neg_sorted_chats = [
(
{"prefix": x[0], "completion": x[1], "score": score}
if not save_messages
else {"messages": x, "score": score}
)
for score, x in sorted(
zip(scores, chats), key=lambda pair: pair[0], reverse=False
)
]
neg_sorted_chats = neg_sorted_chats[:save_n_pairs_per_group]
if not allow_negative_scores:
sorted_chats = [x for x in sorted_chats if x["score"] > 0]
total_pairs = []
for i in range(min(save_n_pairs_per_group, len(sorted_chats))):
neg_candidates = [
x
for x in neg_sorted_chats
if x["score"] < sorted_chats[i]["score"] - minimum_score_diff_max_min
]
if len(neg_candidates) > 0:
if save_n_pairs_per_group > 0:
neg_candidate = random.choice(neg_candidates)
else:
neg_candidate = neg_sorted_chats[0] # worst negative candidate
# remove from neg_sorted_chats
neg_sorted_chats.remove(neg_candidate)
sorted_chats[i]["neg"] = (
neg_candidate["completion"]
if "completion" in neg_candidate
else neg_candidate["messages"]
)
total_pairs.append(sorted_chats[i])
return total_pairs
async def dpo_data_grabber(
filepath,
api_url,
group_size,
max_token_len,
tokenizer,
save_messages,
save_n_pairs_per_group,
num_seqs_to_save,
allow_negative_scores,
minimum_score_diff_max_min,
append_to_previous,
):
"""
Main asynchronous function to grab DPO data from the Atropos API.
It registers with the API, continuously fetches batches of data, processes
each batch, and writes the selected sequences to a JSONL file until the
desired number of sequences is saved.
Args:
filepath: Path to the output JSONL file.
api_url: Base URL of the Atropos API server.
group_size: Number of sequences processed per group by the API.
max_token_len: Maximum token length for sequences.
tokenizer: Hugging Face tokenizer model ID.
save_messages: Whether to save raw messages or decoded text.
save_n_pairs_per_group: Max sequences to save per group.
num_seqs_to_save: Total number of sequences to save.
allow_negative_scores: Whether to allow negative scores.
minimum_score_diff_max_min: Min score difference from group minimum.
append_to_previous: Whether to append to an existing file or overwrite.
"""
tok = AutoTokenizer.from_pretrained(tokenizer)
total_count = 0
async def grab_batch(jsonl_writer: jsonlines.Writer):
data = await check_for_batch(api_url)
count = 0
for group in data:
for item in grab_group_data(
tok,
group,
save_messages,
save_n_pairs_per_group,
allow_negative_scores,
minimum_score_diff_max_min,
):
jsonl_writer.write(item)
count += 1
return count
await register_to_api(group_size, max_token_len, api_url)
if os.path.exists(filepath) and not append_to_previous:
raise ValueError("File already exists and append_to_previous is False.")
with open(filepath, "w" if not append_to_previous else "a") as f:
jsonl_writer = jsonlines.Writer(f)
with tqdm(total=num_seqs_to_save, desc="Grabbing DPO data", unit="seq") as pbar:
while total_count < num_seqs_to_save:
batch_count = await grab_batch(jsonl_writer)
total_count += batch_count
pbar.update(min(batch_count, num_seqs_to_save - total_count))
def main():
parser = argparse.ArgumentParser(
description="Grab SFT data from an Atropos API instance."
)
parser.add_argument(
"filepath",
type=str,
default="sft_data.jsonl",
help="Path to the output JSONL file for SFT data.",
)
parser.add_argument(
"--api-url",
type=str,
default="http://localhost:8000",
help="Base URL for the Atropos API server.",
)
parser.add_argument(
"--group-size",
type=int,
default=2,
help="Number of sequences processed per group by the API.",
)
parser.add_argument(
"--max-token-len",
type=int,
default=2048,
help="Maximum token length for sequences.",
)
parser.add_argument(
"--tokenizer",
type=str,
default="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
help="Hugging Face tokenizer model ID (used if --save-messages is not set).",
)
parser.add_argument(
"--save-messages",
action="store_true",
help="Save raw message structures instead of decoded text completions, if your environment supports it.",
)
parser.add_argument(
"--save-n-pairs-per-group",
type=int,
default=3,
help="Maximum number of paired sequences to save from each group.",
)
parser.add_argument(
"--num-seqs-to-save",
type=int,
default=100,
help="Total number of sequences to save before stopping.",
)
parser.add_argument(
"--allow-negative-scores",
action="store_true",
help="Allow sequences with negative scores to be saved.",
)
parser.add_argument(
"--minimum-score-diff-max-min",
type=float,
default=0.5,
help="Minimum score difference from the group minimum required to save a sequence.",
)
parser.add_argument(
"--append-to-previous",
action="store_true",
help="Append to the previous file instead of overwriting it.",
)
args = parser.parse_args()
asyncio.run(
dpo_data_grabber(
args.filepath,
args.api_url,
args.group_size,
args.max_token_len,
args.tokenizer,
args.save_messages,
args.save_n_pairs_per_group,
args.num_seqs_to_save,
args.allow_negative_scores,
args.minimum_score_diff_max_min,
args.append_to_previous,
)
)
if __name__ == "__main__":
main()