atropos/environments/sft_loader_server.py
2025-05-07 15:19:26 -05:00

266 lines
10 KiB
Python

import asyncio
from enum import Enum
from typing import Dict, List, Optional, Tuple
from datasets import load_dataset
from pydantic import Field
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
from atroposlib.type_definitions import Item
class DatasetFormat(Enum):
COMPLETION = "completion"
PREFIXED_COMPLETION = "prefix_completion"
OAI_FORMAT = "oai_format"
SHAREGPT_FORMAT = "sharegpt_format"
class SFTConfig(BaseEnvConfig):
dataset_name: str = Field(
default="gsm8k",
description="The name of the HF dataset to use for training",
)
dataset_format: DatasetFormat = Field(
default=DatasetFormat.COMPLETION,
description="The format of the dataset to use for training",
)
dataset_column_name: str = Field(
default="text",
description="The name of the column to use for training",
)
prefix_column_name: Optional[str] = Field(
default=None,
description="The name of the column to use for the prefix, if applicable",
)
mask_everything_but_assistant_answer: bool = Field(
default=True,
description="Whether to mask everything but the assistant answer, if applicable",
)
mask_everything_but_last_step: bool = Field(
default=False,
description="Whether to mask everything but the last step, if applicable",
)
max_sft_per_step: int = Field(
default=-1,
description="The maximum number of SFTs to do per step, if -1 just sends all the data",
)
add_every_n_steps: int = Field(
default=1,
description="Only add SFT data every n steps",
)
class SFTEnv(BaseEnv):
name = "sft_loader"
def __init__(
self,
config: SFTConfig,
server_configs: List[OpenaiConfig],
slurm=True,
testing=False,
):
super().__init__(config, server_configs, slurm, testing)
self.percent_correct_buffer = list()
self.eval_metrics = list()
# Add tracking for wandb visualizations
self.rollouts_for_wandb = []
self.completion_lengths = []
self.idx = 0
self.last_step = -1
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
env_config = SFTConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
group_size=8,
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=12,
steps_per_eval=100,
max_token_length=2048,
ensure_scores_are_not_same=False,
wandb_name="sft",
dataset_name="AlignmentLab-AI/open-instruct-sharegpt",
dataset_format=DatasetFormat.SHAREGPT_FORMAT,
dataset_column_name="conversations",
inference_weight=-1,
max_sft_per_step=8,
)
server_configs = [
OpenaiConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
base_url="http://localhost:9001/v1",
api_key="x",
num_requests_for_eval=256,
),
]
return env_config, server_configs
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
if wandb_metrics is None:
wandb_metrics = {}
wandb_metrics["train/idx"] = self.idx
wandb_metrics["train/epoch"] = self.idx // len(self.train)
# Call the parent method to handle the server metrics
await super().wandb_log(wandb_metrics)
async def setup(self):
self.train = load_dataset(self.config.dataset_name, split="train").shuffle(
seed=42
)
self.iter = 0
async def get_next_item(self) -> Item:
next_item = self.train[self.idx % len(self.train)]
# self.idx += 1
return next_item
async def evaluate(self, *args, **kwargs):
pass
async def add_train_workers(self):
pass
async def format_item(self, item: Item):
group = ScoredDataGroup()
group["tokens"] = list()
group["masks"] = list()
group["messages"] = list()
group["scores"] = [1]
group["advantages"] = [1]
# GRPO/RLOO should natively work if they don't apply mean/std to a 1 item group and have 0 KL term,
# but in case you are using something different, e.g. PPO,
# we'll send an override to the server to tell it to do a SFT step.
group["group_overrides"] = {"sft": True}
group["overrides"] = [{"sft": True}]
if self.config.dataset_format == DatasetFormat.COMPLETION:
tokens = self.tokenizer.encode(item[self.config.dataset_column_name])
group["tokens"].append(tokens)
group["messages"].append(item[self.config.dataset_column_name])
group["masks"].append(tokens)
elif self.config.dataset_format == DatasetFormat.PREFIXED_COMPLETION:
tokens = self.tokenizer.encode(
item[self.config.prefix_column_name]
+ item[self.config.dataset_column_name]
)
prefix_tokens = self.tokenizer.encode(item[self.config.prefix_column_name])
if prefix_tokens[-1] == self.tokenizer.eos_token_id:
prefix_tokens = prefix_tokens[:-1]
group["tokens"].append(tokens)
group["messages"].append(
item[self.config.prefix_column_name]
+ item[self.config.dataset_column_name]
)
group["masks"].append(
[-100 for _ in range(len(prefix_tokens))] + tokens[len(prefix_tokens) :]
)
else:
if self.config.dataset_format == DatasetFormat.OAI_FORMAT:
messages = item[self.config.dataset_column_name]
elif self.config.dataset_format == DatasetFormat.SHAREGPT_FORMAT:
messages = item[self.config.dataset_column_name]
# reformat to oai format
role_map = {"human": "user", "gpt": "assistant", "system": "system"}
messages = [
{"role": role_map[message["from"]], "content": message["value"]}
for message in messages
]
group["messages"].append(messages)
tokens = self.tokenizer.apply_chat_template(messages)
group["tokens"].append(tokens)
if self.config.mask_everything_but_last_step:
masks = [
-100
for _ in self.tokenizer.apply_chat_template(
messages[:-1], add_generation_prompt=True
)
]
masks.extend(tokens[len(masks) :])
group["masks"].append(masks)
elif self.config.mask_everything_but_assistant_answer:
# find the assistant answer in the chats
masks = []
for i, msg in enumerate(messages):
if msg["role"] == "assistant":
prefix_tokens = self.tokenizer.apply_chat_template(
messages[:i], add_generation_prompt=True
)
masks.extend(
[-100 for _ in range(len(prefix_tokens[len(masks) :]))]
)
curr_assist_tokens = self.tokenizer.apply_chat_template(
messages[: i + 1]
)
masks.extend(curr_assist_tokens[len(masks) :])
group["masks"].append(masks)
return group
async def env_step_checks(self):
# Check if we need to run an eval or log...
if self.curr_step != self.status_dict["current_step"]:
next_step = self.status_dict["current_step"]
if next_step % self.config.add_every_n_steps == 0:
to_send = (
self.config.max_batches_offpolicy * self.config.batch_size
) - (self.status_dict["queue_size"])
if self.config.max_sft_per_step != -1:
to_send = min(to_send, self.config.max_sft_per_step)
self.items_sent_this_step = to_send
if to_send > 0:
formatted_items = list()
for _ in range(to_send):
item = await self.get_next_item()
self.idx += 1
formatted_items.append(await self.format_item(item))
await asyncio.gather(
*[
self.handle_send_to_api(formatted_item, None)
for formatted_item in formatted_items
]
)
await super().env_step_checks()
async def checkout_formatting():
# SFTEnv.cli()
env = SFTEnv(
config=SFTConfig(
dataset_name="AlignmentLab-AI/open-instruct-sharegpt",
dataset_format=DatasetFormat.SHAREGPT_FORMAT,
dataset_column_name="conversations",
),
server_configs=[
OpenaiConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
base_url="http://localhost:9001/v1",
api_key="x",
num_requests_for_eval=256,
)
],
slurm=False,
testing=True,
)
await env.setup()
item = await env.get_next_item()
item = {
"conversations": [
{"from": "human", "value": "What is the capital of France?"},
{"from": "gpt", "value": "The capital of France is Paris."},
{"from": "human", "value": "What is the capital of Germany?"},
{"from": "gpt", "value": "The capital of Germany is Berlin."},
]
}
formatted_item = await env.format_item(item)
print(formatted_item)
print(env.tokenizer.decode(formatted_item["tokens"][0]))
print("--- mask decoding ---")
print(env.tokenizer.decode([x for x in formatted_item["masks"][0] if x != -100]))
if __name__ == "__main__":
SFTEnv.cli()