mirror of
https://github.com/lilakk/BLEUBERI.git
synced 2026-04-19 12:58:12 +00:00
226 lines
7.5 KiB
Python
226 lines
7.5 KiB
Python
# Adapted from tatsu-lab@stanford_alpaca. Below is the original copyright:
|
|
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from collections import defaultdict
|
|
import copy
|
|
import os
|
|
from dataclasses import dataclass, field
|
|
import random
|
|
import json
|
|
import logging
|
|
import pathlib
|
|
from typing import Dict, Optional, Sequence, List
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
|
|
from deepspeed import zero
|
|
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType
|
|
|
|
import transformers
|
|
from torch.utils.data import Dataset
|
|
from transformers import Trainer, AddedToken, BitsAndBytesConfig, deepspeed
|
|
|
|
from fastchat.train.train_flant5 import (
|
|
smart_tokenizer_and_embedding_resize,
|
|
make_supervised_data_module,
|
|
)
|
|
|
|
from fastchat.train.train_lora import get_peft_state_maybe_zero_3
|
|
|
|
from fastchat.model.model_adapter import get_conversation_template
|
|
|
|
default_conversation = get_conversation_template("t5")
|
|
|
|
# TODO: import and use code from ../data/dataset.py
|
|
|
|
IGNORE_INDEX = -100
|
|
DEFAULT_PAD_TOKEN = "[PAD]"
|
|
DEFAULT_EOS_TOKEN = "</s>"
|
|
DEFAULT_BOS_TOKEN = "</s>"
|
|
DEFAULT_UNK_TOKEN = "</s>"
|
|
|
|
|
|
@dataclass
|
|
class LoraArguments:
|
|
lora_r: int = 8
|
|
lora_alpha: int = 16
|
|
lora_dropout: float = 0.05
|
|
lora_target_modules: List[str] = field(default_factory=lambda: ["q", "v"])
|
|
lora_weight_path: str = ""
|
|
lora_bias: str = "none"
|
|
q_lora: bool = False
|
|
|
|
|
|
@dataclass
|
|
class ModelArguments:
|
|
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
|
|
|
|
|
|
@dataclass
|
|
class DataArguments:
|
|
data_path: str = field(
|
|
default=None, metadata={"help": "Path to the training data."}
|
|
)
|
|
lazy_preprocess: bool = False
|
|
num_data: int = -1
|
|
preprocessed_path: str = field(
|
|
default=None, metadata={"help": "Path to the preprocessed training data."}
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class TrainingArguments(transformers.TrainingArguments):
|
|
cache_dir: Optional[str] = field(default=None)
|
|
optim: str = field(default="adamw_torch")
|
|
model_max_length: int = field(
|
|
default=2048,
|
|
metadata={
|
|
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
|
|
},
|
|
)
|
|
|
|
|
|
def safe_save_model_for_hf_trainer(
|
|
trainer: transformers.Trainer, output_dir: str, state_dict: dict
|
|
):
|
|
"""Collects the state dict and dump to disk."""
|
|
|
|
if trainer.args.should_save:
|
|
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
|
|
del state_dict
|
|
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
|
|
|
|
|
|
def train():
|
|
parser = transformers.HfArgumentParser(
|
|
(ModelArguments, DataArguments, TrainingArguments, LoraArguments)
|
|
)
|
|
(
|
|
model_args,
|
|
data_args,
|
|
training_args,
|
|
lora_args,
|
|
) = parser.parse_args_into_dataclasses()
|
|
|
|
device_map = None
|
|
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
|
ddp = world_size != 1
|
|
if lora_args.q_lora:
|
|
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
|
|
if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
|
|
logging.warning(
|
|
"FSDP and ZeRO3 are both currently incompatible with QLoRA."
|
|
)
|
|
|
|
compute_dtype = (
|
|
torch.float16
|
|
if training_args.fp16
|
|
else (torch.bfloat16 if training_args.bf16 else torch.float32)
|
|
)
|
|
|
|
model = transformers.AutoModelForSeq2SeqLM.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
cache_dir=training_args.cache_dir,
|
|
device_map=device_map,
|
|
quantization_config=BitsAndBytesConfig(
|
|
load_in_4bit=True,
|
|
bnb_4bit_use_double_quant=True,
|
|
bnb_4bit_quant_type="nf4",
|
|
bnb_4bit_compute_dtype=compute_dtype,
|
|
)
|
|
if lora_args.q_lora
|
|
else None,
|
|
)
|
|
|
|
lora_config = LoraConfig(
|
|
r=lora_args.lora_r,
|
|
lora_alpha=lora_args.lora_alpha,
|
|
target_modules=lora_args.lora_target_modules,
|
|
lora_dropout=lora_args.lora_dropout,
|
|
bias=lora_args.lora_bias,
|
|
task_type=TaskType.SEQ_2_SEQ_LM,
|
|
)
|
|
|
|
if lora_args.q_lora:
|
|
model = prepare_model_for_kbit_training(
|
|
model, use_gradient_checkpointing=training_args.gradient_checkpointing
|
|
)
|
|
if not ddp and torch.cuda.device_count() > 1:
|
|
# keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
|
|
model.is_parallelizable = True
|
|
model.model_parallel = True
|
|
|
|
model = get_peft_model(model, lora_config)
|
|
if training_args.deepspeed is not None and training_args.local_rank == 0:
|
|
model.print_trainable_parameters()
|
|
|
|
if training_args.gradient_checkpointing:
|
|
model.enable_input_require_grads()
|
|
|
|
# Dacheng: Note we can only use T5Tokenizer, otherwise it will prepend
|
|
# a space before special tokens.
|
|
tokenizer = transformers.T5Tokenizer.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
cache_dir=training_args.cache_dir,
|
|
model_max_length=training_args.model_max_length,
|
|
padding_side="right",
|
|
use_fast=False,
|
|
)
|
|
|
|
smart_tokenizer_and_embedding_resize(
|
|
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
|
|
other_tokens=["<", "{", "\n", "}", "`", " ", "\\", "^", "\t"],
|
|
tokenizer=tokenizer,
|
|
model=model,
|
|
)
|
|
|
|
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
|
|
|
|
trainer = Trainer(
|
|
model=model, tokenizer=tokenizer, args=training_args, **data_module
|
|
)
|
|
|
|
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
|
|
trainer.train(resume_from_checkpoint=True)
|
|
else:
|
|
trainer.train()
|
|
trainer.save_state()
|
|
# check if zero3 mode enabled
|
|
if deepspeed.is_deepspeed_zero3_enabled():
|
|
# use deepspeed engine internal function to gather state dict
|
|
# state_dict_zero3 contains whole parameters of base and lora adapters
|
|
# we will not extract lora parameters since peft save_pretrained will do that
|
|
# https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/peft_model.py#L125
|
|
# https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/utils/save_and_load.py#L19
|
|
state_dict_zero3 = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()
|
|
if training_args.local_rank == 0:
|
|
state_dict = state_dict_zero3
|
|
else:
|
|
# in other mode we use original code from fastchat team, to make sure our change is minimum
|
|
state_dict = get_peft_state_maybe_zero_3(
|
|
model.named_parameters(), lora_args.lora_bias
|
|
)
|
|
|
|
if training_args.local_rank == 0:
|
|
safe_save_model_for_hf_trainer(
|
|
trainer=trainer, output_dir=training_args.output_dir, state_dict=state_dict
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
train()
|