mirror of
https://github.com/lilakk/BLEUBERI.git
synced 2026-04-19 12:58:12 +00:00
77 lines
2.2 KiB
Python
77 lines
2.2 KiB
Python
import gc
|
|
import sys
|
|
from typing import Dict
|
|
|
|
import torch
|
|
|
|
|
|
def generate_stream_exllama(
|
|
model,
|
|
tokenizer,
|
|
params: Dict,
|
|
device: str,
|
|
context_len: int,
|
|
stream_interval: int = 2,
|
|
judge_sent_end: bool = False,
|
|
):
|
|
try:
|
|
from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler
|
|
except ImportError as e:
|
|
print(f"Error: Failed to load Exllamav2. {e}")
|
|
sys.exit(-1)
|
|
|
|
prompt = params["prompt"]
|
|
|
|
generator = ExLlamaV2StreamingGenerator(model.model, model.cache, tokenizer)
|
|
settings = ExLlamaV2Sampler.Settings()
|
|
|
|
settings.temperature = float(params.get("temperature", 0.85))
|
|
settings.top_k = int(params.get("top_k", 50))
|
|
settings.top_p = float(params.get("top_p", 0.8))
|
|
settings.token_repetition_penalty = float(params.get("repetition_penalty", 1.15))
|
|
settings.disallow_tokens(generator.tokenizer, [generator.tokenizer.eos_token_id])
|
|
|
|
max_new_tokens = int(params.get("max_new_tokens", 256))
|
|
|
|
generator.set_stop_conditions(params.get("stop_token_ids", None) or [])
|
|
echo = bool(params.get("echo", True))
|
|
|
|
input_ids = generator.tokenizer.encode(prompt)
|
|
prompt_tokens = input_ids.shape[-1]
|
|
generator.begin_stream(input_ids, settings)
|
|
|
|
generated_tokens = 0
|
|
if echo:
|
|
output = prompt
|
|
else:
|
|
output = ""
|
|
while True:
|
|
chunk, eos, _ = generator.stream()
|
|
output += chunk
|
|
generated_tokens += 1
|
|
if generated_tokens == max_new_tokens:
|
|
finish_reason = "length"
|
|
break
|
|
elif eos:
|
|
finish_reason = "length"
|
|
break
|
|
yield {
|
|
"text": output,
|
|
"usage": {
|
|
"prompt_tokens": prompt_tokens,
|
|
"completion_tokens": generated_tokens,
|
|
"total_tokens": prompt_tokens + generated_tokens,
|
|
},
|
|
"finish_reason": None,
|
|
}
|
|
|
|
yield {
|
|
"text": output,
|
|
"usage": {
|
|
"prompt_tokens": prompt_tokens,
|
|
"completion_tokens": generated_tokens,
|
|
"total_tokens": prompt_tokens + generated_tokens,
|
|
},
|
|
"finish_reason": finish_reason,
|
|
}
|
|
gc.collect()
|