mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
266 lines
10 KiB
Python
266 lines
10 KiB
Python
import asyncio
|
|
from enum import Enum
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
from datasets import load_dataset
|
|
from pydantic import Field
|
|
|
|
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
|
|
from atroposlib.type_definitions import Item
|
|
|
|
|
|
class DatasetFormat(Enum):
|
|
COMPLETION = "completion"
|
|
PREFIXED_COMPLETION = "prefix_completion"
|
|
OAI_FORMAT = "oai_format"
|
|
SHAREGPT_FORMAT = "sharegpt_format"
|
|
|
|
|
|
class SFTConfig(BaseEnvConfig):
|
|
dataset_name: str = Field(
|
|
default="gsm8k",
|
|
description="The name of the HF dataset to use for training",
|
|
)
|
|
dataset_format: DatasetFormat = Field(
|
|
default=DatasetFormat.COMPLETION,
|
|
description="The format of the dataset to use for training",
|
|
)
|
|
dataset_column_name: str = Field(
|
|
default="text",
|
|
description="The name of the column to use for training",
|
|
)
|
|
prefix_column_name: Optional[str] = Field(
|
|
default=None,
|
|
description="The name of the column to use for the prefix, if applicable",
|
|
)
|
|
mask_everything_but_assistant_answer: bool = Field(
|
|
default=True,
|
|
description="Whether to mask everything but the assistant answer, if applicable",
|
|
)
|
|
mask_everything_but_last_step: bool = Field(
|
|
default=False,
|
|
description="Whether to mask everything but the last step, if applicable",
|
|
)
|
|
max_sft_per_step: int = Field(
|
|
default=-1,
|
|
description="The maximum number of SFTs to do per step, if -1 just sends all the data",
|
|
)
|
|
add_every_n_steps: int = Field(
|
|
default=1,
|
|
description="Only add SFT data every n steps",
|
|
)
|
|
|
|
|
|
class SFTEnv(BaseEnv):
|
|
|
|
name = "sft_loader"
|
|
|
|
def __init__(
|
|
self,
|
|
config: SFTConfig,
|
|
server_configs: List[OpenaiConfig],
|
|
slurm=True,
|
|
testing=False,
|
|
):
|
|
super().__init__(config, server_configs, slurm, testing)
|
|
self.percent_correct_buffer = list()
|
|
self.eval_metrics = list()
|
|
# Add tracking for wandb visualizations
|
|
self.rollouts_for_wandb = []
|
|
self.completion_lengths = []
|
|
self.idx = 0
|
|
self.last_step = -1
|
|
|
|
@classmethod
|
|
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
|
|
env_config = SFTConfig(
|
|
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
|
group_size=8,
|
|
use_wandb=True,
|
|
rollout_server_url="http://localhost:8000",
|
|
total_steps=1000,
|
|
batch_size=12,
|
|
steps_per_eval=100,
|
|
max_token_length=2048,
|
|
ensure_scores_are_not_same=False,
|
|
wandb_name="sft",
|
|
dataset_name="AlignmentLab-AI/open-instruct-sharegpt",
|
|
dataset_format=DatasetFormat.SHAREGPT_FORMAT,
|
|
dataset_column_name="conversations",
|
|
inference_weight=-1,
|
|
max_sft_per_step=8,
|
|
)
|
|
server_configs = [
|
|
OpenaiConfig(
|
|
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
|
base_url="http://localhost:9001/v1",
|
|
api_key="x",
|
|
num_requests_for_eval=256,
|
|
),
|
|
]
|
|
|
|
return env_config, server_configs
|
|
|
|
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
|
if wandb_metrics is None:
|
|
wandb_metrics = {}
|
|
wandb_metrics["train/idx"] = self.idx
|
|
wandb_metrics["train/epoch"] = self.idx // len(self.train)
|
|
# Call the parent method to handle the server metrics
|
|
await super().wandb_log(wandb_metrics)
|
|
|
|
async def setup(self):
|
|
self.train = load_dataset(self.config.dataset_name, split="train").shuffle(
|
|
seed=42
|
|
)
|
|
self.iter = 0
|
|
|
|
async def get_next_item(self) -> Item:
|
|
next_item = self.train[self.idx % len(self.train)]
|
|
# self.idx += 1
|
|
return next_item
|
|
|
|
async def evaluate(self, *args, **kwargs):
|
|
pass
|
|
|
|
async def add_train_workers(self):
|
|
pass
|
|
|
|
async def format_item(self, item: Item):
|
|
group = ScoredDataGroup()
|
|
group["tokens"] = list()
|
|
group["masks"] = list()
|
|
group["messages"] = list()
|
|
group["scores"] = [1]
|
|
group["advantages"] = [1]
|
|
# GRPO/RLOO should natively work if they don't apply mean/std to a 1 item group and have 0 KL term,
|
|
# but in case you are using something different, e.g. PPO,
|
|
# we'll send an override to the server to tell it to do a SFT step.
|
|
group["group_overrides"] = {"sft": True}
|
|
group["overrides"] = [{"sft": True}]
|
|
if self.config.dataset_format == DatasetFormat.COMPLETION:
|
|
tokens = self.tokenizer.encode(item[self.config.dataset_column_name])
|
|
group["tokens"].append(tokens)
|
|
group["messages"].append(item[self.config.dataset_column_name])
|
|
group["masks"].append(tokens)
|
|
elif self.config.dataset_format == DatasetFormat.PREFIXED_COMPLETION:
|
|
tokens = self.tokenizer.encode(
|
|
item[self.config.prefix_column_name]
|
|
+ item[self.config.dataset_column_name]
|
|
)
|
|
prefix_tokens = self.tokenizer.encode(item[self.config.prefix_column_name])
|
|
if prefix_tokens[-1] == self.tokenizer.eos_token_id:
|
|
prefix_tokens = prefix_tokens[:-1]
|
|
group["tokens"].append(tokens)
|
|
group["messages"].append(
|
|
item[self.config.prefix_column_name]
|
|
+ item[self.config.dataset_column_name]
|
|
)
|
|
group["masks"].append(
|
|
[-100 for _ in range(len(prefix_tokens))] + tokens[len(prefix_tokens) :]
|
|
)
|
|
else:
|
|
if self.config.dataset_format == DatasetFormat.OAI_FORMAT:
|
|
messages = item[self.config.dataset_column_name]
|
|
elif self.config.dataset_format == DatasetFormat.SHAREGPT_FORMAT:
|
|
messages = item[self.config.dataset_column_name]
|
|
# reformat to oai format
|
|
role_map = {"human": "user", "gpt": "assistant", "system": "system"}
|
|
messages = [
|
|
{"role": role_map[message["from"]], "content": message["value"]}
|
|
for message in messages
|
|
]
|
|
group["messages"].append(messages)
|
|
tokens = self.tokenizer.apply_chat_template(messages)
|
|
group["tokens"].append(tokens)
|
|
if self.config.mask_everything_but_last_step:
|
|
masks = [
|
|
-100
|
|
for _ in self.tokenizer.apply_chat_template(
|
|
messages[:-1], add_generation_prompt=True
|
|
)
|
|
]
|
|
masks.extend(tokens[len(masks) :])
|
|
group["masks"].append(masks)
|
|
elif self.config.mask_everything_but_assistant_answer:
|
|
# find the assistant answer in the chats
|
|
masks = []
|
|
for i, msg in enumerate(messages):
|
|
if msg["role"] == "assistant":
|
|
prefix_tokens = self.tokenizer.apply_chat_template(
|
|
messages[:i], add_generation_prompt=True
|
|
)
|
|
masks.extend(
|
|
[-100 for _ in range(len(prefix_tokens[len(masks) :]))]
|
|
)
|
|
curr_assist_tokens = self.tokenizer.apply_chat_template(
|
|
messages[: i + 1]
|
|
)
|
|
masks.extend(curr_assist_tokens[len(masks) :])
|
|
group["masks"].append(masks)
|
|
return group
|
|
|
|
async def env_step_checks(self):
|
|
# Check if we need to run an eval or log...
|
|
if self.curr_step != self.status_dict["current_step"]:
|
|
next_step = self.status_dict["current_step"]
|
|
if next_step % self.config.add_every_n_steps == 0:
|
|
to_send = (
|
|
self.config.max_batches_offpolicy * self.config.batch_size
|
|
) - (self.status_dict["queue_size"])
|
|
if self.config.max_sft_per_step != -1:
|
|
to_send = min(to_send, self.config.max_sft_per_step)
|
|
self.items_sent_this_step = to_send
|
|
if to_send > 0:
|
|
formatted_items = list()
|
|
for _ in range(to_send):
|
|
item = await self.get_next_item()
|
|
self.idx += 1
|
|
formatted_items.append(await self.format_item(item))
|
|
await asyncio.gather(
|
|
*[
|
|
self.handle_send_to_api(formatted_item, None)
|
|
for formatted_item in formatted_items
|
|
]
|
|
)
|
|
await super().env_step_checks()
|
|
|
|
|
|
async def checkout_formatting():
|
|
# SFTEnv.cli()
|
|
env = SFTEnv(
|
|
config=SFTConfig(
|
|
dataset_name="AlignmentLab-AI/open-instruct-sharegpt",
|
|
dataset_format=DatasetFormat.SHAREGPT_FORMAT,
|
|
dataset_column_name="conversations",
|
|
),
|
|
server_configs=[
|
|
OpenaiConfig(
|
|
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
|
base_url="http://localhost:9001/v1",
|
|
api_key="x",
|
|
num_requests_for_eval=256,
|
|
)
|
|
],
|
|
slurm=False,
|
|
testing=True,
|
|
)
|
|
await env.setup()
|
|
item = await env.get_next_item()
|
|
item = {
|
|
"conversations": [
|
|
{"from": "human", "value": "What is the capital of France?"},
|
|
{"from": "gpt", "value": "The capital of France is Paris."},
|
|
{"from": "human", "value": "What is the capital of Germany?"},
|
|
{"from": "gpt", "value": "The capital of Germany is Berlin."},
|
|
]
|
|
}
|
|
formatted_item = await env.format_item(item)
|
|
print(formatted_item)
|
|
print(env.tokenizer.decode(formatted_item["tokens"][0]))
|
|
print("--- mask decoding ---")
|
|
print(env.tokenizer.decode([x for x in formatted_item["masks"][0] if x != -100]))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
SFTEnv.cli()
|