mirror of
https://github.com/lilakk/BLEUBERI.git
synced 2026-04-19 12:58:12 +00:00
108 lines
3.1 KiB
Python
108 lines
3.1 KiB
Python
import gc
|
|
from threading import Thread
|
|
import torch
|
|
import transformers
|
|
from transformers import (
|
|
GenerationConfig,
|
|
StoppingCriteria,
|
|
StoppingCriteriaList,
|
|
TextIteratorStreamer,
|
|
)
|
|
|
|
|
|
@torch.inference_mode()
|
|
def generate_stream_codet5p(
|
|
model,
|
|
tokenizer,
|
|
params,
|
|
device,
|
|
context_len=2048,
|
|
stream_interval=2,
|
|
judge_sent_end=False,
|
|
):
|
|
prompt = params["prompt"]
|
|
temperature = float(params.get("temperature", 1.0))
|
|
repetition_penalty = float(params.get("repetition_penalty", 1.0))
|
|
top_p = float(params.get("top_p", 1.0))
|
|
top_k = int(params.get("top_k", 50)) # -1 means disable
|
|
max_new_tokens = int(params.get("max_new_tokens", 1024))
|
|
stop_token_ids = params.get("stop_token_ids", None) or []
|
|
stop_token_ids.append(tokenizer.eos_token_id)
|
|
|
|
decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
|
streamer = TextIteratorStreamer(tokenizer, **decode_config)
|
|
encoding = tokenizer(prompt, return_tensors="pt").to(device)
|
|
input_ids = encoding.input_ids
|
|
encoding["decoder_input_ids"] = encoding["input_ids"].clone()
|
|
input_echo_len = len(input_ids)
|
|
|
|
generation_config = GenerationConfig(
|
|
max_new_tokens=max_new_tokens,
|
|
do_sample=temperature >= 1e-5,
|
|
temperature=temperature,
|
|
repetition_penalty=repetition_penalty,
|
|
no_repeat_ngram_size=10,
|
|
top_p=top_p,
|
|
top_k=top_k,
|
|
eos_token_id=stop_token_ids,
|
|
)
|
|
|
|
class CodeBlockStopper(StoppingCriteria):
|
|
def __call__(
|
|
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
|
) -> bool:
|
|
# Code-completion is open-end generation.
|
|
# We check \n\n to stop at end of a code block.
|
|
if list(input_ids[0][-2:]) == [628, 198]:
|
|
return True
|
|
return False
|
|
|
|
gen_kwargs = dict(
|
|
**encoding,
|
|
streamer=streamer,
|
|
generation_config=generation_config,
|
|
stopping_criteria=StoppingCriteriaList([CodeBlockStopper()]),
|
|
)
|
|
thread = Thread(target=model.generate, kwargs=gen_kwargs)
|
|
thread.start()
|
|
i = 0
|
|
output = ""
|
|
for new_text in streamer:
|
|
i += 1
|
|
output += new_text
|
|
if i % stream_interval == 0 or i == max_new_tokens - 1:
|
|
yield {
|
|
"text": output,
|
|
"usage": {
|
|
"prompt_tokens": input_echo_len,
|
|
"completion_tokens": i,
|
|
"total_tokens": input_echo_len + i,
|
|
},
|
|
"finish_reason": None,
|
|
}
|
|
if i >= max_new_tokens:
|
|
break
|
|
|
|
if i >= max_new_tokens:
|
|
finish_reason = "length"
|
|
else:
|
|
finish_reason = "stop"
|
|
|
|
yield {
|
|
"text": output,
|
|
"usage": {
|
|
"prompt_tokens": input_echo_len,
|
|
"completion_tokens": i,
|
|
"total_tokens": input_echo_len + i,
|
|
},
|
|
"finish_reason": finish_reason,
|
|
}
|
|
thread.join()
|
|
|
|
# clean
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
if device == "xpu":
|
|
torch.xpu.empty_cache()
|
|
if device == "npu":
|
|
torch.npu.empty_cache()
|