mirror of
https://github.com/lilakk/BLEUBERI.git
synced 2026-04-19 12:58:12 +00:00
50 lines
1.4 KiB
Python
50 lines
1.4 KiB
Python
from dataclasses import dataclass, field
|
|
import sys
|
|
|
|
|
|
@dataclass
|
|
class ExllamaConfig:
|
|
max_seq_len: int
|
|
gpu_split: str = None
|
|
cache_8bit: bool = False
|
|
|
|
|
|
class ExllamaModel:
|
|
def __init__(self, exllama_model, exllama_cache):
|
|
self.model = exllama_model
|
|
self.cache = exllama_cache
|
|
self.config = self.model.config
|
|
|
|
|
|
def load_exllama_model(model_path, exllama_config: ExllamaConfig):
|
|
try:
|
|
from exllamav2 import (
|
|
ExLlamaV2Config,
|
|
ExLlamaV2Tokenizer,
|
|
ExLlamaV2,
|
|
ExLlamaV2Cache,
|
|
ExLlamaV2Cache_8bit,
|
|
)
|
|
except ImportError as e:
|
|
print(f"Error: Failed to load Exllamav2. {e}")
|
|
sys.exit(-1)
|
|
|
|
exllamav2_config = ExLlamaV2Config()
|
|
exllamav2_config.model_dir = model_path
|
|
exllamav2_config.prepare()
|
|
exllamav2_config.max_seq_len = exllama_config.max_seq_len
|
|
exllamav2_config.cache_8bit = exllama_config.cache_8bit
|
|
|
|
exllama_model = ExLlamaV2(exllamav2_config)
|
|
tokenizer = ExLlamaV2Tokenizer(exllamav2_config)
|
|
|
|
split = None
|
|
if exllama_config.gpu_split:
|
|
split = [float(alloc) for alloc in exllama_config.gpu_split.split(",")]
|
|
exllama_model.load(split)
|
|
|
|
cache_class = ExLlamaV2Cache_8bit if exllamav2_config.cache_8bit else ExLlamaV2Cache
|
|
exllama_cache = cache_class(exllama_model)
|
|
model = ExllamaModel(exllama_model=exllama_model, exllama_cache=exllama_cache)
|
|
|
|
return model, tokenizer
|