BLEUBERI/eval/FastChat/fastchat/model/model_chatglm.py
2025-06-04 20:36:43 +00:00

137 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Inference code for ChatGLM.
Adapted from https://huggingface.co/THUDM/chatglm-6b/blob/main/modeling_chatglm.py.
"""
import re
import torch
from transformers.generation.logits_process import LogitsProcessor
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 5] = 5e4
return scores
invalid_score_processor = InvalidScoreLogitsProcessor()
def process_response(response):
response = response.strip()
response = response.replace("[[训练时间]]", "2023年")
punkts = [
[",", ""],
["!", ""],
[":", ""],
[";", ""],
["\?", ""],
]
for item in punkts:
response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
return response
def recover_message_list(prompt):
role_token_pattern = "|".join(
[re.escape(r) for r in ["<|system|>", "<|user|>", "<|assistant|>"]]
)
role = None
last_end_idx = -1
message_list = []
for match in re.finditer(role_token_pattern, prompt):
if role:
messge = {}
if role == "<|system|>":
messge["role"] = "system"
elif role == "<|user|>":
messge["role"] = "user"
else:
messge["role"] = "assistant"
messge["content"] = prompt[last_end_idx + 1 : match.start()]
message_list.append(messge)
role = prompt[match.start() : match.end()]
last_end_idx = match.end()
return message_list
@torch.inference_mode()
def generate_stream_chatglm(
model,
tokenizer,
params,
device,
context_len=2048,
stream_interval=2,
judge_sent_end=False,
):
prompt = params["prompt"]
temperature = float(params.get("temperature", 1.0))
repetition_penalty = float(params.get("repetition_penalty", 1.0))
top_p = float(params.get("top_p", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 256))
echo = params.get("echo", True)
model_type = str(type(model)).lower()
if "peft" in model_type:
model_type = str(type(model.base_model.model)).lower()
if "chatglm3" in model_type:
message_list = recover_message_list(prompt)
inputs = tokenizer.build_chat_input(
query=message_list[-1]["content"], history=message_list[:-1], role="user"
).to(model.device)
else:
inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
input_echo_len = len(inputs["input_ids"][0])
gen_kwargs = {
"max_length": max_new_tokens + input_echo_len,
"do_sample": True if temperature > 1e-5 else False,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"logits_processor": [invalid_score_processor],
}
if temperature > 1e-5:
gen_kwargs["temperature"] = temperature
total_len = 0
for total_ids in model.stream_generate(**inputs, **gen_kwargs):
total_ids = total_ids.tolist()[0]
total_len = len(total_ids)
if echo:
output_ids = total_ids
else:
output_ids = total_ids[input_echo_len:]
response = tokenizer.decode(output_ids)
response = process_response(response)
yield {
"text": response,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": total_len - input_echo_len,
"total_tokens": total_len,
},
"finish_reason": None,
}
# TODO: ChatGLM stop when it reach max length
# Only last stream result contains finish_reason, we set finish_reason as stop
ret = {
"text": response,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": total_len - input_echo_len,
"total_tokens": total_len,
},
"finish_reason": "stop",
}
yield ret