mirror of
https://github.com/lilakk/BLEUBERI.git
synced 2026-04-19 12:58:12 +00:00
137 lines
4.1 KiB
Python
137 lines
4.1 KiB
Python
"""
|
||
Inference code for ChatGLM.
|
||
Adapted from https://huggingface.co/THUDM/chatglm-6b/blob/main/modeling_chatglm.py.
|
||
"""
|
||
import re
|
||
|
||
import torch
|
||
from transformers.generation.logits_process import LogitsProcessor
|
||
|
||
|
||
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||
def __call__(
|
||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
||
) -> torch.FloatTensor:
|
||
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
||
scores.zero_()
|
||
scores[..., 5] = 5e4
|
||
return scores
|
||
|
||
|
||
invalid_score_processor = InvalidScoreLogitsProcessor()
|
||
|
||
|
||
def process_response(response):
|
||
response = response.strip()
|
||
response = response.replace("[[训练时间]]", "2023年")
|
||
punkts = [
|
||
[",", ","],
|
||
["!", "!"],
|
||
[":", ":"],
|
||
[";", ";"],
|
||
["\?", "?"],
|
||
]
|
||
for item in punkts:
|
||
response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
|
||
response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
|
||
return response
|
||
|
||
|
||
def recover_message_list(prompt):
|
||
role_token_pattern = "|".join(
|
||
[re.escape(r) for r in ["<|system|>", "<|user|>", "<|assistant|>"]]
|
||
)
|
||
role = None
|
||
last_end_idx = -1
|
||
message_list = []
|
||
for match in re.finditer(role_token_pattern, prompt):
|
||
if role:
|
||
messge = {}
|
||
if role == "<|system|>":
|
||
messge["role"] = "system"
|
||
elif role == "<|user|>":
|
||
messge["role"] = "user"
|
||
else:
|
||
messge["role"] = "assistant"
|
||
messge["content"] = prompt[last_end_idx + 1 : match.start()]
|
||
message_list.append(messge)
|
||
|
||
role = prompt[match.start() : match.end()]
|
||
last_end_idx = match.end()
|
||
|
||
return message_list
|
||
|
||
|
||
@torch.inference_mode()
|
||
def generate_stream_chatglm(
|
||
model,
|
||
tokenizer,
|
||
params,
|
||
device,
|
||
context_len=2048,
|
||
stream_interval=2,
|
||
judge_sent_end=False,
|
||
):
|
||
prompt = params["prompt"]
|
||
temperature = float(params.get("temperature", 1.0))
|
||
repetition_penalty = float(params.get("repetition_penalty", 1.0))
|
||
top_p = float(params.get("top_p", 1.0))
|
||
max_new_tokens = int(params.get("max_new_tokens", 256))
|
||
echo = params.get("echo", True)
|
||
|
||
model_type = str(type(model)).lower()
|
||
if "peft" in model_type:
|
||
model_type = str(type(model.base_model.model)).lower()
|
||
|
||
if "chatglm3" in model_type:
|
||
message_list = recover_message_list(prompt)
|
||
inputs = tokenizer.build_chat_input(
|
||
query=message_list[-1]["content"], history=message_list[:-1], role="user"
|
||
).to(model.device)
|
||
else:
|
||
inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
|
||
input_echo_len = len(inputs["input_ids"][0])
|
||
|
||
gen_kwargs = {
|
||
"max_length": max_new_tokens + input_echo_len,
|
||
"do_sample": True if temperature > 1e-5 else False,
|
||
"top_p": top_p,
|
||
"repetition_penalty": repetition_penalty,
|
||
"logits_processor": [invalid_score_processor],
|
||
}
|
||
if temperature > 1e-5:
|
||
gen_kwargs["temperature"] = temperature
|
||
|
||
total_len = 0
|
||
for total_ids in model.stream_generate(**inputs, **gen_kwargs):
|
||
total_ids = total_ids.tolist()[0]
|
||
total_len = len(total_ids)
|
||
if echo:
|
||
output_ids = total_ids
|
||
else:
|
||
output_ids = total_ids[input_echo_len:]
|
||
response = tokenizer.decode(output_ids)
|
||
response = process_response(response)
|
||
|
||
yield {
|
||
"text": response,
|
||
"usage": {
|
||
"prompt_tokens": input_echo_len,
|
||
"completion_tokens": total_len - input_echo_len,
|
||
"total_tokens": total_len,
|
||
},
|
||
"finish_reason": None,
|
||
}
|
||
|
||
# TODO: ChatGLM stop when it reach max length
|
||
# Only last stream result contains finish_reason, we set finish_reason as stop
|
||
ret = {
|
||
"text": response,
|
||
"usage": {
|
||
"prompt_tokens": input_echo_len,
|
||
"completion_tokens": total_len - input_echo_len,
|
||
"total_tokens": total_len,
|
||
},
|
||
"finish_reason": "stop",
|
||
}
|
||
yield ret
|