mirror of
https://github.com/lilakk/BLEUBERI.git
synced 2026-04-19 12:58:12 +00:00
13 lines
355 B
Python
13 lines
355 B
Python
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
|
|
|
|
# Need to call this before importing transformers.
|
|
from fastchat.train.llama2_flash_attn_monkey_patch import (
|
|
replace_llama_attn_with_flash_attn,
|
|
)
|
|
|
|
replace_llama_attn_with_flash_attn()
|
|
|
|
from fastchat.train.train import train
|
|
|
|
if __name__ == "__main__":
|
|
train()
|