mirror of
https://github.com/lilakk/BLEUBERI.git
synced 2026-04-19 12:58:12 +00:00
304 lines
11 KiB
Python
304 lines
11 KiB
Python
"""
|
|
Chat with a model with command line interface.
|
|
|
|
Usage:
|
|
python3 -m fastchat.serve.cli --model lmsys/vicuna-7b-v1.5
|
|
python3 -m fastchat.serve.cli --model lmsys/fastchat-t5-3b-v1.0
|
|
|
|
Other commands:
|
|
- Type "!!exit" or an empty line to exit.
|
|
- Type "!!reset" to start a new conversation.
|
|
- Type "!!remove" to remove the last prompt.
|
|
- Type "!!regen" to regenerate the last message.
|
|
- Type "!!save <filename>" to save the conversation history to a json file.
|
|
- Type "!!load <filename>" to load a conversation history from a json file.
|
|
"""
|
|
import argparse
|
|
import os
|
|
import re
|
|
import sys
|
|
|
|
from prompt_toolkit import PromptSession
|
|
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
|
|
from prompt_toolkit.completion import WordCompleter
|
|
from prompt_toolkit.history import InMemoryHistory
|
|
from prompt_toolkit.key_binding import KeyBindings
|
|
from rich.console import Console
|
|
from rich.live import Live
|
|
from rich.markdown import Markdown
|
|
import torch
|
|
|
|
from fastchat.model.model_adapter import add_model_args
|
|
from fastchat.modules.awq import AWQConfig
|
|
from fastchat.modules.exllama import ExllamaConfig
|
|
from fastchat.modules.xfastertransformer import XftConfig
|
|
from fastchat.modules.gptq import GptqConfig
|
|
from fastchat.serve.inference import ChatIO, chat_loop
|
|
from fastchat.utils import str_to_torch_dtype
|
|
|
|
|
|
class SimpleChatIO(ChatIO):
|
|
def __init__(self, multiline: bool = False):
|
|
self._multiline = multiline
|
|
|
|
def prompt_for_input(self, role) -> str:
|
|
if not self._multiline:
|
|
return input(f"{role}: ")
|
|
|
|
prompt_data = []
|
|
line = input(f"{role} [ctrl-d/z on empty line to end]: ")
|
|
while True:
|
|
prompt_data.append(line.strip())
|
|
try:
|
|
line = input()
|
|
except EOFError as e:
|
|
break
|
|
return "\n".join(prompt_data)
|
|
|
|
def prompt_for_output(self, role: str):
|
|
print(f"{role}: ", end="", flush=True)
|
|
|
|
def stream_output(self, output_stream):
|
|
pre = 0
|
|
for outputs in output_stream:
|
|
output_text = outputs["text"]
|
|
output_text = output_text.strip().split(" ")
|
|
now = len(output_text) - 1
|
|
if now > pre:
|
|
print(" ".join(output_text[pre:now]), end=" ", flush=True)
|
|
pre = now
|
|
print(" ".join(output_text[pre:]), flush=True)
|
|
return " ".join(output_text)
|
|
|
|
def print_output(self, text: str):
|
|
print(text)
|
|
|
|
|
|
class RichChatIO(ChatIO):
|
|
bindings = KeyBindings()
|
|
|
|
@bindings.add("escape", "enter")
|
|
def _(event):
|
|
event.app.current_buffer.newline()
|
|
|
|
def __init__(self, multiline: bool = False, mouse: bool = False):
|
|
self._prompt_session = PromptSession(history=InMemoryHistory())
|
|
self._completer = WordCompleter(
|
|
words=["!!exit", "!!reset", "!!remove", "!!regen", "!!save", "!!load"],
|
|
pattern=re.compile("$"),
|
|
)
|
|
self._console = Console()
|
|
self._multiline = multiline
|
|
self._mouse = mouse
|
|
|
|
def prompt_for_input(self, role) -> str:
|
|
self._console.print(f"[bold]{role}:")
|
|
# TODO(suquark): multiline input has some issues. fix it later.
|
|
prompt_input = self._prompt_session.prompt(
|
|
completer=self._completer,
|
|
multiline=False,
|
|
mouse_support=self._mouse,
|
|
auto_suggest=AutoSuggestFromHistory(),
|
|
key_bindings=self.bindings if self._multiline else None,
|
|
)
|
|
self._console.print()
|
|
return prompt_input
|
|
|
|
def prompt_for_output(self, role: str):
|
|
self._console.print(f"[bold]{role.replace('/', '|')}:")
|
|
|
|
def stream_output(self, output_stream):
|
|
"""Stream output from a role."""
|
|
# TODO(suquark): the console flickers when there is a code block
|
|
# above it. We need to cut off "live" when a code block is done.
|
|
|
|
# Create a Live context for updating the console output
|
|
with Live(console=self._console, refresh_per_second=4) as live:
|
|
# Read lines from the stream
|
|
for outputs in output_stream:
|
|
if not outputs:
|
|
continue
|
|
text = outputs["text"]
|
|
# Render the accumulated text as Markdown
|
|
# NOTE: this is a workaround for the rendering "unstandard markdown"
|
|
# in rich. The chatbots output treat "\n" as a new line for
|
|
# better compatibility with real-world text. However, rendering
|
|
# in markdown would break the format. It is because standard markdown
|
|
# treat a single "\n" in normal text as a space.
|
|
# Our workaround is adding two spaces at the end of each line.
|
|
# This is not a perfect solution, as it would
|
|
# introduce trailing spaces (only) in code block, but it works well
|
|
# especially for console output, because in general the console does not
|
|
# care about trailing spaces.
|
|
lines = []
|
|
for line in text.splitlines():
|
|
lines.append(line)
|
|
if line.startswith("```"):
|
|
# Code block marker - do not add trailing spaces, as it would
|
|
# break the syntax highlighting
|
|
lines.append("\n")
|
|
else:
|
|
lines.append(" \n")
|
|
markdown = Markdown("".join(lines))
|
|
# Update the Live console output
|
|
live.update(markdown)
|
|
self._console.print()
|
|
return text
|
|
|
|
def print_output(self, text: str):
|
|
self.stream_output([{"text": text}])
|
|
|
|
|
|
class ProgrammaticChatIO(ChatIO):
|
|
def prompt_for_input(self, role) -> str:
|
|
contents = ""
|
|
# `end_sequence` signals the end of a message. It is unlikely to occur in
|
|
# message content.
|
|
end_sequence = " __END_OF_A_MESSAGE_47582648__\n"
|
|
len_end = len(end_sequence)
|
|
while True:
|
|
if len(contents) >= len_end:
|
|
last_chars = contents[-len_end:]
|
|
if last_chars == end_sequence:
|
|
break
|
|
try:
|
|
char = sys.stdin.read(1)
|
|
contents = contents + char
|
|
except EOFError:
|
|
continue
|
|
contents = contents[:-len_end]
|
|
print(f"[!OP:{role}]: {contents}", flush=True)
|
|
return contents
|
|
|
|
def prompt_for_output(self, role: str):
|
|
print(f"[!OP:{role}]: ", end="", flush=True)
|
|
|
|
def stream_output(self, output_stream):
|
|
pre = 0
|
|
for outputs in output_stream:
|
|
output_text = outputs["text"]
|
|
output_text = output_text.strip().split(" ")
|
|
now = len(output_text) - 1
|
|
if now > pre:
|
|
print(" ".join(output_text[pre:now]), end=" ", flush=True)
|
|
pre = now
|
|
print(" ".join(output_text[pre:]), flush=True)
|
|
return " ".join(output_text)
|
|
|
|
def print_output(self, text: str):
|
|
print(text)
|
|
|
|
|
|
def main(args):
|
|
if args.gpus:
|
|
if len(args.gpus.split(",")) < args.num_gpus:
|
|
raise ValueError(
|
|
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
|
|
)
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
|
|
os.environ["XPU_VISIBLE_DEVICES"] = args.gpus
|
|
if args.enable_exllama:
|
|
exllama_config = ExllamaConfig(
|
|
max_seq_len=args.exllama_max_seq_len,
|
|
gpu_split=args.exllama_gpu_split,
|
|
cache_8bit=args.exllama_cache_8bit,
|
|
)
|
|
else:
|
|
exllama_config = None
|
|
if args.enable_xft:
|
|
xft_config = XftConfig(
|
|
max_seq_len=args.xft_max_seq_len,
|
|
data_type=args.xft_dtype,
|
|
)
|
|
if args.device != "cpu":
|
|
print("xFasterTransformer now is only support CPUs. Reset device to CPU")
|
|
args.device = "cpu"
|
|
else:
|
|
xft_config = None
|
|
if args.style == "simple":
|
|
chatio = SimpleChatIO(args.multiline)
|
|
elif args.style == "rich":
|
|
chatio = RichChatIO(args.multiline, args.mouse)
|
|
elif args.style == "programmatic":
|
|
chatio = ProgrammaticChatIO()
|
|
else:
|
|
raise ValueError(f"Invalid style for console: {args.style}")
|
|
try:
|
|
chat_loop(
|
|
args.model_path,
|
|
args.device,
|
|
args.num_gpus,
|
|
args.max_gpu_memory,
|
|
str_to_torch_dtype(args.dtype),
|
|
args.load_8bit,
|
|
args.cpu_offloading,
|
|
args.conv_template,
|
|
args.conv_system_msg,
|
|
args.temperature,
|
|
args.repetition_penalty,
|
|
args.max_new_tokens,
|
|
chatio,
|
|
gptq_config=GptqConfig(
|
|
ckpt=args.gptq_ckpt or args.model_path,
|
|
wbits=args.gptq_wbits,
|
|
groupsize=args.gptq_groupsize,
|
|
act_order=args.gptq_act_order,
|
|
),
|
|
awq_config=AWQConfig(
|
|
ckpt=args.awq_ckpt or args.model_path,
|
|
wbits=args.awq_wbits,
|
|
groupsize=args.awq_groupsize,
|
|
),
|
|
exllama_config=exllama_config,
|
|
xft_config=xft_config,
|
|
revision=args.revision,
|
|
judge_sent_end=args.judge_sent_end,
|
|
debug=args.debug,
|
|
history=not args.no_history,
|
|
)
|
|
except KeyboardInterrupt:
|
|
print("exit...")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
add_model_args(parser)
|
|
parser.add_argument(
|
|
"--conv-template", type=str, default=None, help="Conversation prompt template."
|
|
)
|
|
parser.add_argument(
|
|
"--conv-system-msg", type=str, default=None, help="Conversation system message."
|
|
)
|
|
parser.add_argument("--temperature", type=float, default=0.7)
|
|
parser.add_argument("--repetition_penalty", type=float, default=1.0)
|
|
parser.add_argument("--max-new-tokens", type=int, default=512)
|
|
parser.add_argument("--no-history", action="store_true")
|
|
parser.add_argument(
|
|
"--style",
|
|
type=str,
|
|
default="simple",
|
|
choices=["simple", "rich", "programmatic"],
|
|
help="Display style.",
|
|
)
|
|
parser.add_argument(
|
|
"--multiline",
|
|
action="store_true",
|
|
help="Enable multiline input. Use ESC+Enter for newline.",
|
|
)
|
|
parser.add_argument(
|
|
"--mouse",
|
|
action="store_true",
|
|
help="[Rich Style]: Enable mouse support for cursor positioning.",
|
|
)
|
|
parser.add_argument(
|
|
"--judge-sent-end",
|
|
action="store_true",
|
|
help="Whether enable the correction logic that interrupts the output of sentences due to EOS.",
|
|
)
|
|
parser.add_argument(
|
|
"--debug",
|
|
action="store_true",
|
|
help="Print useful debug information (e.g., prompts)",
|
|
)
|
|
args = parser.parse_args()
|
|
main(args)
|