mirror of
https://github.com/lilakk/BLEUBERI.git
synced 2026-04-19 12:58:12 +00:00
76 lines
2.5 KiB
Python
76 lines
2.5 KiB
Python
import os
|
|
from types import SimpleNamespace
|
|
import warnings
|
|
|
|
import torch
|
|
|
|
os.environ["RWKV_JIT_ON"] = "1"
|
|
os.environ["RWKV_CUDA_ON"] = "1"
|
|
|
|
from rwkv.model import RWKV
|
|
from rwkv.utils import PIPELINE, PIPELINE_ARGS
|
|
|
|
|
|
class RwkvModel:
|
|
def __init__(self, model_path):
|
|
warnings.warn(
|
|
"Experimental support. Please use ChatRWKV if you want to chat with RWKV"
|
|
)
|
|
self.config = SimpleNamespace(is_encoder_decoder=False)
|
|
self.model = RWKV(model=model_path, strategy="cuda fp16")
|
|
# two GPUs
|
|
# self.model = RWKV(model=model_path, strategy="cuda:0 fp16 *20 -> cuda:1 fp16")
|
|
|
|
self.tokenizer = None
|
|
self.model_path = model_path
|
|
|
|
def to(self, target):
|
|
assert target == "cuda"
|
|
|
|
def __call__(self, input_ids, use_cache, past_key_values=None):
|
|
assert use_cache == True
|
|
input_ids = input_ids[0].detach().cpu().numpy()
|
|
# print(input_ids)
|
|
logits, state = self.model.forward(input_ids, past_key_values)
|
|
# print(logits)
|
|
logits = logits.unsqueeze(0).unsqueeze(0)
|
|
out = SimpleNamespace(logits=logits, past_key_values=state)
|
|
return out
|
|
|
|
def generate(
|
|
self, input_ids, do_sample, temperature, max_new_tokens, repetition_penalty=1.0
|
|
):
|
|
# This function is used by fastchat.llm_judge.
|
|
# Because RWKV does not support huggingface generation API,
|
|
# we reuse fastchat.serve.inference.generate_stream as a workaround.
|
|
from transformers import AutoTokenizer
|
|
|
|
from fastchat.serve.inference import generate_stream
|
|
from fastchat.conversation import get_conv_template
|
|
|
|
if self.tokenizer is None:
|
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
"EleutherAI/pythia-160m", use_fast=True
|
|
)
|
|
prompt = self.tokenizer.decode(input_ids[0].tolist())
|
|
conv = get_conv_template("rwkv")
|
|
|
|
gen_params = {
|
|
"model": self.model_path,
|
|
"prompt": prompt,
|
|
"temperature": temperature,
|
|
"repetition_penalty": repetition_penalty,
|
|
"max_new_tokens": max_new_tokens,
|
|
"stop": conv.stop_str,
|
|
"stop_token_ids": conv.stop_token_ids,
|
|
"echo": False,
|
|
}
|
|
res_iter = generate_stream(self, self.tokenizer, gen_params, "cuda")
|
|
|
|
for res in res_iter:
|
|
pass
|
|
|
|
output = res["text"]
|
|
output_ids = self.tokenizer.encode(output)
|
|
|
|
return [input_ids[0].tolist() + output_ids]
|