first commit

This commit is contained in:
Dakota Nous 2025-04-29 12:10:10 -07:00
commit 621d00dd80
89 changed files with 15315 additions and 0 deletions

322
atroposlib/cli/dpo.py Normal file
View file

@ -0,0 +1,322 @@
import argparse
import asyncio
import os
import random
import aiohttp
import jsonlines
from tqdm.asyncio import tqdm # Import tqdm for async
from transformers import AutoTokenizer
def find_common_prefix(strings):
"""
Finds the longest common prefix among a list of strings.
Args:
strings: A list of strings.
Returns:
The longest common prefix string, or an empty string if the list is empty
or no common prefix exists.
"""
if not strings:
return ""
prefix = strings[0]
for s in strings[1:]:
while not s.startswith(prefix):
prefix = prefix[:-1]
if not prefix:
return ""
return prefix
async def register_to_api(group_size, max_token_len, api_url, num_steps):
"""
Registers this data grabber instance with the Atropos API.
This involves resetting any previous data on the server and then sending
configuration parameters for the current session.
Args:
group_size: The number of sequences processed per group by the API.
max_token_len: The maximum token length for sequences.
api_url: The base URL of the Atropos API server.
num_steps: The number of steps to run the API for.
"""
async with aiohttp.ClientSession() as session:
# Reset data on the API server before registering
async with session.get(f"{api_url}/reset_data") as response:
print(await response.text())
# Register this instance with its configuration
async with session.post(
f"{api_url}/register",
json={
"wandb_group": "test",
"wandb_project": "test",
"batch_size": group_size * 8,
"max_token_len": max_token_len,
"checkpoint_dir": "checkpoints",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": num_steps * 2, # For a bit of a buffer just in case
},
) as response:
print("output of register is")
print(await response.text())
async def check_for_batch(api_url):
"""
Continuously polls the Atropos API until a batch of data is available.
Args:
api_url: The base URL of the Atropos API server.
Returns:
The batch data received from the API.
"""
while True:
async with aiohttp.ClientSession() as session:
async with session.get(f"{api_url}/batch") as response:
data = await response.json()
if data["batch"] is not None:
return data["batch"]
await asyncio.sleep(1) # Wait before polling again
def grab_group_data(
tok,
datagroup,
save_messages,
save_n_pairs_per_group,
allow_negative_scores=False,
minimum_score_diff_max_min=0.0,
):
"""
Processes a single group of data received from the API.
This function sorts the sequences within the group by score, filters them
based on scoring criteria, and formats them for saving.
Args:
tok: The Hugging Face tokenizer instance.
datagroup: A dictionary representing a group of sequences and their scores.
save_messages: Boolean indicating whether to save raw message structures
or decoded text completions.
save_n_pairs_per_group: The maximum number of sequences to save from this group.
allow_negative_scores: Boolean indicating whether to allow sequences with
negative scores.
minimum_score_diff_max_min: The minimum score difference required to save a pair.
Returns:
A list of processed and filtered sequences from the group, ready to be
written to the output file.
"""
if save_messages:
chats = datagroup["messages"]
else:
chats = [tok.decode(chat) for chat in datagroup["tokens"]]
# find common prefix
prefix = find_common_prefix(chats)
chats = [(prefix, chat.split(prefix)[1]) for chat in chats]
# sort chats by scores
scores = datagroup["scores"]
sorted_chats = [
(
{"prefix": x[0], "pos": x[1], "score": score}
if not save_messages
else {"pos": x, "score": score}
)
for score, x in sorted(
zip(scores, chats), key=lambda pair: pair[0], reverse=True
)
]
neg_sorted_chats = [
(
{"prefix": x[0], "completion": x[1], "score": score}
if not save_messages
else {"messages": x, "score": score}
)
for score, x in sorted(
zip(scores, chats), key=lambda pair: pair[0], reverse=False
)
]
neg_sorted_chats = neg_sorted_chats[:save_n_pairs_per_group]
if not allow_negative_scores:
sorted_chats = [x for x in sorted_chats if x["score"] > 0]
total_pairs = []
for i in range(min(save_n_pairs_per_group, len(sorted_chats))):
neg_candidates = [
x
for x in neg_sorted_chats
if x["score"] < sorted_chats[i]["score"] - minimum_score_diff_max_min
]
if len(neg_candidates) > 0:
if save_n_pairs_per_group > 0:
neg_candidate = random.choice(neg_candidates)
else:
neg_candidate = neg_sorted_chats[0] # worst negative candidate
# remove from neg_sorted_chats
neg_sorted_chats.remove(neg_candidate)
sorted_chats[i]["neg"] = (
neg_candidate["completion"]
if "completion" in neg_candidate
else neg_candidate["messages"]
)
total_pairs.append(sorted_chats[i])
return total_pairs
async def dpo_data_grabber(
filepath,
api_url,
group_size,
max_token_len,
tokenizer,
save_messages,
save_n_pairs_per_group,
num_seqs_to_save,
allow_negative_scores,
minimum_score_diff_max_min,
append_to_previous,
):
"""
Main asynchronous function to grab DPO data from the Atropos API.
It registers with the API, continuously fetches batches of data, processes
each batch, and writes the selected sequences to a JSONL file until the
desired number of sequences is saved.
Args:
filepath: Path to the output JSONL file.
api_url: Base URL of the Atropos API server.
group_size: Number of sequences processed per group by the API.
max_token_len: Maximum token length for sequences.
tokenizer: Hugging Face tokenizer model ID.
save_messages: Whether to save raw messages or decoded text.
save_n_pairs_per_group: Max sequences to save per group.
num_seqs_to_save: Total number of sequences to save.
allow_negative_scores: Whether to allow negative scores.
minimum_score_diff_max_min: Min score difference from group minimum.
append_to_previous: Whether to append to an existing file or overwrite.
"""
tok = AutoTokenizer.from_pretrained(tokenizer)
total_count = 0
async def grab_batch(jsonl_writer: jsonlines.Writer):
data = await check_for_batch(api_url)
count = 0
for group in data:
for item in grab_group_data(
tok,
group,
save_messages,
save_n_pairs_per_group,
allow_negative_scores,
minimum_score_diff_max_min,
):
jsonl_writer.write(item)
count += 1
return count
await register_to_api(group_size, max_token_len, api_url)
if os.path.exists(filepath) and not append_to_previous:
raise ValueError("File already exists and append_to_previous is False.")
with open(filepath, "w" if not append_to_previous else "a") as f:
jsonl_writer = jsonlines.Writer(f)
with tqdm(total=num_seqs_to_save, desc="Grabbing DPO data", unit="seq") as pbar:
while total_count < num_seqs_to_save:
batch_count = await grab_batch(jsonl_writer)
total_count += batch_count
pbar.update(min(batch_count, num_seqs_to_save - total_count))
def main():
parser = argparse.ArgumentParser(
description="Grab SFT data from an Atropos API instance."
)
parser.add_argument(
"filepath",
type=str,
default="sft_data.jsonl",
help="Path to the output JSONL file for SFT data.",
)
parser.add_argument(
"--api-url",
type=str,
default="http://localhost:8000",
help="Base URL for the Atropos API server.",
)
parser.add_argument(
"--group-size",
type=int,
default=2,
help="Number of sequences processed per group by the API.",
)
parser.add_argument(
"--max-token-len",
type=int,
default=2048,
help="Maximum token length for sequences.",
)
parser.add_argument(
"--tokenizer",
type=str,
default="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
help="Hugging Face tokenizer model ID (used if --save-messages is not set).",
)
parser.add_argument(
"--save-messages",
action="store_true",
help="Save raw message structures instead of decoded text completions, if your environment supports it.",
)
parser.add_argument(
"--save-n-pairs-per-group",
type=int,
default=3,
help="Maximum number of paired sequences to save from each group.",
)
parser.add_argument(
"--num-seqs-to-save",
type=int,
default=100,
help="Total number of sequences to save before stopping.",
)
parser.add_argument(
"--allow-negative-scores",
action="store_true",
help="Allow sequences with negative scores to be saved.",
)
parser.add_argument(
"--minimum-score-diff-max-min",
type=float,
default=0.5,
help="Minimum score difference from the group minimum required to save a sequence.",
)
parser.add_argument(
"--append-to-previous",
action="store_true",
help="Append to the previous file instead of overwriting it.",
)
args = parser.parse_args()
asyncio.run(
dpo_data_grabber(
args.filepath,
args.api_url,
args.group_size,
args.max_token_len,
args.tokenizer,
args.save_messages,
args.save_n_pairs_per_group,
args.num_seqs_to_save,
args.allow_negative_scores,
args.minimum_score_diff_max_min,
args.append_to_previous,
)
)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,66 @@
import argparse
import time
import requests
import wandb
def update_wandb(health_statuses):
wandb.log(health_statuses)
def run(api_addr, tp, node_num):
print(f"Starting up with {api_addr}, {tp}, {node_num}", flush=True)
while True:
try:
data = requests.get(f"{api_addr}/wandb_info").json()
wandb_group = data["group"]
wandb_project = data["project"]
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
wandb_project = None
wandb_group = None
print("Waiting for init...")
if wandb_project is None:
time.sleep(1)
else:
wandb.init(
project=wandb_project, group=wandb_group, name=f"inf_node_{node_num}"
)
break
curr_step = 0
health_statuses = {
f"server/server_heath_{node_num}_{i}": 0.0 for i in range(8 // tp)
}
while True:
data = requests.get(f"{api_addr}/status").json()
step = data["current_step"]
if step > curr_step:
wandb.log(health_statuses, step=step)
curr_step = step
time.sleep(60)
# Check on each server
for i in range(8 // tp):
try:
health_status = requests.get(
f"http://localhost:{9000 + i}/health_generate"
).status_code
health_statuses[f"server/server_heath_{node_num}_{i}"] = (
1 if health_status == 200 else 0
)
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
health_statuses[f"server/server_heath_{node_num}_{i}"] = 0
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--api_addr", type=str, required=True)
parser.add_argument("--tp", type=int, required=True)
parser.add_argument("--node_num", type=int, required=True)
args = parser.parse_args()
run(args.api_addr, args.tp, args.node_num)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,9 @@
import uvicorn
def main():
uvicorn.run("atroposlib.api:app", host="0.0.0.0", port=8000, reload=True)
if __name__ == "__main__":
main()

318
atroposlib/cli/sft.py Normal file
View file

@ -0,0 +1,318 @@
import argparse
import asyncio
import os
import aiohttp
import jsonlines
from tqdm.asyncio import tqdm # Import tqdm for async
from transformers import AutoTokenizer
def find_common_prefix(strings):
"""
Finds the longest common prefix among a list of strings.
Args:
strings: A list of strings.
Returns:
The longest common prefix string, or an empty string if the list is empty
or no common prefix exists.
"""
if not strings:
return ""
prefix = strings[0]
for s in strings[1:]:
while not s.startswith(prefix):
prefix = prefix[:-1]
if not prefix:
return ""
return prefix
async def register_to_api(group_size, max_token_len, api_url, num_steps):
"""
Registers this data grabber instance with the Atropos API.
This involves resetting any previous data on the server and then sending
configuration parameters for the current session.
Args:
group_size: The number of sequences processed per group by the API.
max_token_len: The maximum token length for sequences.
api_url: The base URL of the Atropos API server.
num_steps: The number of steps to run the API for.
"""
async with aiohttp.ClientSession() as session:
# Reset data on the API server before registering
async with session.get(f"{api_url}/reset_data") as response:
print(await response.text())
# Register this instance with its configuration
async with session.post(
f"{api_url}/register",
json={
"wandb_group": "test",
"wandb_project": "test",
"batch_size": group_size * 8,
"max_token_len": max_token_len,
"checkpoint_dir": "checkpoints",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": num_steps * 2, # For a bit of a buffer just in case
},
) as response:
print("output of register is")
print(await response.text())
async def check_for_batch(api_url):
"""
Continuously polls the Atropos API until a batch of data is available.
Args:
api_url: The base URL of the Atropos API server.
Returns:
The batch data received from the API.
"""
while True:
async with aiohttp.ClientSession() as session:
async with session.get(f"{api_url}/batch") as response:
data = await response.json()
if data["batch"] is not None:
return data["batch"]
await asyncio.sleep(1) # Wait before polling again
def grab_group_data(
tok,
datagroup,
save_messages,
save_top_n_per_group,
allow_negative_scores=False,
minimum_score_diff_max_min=0.0,
):
"""
Processes a single group of data received from the API.
This function sorts the sequences within the group by score, filters them
based on scoring criteria, and formats them for saving.
Args:
tok: The Hugging Face tokenizer instance.
datagroup: A dictionary representing a group of sequences and their scores.
save_messages: Boolean indicating whether to save raw message structures
or decoded text completions.
save_top_n_per_group: The maximum number of sequences to save from this group.
allow_negative_scores: Boolean indicating whether to allow sequences with
negative scores.
minimum_score_diff_max_min: The minimum score difference from the group's
minimum score required to save a sequence.
Returns:
A list of processed and filtered sequences from the group, ready to be
written to the output file.
"""
if save_messages:
# Use raw message structures if specified
chats = datagroup["messages"]
else:
# Decode tokens into text and find common prefix/completion pairs
chats = [tok.decode(chat) for chat in datagroup["tokens"]]
prefix = find_common_prefix(chats)
# Split each chat into (prefix, completion)
chats = [(prefix, chat.split(prefix)[1]) for chat in chats]
scores = datagroup["scores"]
# Sort chats by score in descending order
sorted_chats = [
(
{"prefix": x[0], "completion": x[1], "score": score}
if not save_messages
else {"messages": x, "score": score}
)
for score, x in sorted(
zip(scores, chats), key=lambda pair: pair[0], reverse=True
)
]
# Apply filtering based on score criteria
if not allow_negative_scores:
sorted_chats = [x for x in sorted_chats if x["score"] > 0]
if minimum_score_diff_max_min > 0:
# Ensure the score is sufficiently higher than the minimum score in the group
min_score = min(scores) if scores else 0 # Handle empty scores list
sorted_chats = [
x
for x in sorted_chats
if x["score"] - min_score > minimum_score_diff_max_min
]
# Return only the top N sequences
return sorted_chats[:save_top_n_per_group]
async def sft_data_grabber(
filepath,
api_url,
group_size,
max_token_len,
tokenizer,
save_messages,
save_top_n_per_group,
num_seqs_to_save,
allow_negative_scores,
minimum_score_diff_max_min,
append_to_previous,
):
"""
Main asynchronous function to grab SFT data from the Atropos API.
It registers with the API, continuously fetches batches of data, processes
each batch, and writes the selected sequences to a JSONL file until the
desired number of sequences is saved.
Args:
filepath: Path to the output JSONL file.
api_url: Base URL of the Atropos API server.
group_size: Number of sequences processed per group by the API.
max_token_len: Maximum token length for sequences.
tokenizer: Hugging Face tokenizer model ID.
save_messages: Whether to save raw messages or decoded text.
save_top_n_per_group: Max sequences to save per group.
num_seqs_to_save: Total number of sequences to save.
allow_negative_scores: Whether to allow negative scores.
minimum_score_diff_max_min: Min score difference from group minimum.
append_to_previous: Whether to append to an existing file or overwrite.
"""
tok = AutoTokenizer.from_pretrained(tokenizer)
total_count = 0
async def grab_batch(jsonl_writer: jsonlines.Writer):
"""Fetches and processes one batch of data, returning the count."""
data = await check_for_batch(api_url)
count = 0
for group in data:
for item in grab_group_data(
tok,
group,
save_messages,
save_top_n_per_group,
allow_negative_scores,
minimum_score_diff_max_min,
):
jsonl_writer.write(item)
count += 1
return count
# Register with the API first
await register_to_api(group_size, max_token_len, api_url, num_steps=total_count)
# Check for file existence before opening
if os.path.exists(filepath) and not append_to_previous:
raise ValueError(
f"File '{filepath}' already exists and --append-to-previous is False."
)
# Open the file in write or append mode
file_mode = "a" if append_to_previous and os.path.exists(filepath) else "w"
with open(filepath, file_mode) as f:
jsonl_writer = jsonlines.Writer(f)
# Use tqdm for progress bar
with tqdm(total=num_seqs_to_save, desc="Grabbing SFT data", unit="seq") as pbar:
while total_count < num_seqs_to_save:
batch_count = await grab_batch(jsonl_writer)
total_count += batch_count
pbar.update(min(batch_count, num_seqs_to_save - total_count))
def main():
"""Parses command-line arguments and runs the SFT data grabber."""
parser = argparse.ArgumentParser(
description="Grab SFT data from an Atropos API instance."
)
parser.add_argument(
"filepath",
type=str,
default="sft_data.jsonl",
help="Path to the output JSONL file for SFT data.",
)
parser.add_argument(
"--api-url",
type=str,
default="http://localhost:8000",
help="Base URL for the Atropos API server.",
)
parser.add_argument(
"--group-size",
type=int,
default=2,
help="Number of sequences processed per group by the API.",
)
parser.add_argument(
"--max-token-len",
type=int,
default=2048,
help="Maximum token length for sequences.",
)
parser.add_argument(
"--tokenizer",
type=str,
default="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
help="Hugging Face tokenizer model ID (used if --save-messages is not set).",
)
parser.add_argument(
"--save-messages",
action="store_true",
help="Save raw message structures instead of decoded text completions, if your environment supports it.",
)
parser.add_argument(
"--save-top-n-per-group",
type=int,
default=3,
help="Maximum number of highest-scoring sequences to save from each group.",
)
parser.add_argument(
"--num-seqs-to-save",
type=int,
default=100,
help="Total number of sequences to save before stopping.",
)
parser.add_argument(
"--allow-negative-scores",
action="store_true",
help="Allow sequences with negative scores to be saved.",
)
parser.add_argument(
"--minimum-score-diff-max-min",
type=float,
default=0.0,
help="Minimum score difference from the group minimum required to save a sequence.",
)
parser.add_argument(
"--append-to-previous",
action="store_true",
help="Append to the previous file instead of overwriting it.",
)
args = parser.parse_args()
# Run the main async function
asyncio.run(
sft_data_grabber(
args.filepath,
args.api_url,
args.group_size,
args.max_token_len,
args.tokenizer,
args.save_messages,
args.save_top_n_per_group,
args.num_seqs_to_save,
args.allow_negative_scores,
args.minimum_score_diff_max_min,
args.append_to_previous,
)
)
if __name__ == "__main__":
main()

105
atroposlib/cli/view_run.py Normal file
View file

@ -0,0 +1,105 @@
import argparse
import asyncio
import aiohttp
import gradio as gr
from transformers import AutoTokenizer
def find_common_prefix(strings):
if not strings:
return ""
prefix = strings[0]
for s in strings[1:]:
while not s.startswith(prefix):
prefix = prefix[:-1]
if not prefix:
return ""
return prefix
async def register_to_api(group_size, max_token_len):
async with aiohttp.ClientSession() as session:
async with session.get("http://localhost:8000/reset_data") as response:
print(await response.text())
print(group_size)
async with session.post(
"http://localhost:8000/register",
json={
"wandb_group": "test",
"wandb_project": "test",
"batch_size": group_size
* 8, # * 8 just in case you want to just sample from a large group
"max_token_len": max_token_len,
"checkpoint_dir": "checkpoints",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": 69,
},
) as response:
print("output of register is")
print(await response.text())
async def check_for_batch():
while True:
async with aiohttp.ClientSession() as session:
async with session.get("http://localhost:8000/batch") as response:
data = await response.json()
print(data)
if data["batch"] is not None:
return data["batch"]
await asyncio.sleep(1)
async def build_interface(group_size, max_token_len, tokenizer, port):
async def grab_batch():
tok = AutoTokenizer.from_pretrained(tokenizer)
data = await check_for_batch()
print(data)
chats = [tok.decode(chat) for chat in data[0]["tokens"]]
# find common prefix
prefix = find_common_prefix(chats)
return (
(prefix,)
+ tuple([chat.split(prefix)[1] for chat in chats[:group_size]])
+ tuple(data[0]["scores"][:group_size])
)
with gr.Blocks() as demo:
prefix_blk = gr.Textbox(label="Prefix")
with gr.Row():
score_blks = [gr.Textbox(label=f"Score_{i+1}") for i in range(group_size)]
with gr.Row():
outputs_blks = [
gr.Textbox(label=f"Output_{i+1}") for i in range(group_size)
]
with gr.Row():
grab_next = gr.Button(value="Grab Next Batch")
grab_next.click(
fn=grab_batch,
outputs=[prefix_blk] + outputs_blks + score_blks,
api_name="get_batch",
)
await register_to_api(group_size, max_token_len)
demo.launch(server_port=port, share=True)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=9001)
parser.add_argument("--group-size", type=int, default=2)
parser.add_argument("--max-token-len", type=int, default=2048)
parser.add_argument(
"--tokenizer", type=str, default="NousResearch/DeepHermes-3-Llama-3-8B-Preview"
)
args = parser.parse_args()
asyncio.run(
build_interface(args.group_size, args.max_token_len, args.tokenizer, args.port)
)
if __name__ == "__main__":
main()