mirror of
https://github.com/lilakk/BLEUBERI.git
synced 2026-04-19 12:58:12 +00:00
1347 lines
41 KiB
Python
1347 lines
41 KiB
Python
"""Call API providers."""
|
|
|
|
import json
|
|
import os
|
|
import random
|
|
import re
|
|
from typing import Optional
|
|
import time
|
|
|
|
import requests
|
|
|
|
from fastchat.utils import build_logger
|
|
|
|
|
|
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
|
|
|
|
|
def get_api_provider_stream_iter(
|
|
conv,
|
|
model_name,
|
|
model_api_dict,
|
|
temperature,
|
|
top_p,
|
|
max_new_tokens,
|
|
state,
|
|
extra_body=None,
|
|
):
|
|
if model_api_dict["api_type"] == "openai":
|
|
if model_api_dict.get("vision-arena", False):
|
|
prompt = conv.to_openai_vision_api_messages()
|
|
else:
|
|
prompt = conv.to_openai_api_messages()
|
|
stream_iter = openai_api_stream_iter(
|
|
model_api_dict["model_name"],
|
|
prompt,
|
|
temperature,
|
|
top_p,
|
|
max_new_tokens,
|
|
api_base=model_api_dict["api_base"],
|
|
api_key=model_api_dict["api_key"],
|
|
)
|
|
elif model_api_dict["api_type"] == "openai_no_stream":
|
|
prompt = conv.to_openai_api_messages()
|
|
stream_iter = openai_api_stream_iter(
|
|
model_api_dict["model_name"],
|
|
prompt,
|
|
temperature,
|
|
top_p,
|
|
max_new_tokens,
|
|
api_base=model_api_dict["api_base"],
|
|
api_key=model_api_dict["api_key"],
|
|
stream=False,
|
|
)
|
|
elif model_api_dict["api_type"] == "openai_o1":
|
|
prompt = conv.to_openai_api_messages()
|
|
stream_iter = openai_api_stream_iter(
|
|
model_api_dict["model_name"],
|
|
prompt,
|
|
temperature,
|
|
top_p,
|
|
max_new_tokens,
|
|
api_base=model_api_dict["api_base"],
|
|
api_key=model_api_dict["api_key"],
|
|
is_o1=True,
|
|
)
|
|
elif model_api_dict["api_type"] == "openai_assistant":
|
|
last_prompt = conv.messages[-2][1]
|
|
stream_iter = openai_assistant_api_stream_iter(
|
|
state,
|
|
last_prompt,
|
|
assistant_id=model_api_dict["assistant_id"],
|
|
api_key=model_api_dict["api_key"],
|
|
)
|
|
elif model_api_dict["api_type"] == "anthropic":
|
|
if model_api_dict.get("vision-arena", False):
|
|
prompt = conv.to_anthropic_vision_api_messages()
|
|
else:
|
|
prompt = conv.to_openai_api_messages()
|
|
stream_iter = anthropic_api_stream_iter(
|
|
model_name, prompt, temperature, top_p, max_new_tokens
|
|
)
|
|
elif model_api_dict["api_type"] == "anthropic_message":
|
|
if model_api_dict.get("vision-arena", False):
|
|
prompt = conv.to_anthropic_vision_api_messages()
|
|
else:
|
|
prompt = conv.to_openai_api_messages()
|
|
stream_iter = anthropic_message_api_stream_iter(
|
|
model_api_dict["model_name"], prompt, temperature, top_p, max_new_tokens
|
|
)
|
|
elif model_api_dict["api_type"] == "anthropic_message_vertex":
|
|
if model_api_dict.get("vision-arena", False):
|
|
prompt = conv.to_anthropic_vision_api_messages()
|
|
else:
|
|
prompt = conv.to_openai_api_messages()
|
|
stream_iter = anthropic_message_api_stream_iter(
|
|
model_api_dict["model_name"],
|
|
prompt,
|
|
temperature,
|
|
top_p,
|
|
max_new_tokens,
|
|
vertex_ai=True,
|
|
)
|
|
elif model_api_dict["api_type"] == "gemini":
|
|
prompt = conv.to_gemini_api_messages()
|
|
stream_iter = gemini_api_stream_iter(
|
|
model_api_dict["model_name"],
|
|
prompt,
|
|
temperature,
|
|
top_p,
|
|
max_new_tokens,
|
|
api_key=model_api_dict["api_key"],
|
|
)
|
|
elif model_api_dict["api_type"] == "gemini_no_stream":
|
|
prompt = conv.to_gemini_api_messages()
|
|
stream_iter = gemini_api_stream_iter(
|
|
model_api_dict["model_name"],
|
|
prompt,
|
|
temperature,
|
|
top_p,
|
|
max_new_tokens,
|
|
api_key=model_api_dict["api_key"],
|
|
use_stream=False,
|
|
)
|
|
elif model_api_dict["api_type"] == "bard":
|
|
prompt = conv.to_openai_api_messages()
|
|
stream_iter = gemini_api_stream_iter(
|
|
model_api_dict["model_name"],
|
|
prompt,
|
|
None, # use Bard's default temperature
|
|
None, # use Bard's default top_p
|
|
max_new_tokens,
|
|
api_key=(model_api_dict["api_key"] or os.environ["BARD_API_KEY"]),
|
|
use_stream=False,
|
|
)
|
|
elif model_api_dict["api_type"] == "mistral":
|
|
if model_api_dict.get("vision-arena", False):
|
|
prompt = conv.to_openai_vision_api_messages(is_mistral=True)
|
|
else:
|
|
prompt = conv.to_openai_api_messages()
|
|
stream_iter = mistral_api_stream_iter(
|
|
model_api_dict["model_name"],
|
|
prompt,
|
|
temperature,
|
|
top_p,
|
|
max_new_tokens,
|
|
api_key=model_api_dict.get("api_key"),
|
|
)
|
|
elif model_api_dict["api_type"] == "nvidia":
|
|
prompt = conv.to_openai_api_messages()
|
|
stream_iter = nvidia_api_stream_iter(
|
|
model_name,
|
|
prompt,
|
|
temperature,
|
|
top_p,
|
|
max_new_tokens,
|
|
model_api_dict["api_base"],
|
|
model_api_dict["api_key"],
|
|
)
|
|
elif model_api_dict["api_type"] == "ai2":
|
|
prompt = conv.to_openai_api_messages()
|
|
stream_iter = ai2_api_stream_iter(
|
|
model_name,
|
|
model_api_dict["model_name"],
|
|
prompt,
|
|
temperature,
|
|
top_p,
|
|
max_new_tokens,
|
|
api_base=model_api_dict["api_base"],
|
|
api_key=model_api_dict["api_key"],
|
|
)
|
|
elif model_api_dict["api_type"] == "vertex":
|
|
prompt = conv.to_vertex_api_messages()
|
|
stream_iter = vertex_api_stream_iter(
|
|
model_name, prompt, temperature, top_p, max_new_tokens
|
|
)
|
|
elif model_api_dict["api_type"] == "yandexgpt":
|
|
# note: top_p parameter is unused by yandexgpt
|
|
|
|
messages = []
|
|
if conv.system_message:
|
|
messages.append({"role": "system", "text": conv.system_message})
|
|
messages += [
|
|
{"role": role, "text": text}
|
|
for role, text in conv.messages
|
|
if text is not None
|
|
]
|
|
|
|
fixed_temperature = model_api_dict.get("fixed_temperature")
|
|
if fixed_temperature is not None:
|
|
temperature = fixed_temperature
|
|
|
|
stream_iter = yandexgpt_api_stream_iter(
|
|
model_name=model_api_dict["model_name"],
|
|
messages=messages,
|
|
temperature=temperature,
|
|
max_tokens=max_new_tokens,
|
|
api_base=model_api_dict["api_base"],
|
|
api_key=model_api_dict.get("api_key"),
|
|
folder_id=model_api_dict.get("folder_id"),
|
|
)
|
|
elif model_api_dict["api_type"] == "cohere":
|
|
messages = conv.to_openai_api_messages()
|
|
stream_iter = cohere_api_stream_iter(
|
|
client_name=model_api_dict.get("client_name", "FastChat"),
|
|
model_id=model_api_dict["model_name"],
|
|
messages=messages,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
max_new_tokens=max_new_tokens,
|
|
api_base=model_api_dict["api_base"],
|
|
api_key=model_api_dict["api_key"],
|
|
)
|
|
elif model_api_dict["api_type"] == "reka":
|
|
messages = conv.to_reka_api_messages()
|
|
stream_iter = reka_api_stream_iter(
|
|
model_name=model_api_dict["model_name"],
|
|
messages=messages,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
max_new_tokens=max_new_tokens,
|
|
api_base=model_api_dict["api_base"],
|
|
api_key=model_api_dict["api_key"],
|
|
)
|
|
elif model_api_dict["api_type"] == "column":
|
|
if model_api_dict.get("vision-arena", False):
|
|
prompt = conv.to_openai_vision_api_messages()
|
|
else:
|
|
prompt = conv.to_openai_api_messages()
|
|
stream_iter = column_api_stream_iter(
|
|
model_name=model_api_dict["model_name"],
|
|
messages=prompt,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
max_new_tokens=max_new_tokens,
|
|
api_base=model_api_dict["api_base"],
|
|
api_key=model_api_dict["api_key"],
|
|
)
|
|
elif model_api_dict["api_type"] == "metagen":
|
|
prompt = conv.to_metagen_api_messages()
|
|
stream_iter = metagen_api_stream_iter(
|
|
model_api_dict["model_name"],
|
|
prompt,
|
|
temperature,
|
|
top_p,
|
|
max_new_tokens,
|
|
api_base=model_api_dict["api_base"],
|
|
api_key=model_api_dict["api_key"],
|
|
conversation_id=state.conv_id,
|
|
)
|
|
elif model_api_dict["api_type"] == "p2l":
|
|
prompt = conv.to_openai_api_messages()
|
|
stream_iter = p2l_api_stream_iter(
|
|
model_api_dict["model_name"],
|
|
prompt,
|
|
temperature,
|
|
top_p,
|
|
max_new_tokens,
|
|
api_base=model_api_dict["api_base"],
|
|
api_key=model_api_dict["api_key"],
|
|
extra_body=extra_body,
|
|
)
|
|
else:
|
|
raise NotImplementedError()
|
|
|
|
return stream_iter
|
|
|
|
|
|
def openai_api_stream_iter(
|
|
model_name,
|
|
messages,
|
|
temperature,
|
|
top_p,
|
|
max_new_tokens,
|
|
api_base=None,
|
|
api_key=None,
|
|
stream=True,
|
|
is_o1=False,
|
|
):
|
|
import openai
|
|
|
|
api_key = api_key or os.environ["OPENAI_API_KEY"]
|
|
|
|
if "azure" in model_name:
|
|
client = openai.AzureOpenAI(
|
|
api_version="2023-07-01-preview",
|
|
azure_endpoint=api_base or "https://api.openai.com/v1",
|
|
api_key=api_key,
|
|
)
|
|
else:
|
|
client = openai.OpenAI(
|
|
base_url=api_base or "https://api.openai.com/v1",
|
|
api_key=api_key,
|
|
timeout=180,
|
|
)
|
|
|
|
# Make requests for logging
|
|
text_messages = []
|
|
for message in messages:
|
|
if type(message["content"]) == str: # text-only model
|
|
text_messages.append(message)
|
|
else: # vision model
|
|
filtered_content_list = [
|
|
content for content in message["content"] if content["type"] == "text"
|
|
]
|
|
text_messages.append(
|
|
{"role": message["role"], "content": filtered_content_list}
|
|
)
|
|
|
|
gen_params = {
|
|
"model": model_name,
|
|
"prompt": text_messages,
|
|
"temperature": temperature,
|
|
"top_p": top_p,
|
|
"max_new_tokens": max_new_tokens,
|
|
}
|
|
logger.info(f"==== request ====\n{gen_params}")
|
|
|
|
if stream and not is_o1:
|
|
res = client.chat.completions.create(
|
|
model=model_name,
|
|
messages=messages,
|
|
temperature=temperature,
|
|
max_tokens=max_new_tokens,
|
|
stream=True,
|
|
)
|
|
text = ""
|
|
for chunk in res:
|
|
if len(chunk.choices) > 0:
|
|
text += chunk.choices[0].delta.content or ""
|
|
data = {
|
|
"text": text,
|
|
"error_code": 0,
|
|
}
|
|
yield data
|
|
else:
|
|
if is_o1:
|
|
res = client.chat.completions.create(
|
|
model=model_name,
|
|
messages=messages,
|
|
temperature=1.0,
|
|
stream=False,
|
|
)
|
|
else:
|
|
res = client.chat.completions.create(
|
|
model=model_name,
|
|
messages=messages,
|
|
temperature=temperature,
|
|
max_tokens=max_new_tokens,
|
|
stream=False,
|
|
)
|
|
text = res.choices[0].message.content
|
|
pos = 0
|
|
while pos < len(text):
|
|
# simulate token streaming
|
|
pos += 2
|
|
time.sleep(0.001)
|
|
data = {
|
|
"text": text[:pos],
|
|
"error_code": 0,
|
|
}
|
|
yield data
|
|
|
|
|
|
def column_api_stream_iter(
|
|
model_name,
|
|
messages,
|
|
temperature,
|
|
top_p,
|
|
max_new_tokens,
|
|
api_base=None,
|
|
api_key=None,
|
|
):
|
|
try:
|
|
messages_no_img = []
|
|
for msg in messages:
|
|
msg_no_img = msg.copy()
|
|
msg_no_img.pop("attachment", None)
|
|
messages_no_img.append(msg_no_img)
|
|
|
|
gen_params = {
|
|
"model": model_name,
|
|
"messages": messages_no_img,
|
|
"temperature": temperature,
|
|
"top_p": top_p,
|
|
"max_new_tokens": max_new_tokens,
|
|
"seed": 42,
|
|
}
|
|
logger.info(f"==== request ====\n{gen_params}")
|
|
|
|
gen_params["messages"] = messages
|
|
gen_params["stream"] = True
|
|
|
|
# payload.pop("model")
|
|
|
|
# try 3 times
|
|
for i in range(3):
|
|
try:
|
|
response = requests.post(
|
|
api_base, json=gen_params, stream=True, timeout=30
|
|
)
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"==== error ====\n{e}")
|
|
if i == 2:
|
|
yield {
|
|
"text": f"**API REQUEST ERROR** Reason: API timeout. please try again later.",
|
|
"error_code": 1,
|
|
}
|
|
return
|
|
|
|
text = ""
|
|
for line in response.iter_lines():
|
|
if line:
|
|
data = line.decode("utf-8")
|
|
if data.startswith("data:"):
|
|
data = json.loads(data[6:])["message"]
|
|
text += data
|
|
yield {"text": text, "error_code": 0}
|
|
|
|
except Exception as e:
|
|
logger.error(f"==== error ====\n{e}")
|
|
yield {
|
|
"text": f"**API REQUEST ERROR** Reason: Unknown.",
|
|
"error_code": 1,
|
|
}
|
|
|
|
|
|
def p2l_api_stream_iter(
|
|
model_name,
|
|
messages,
|
|
temperature,
|
|
top_p,
|
|
max_new_tokens,
|
|
api_base=None,
|
|
api_key=None,
|
|
extra_body=None,
|
|
):
|
|
import openai
|
|
|
|
client = openai.OpenAI(
|
|
base_url=api_base,
|
|
api_key=api_key or "-",
|
|
timeout=180,
|
|
)
|
|
|
|
# Make requests for logging
|
|
text_messages = []
|
|
for message in messages:
|
|
if type(message["content"]) == str: # text-only model
|
|
text_messages.append(message)
|
|
else: # vision model
|
|
filtered_content_list = [
|
|
content for content in message["content"] if content["type"] == "text"
|
|
]
|
|
text_messages.append(
|
|
{"role": message["role"], "content": filtered_content_list}
|
|
)
|
|
|
|
gen_params = {
|
|
"model": model_name,
|
|
"prompt": text_messages,
|
|
"temperature": None,
|
|
"top_p": None,
|
|
"max_new_tokens": max_new_tokens,
|
|
"extra_body": extra_body,
|
|
}
|
|
logger.info(f"==== request ====\n{gen_params}")
|
|
|
|
res = client.chat.completions.create(
|
|
model=model_name,
|
|
messages=messages,
|
|
max_tokens=max_new_tokens,
|
|
stream=True,
|
|
extra_body=extra_body,
|
|
)
|
|
text = ""
|
|
for chunk_idx, chunk in enumerate(res):
|
|
if len(chunk.choices) > 0:
|
|
text += chunk.choices[0].delta.content or ""
|
|
|
|
data = {
|
|
"text": text,
|
|
"error_code": 0,
|
|
}
|
|
|
|
if chunk_idx == 0:
|
|
if hasattr(chunk.choices[0].delta, "model"):
|
|
data["ans_model"] = chunk.choices[0].delta.model
|
|
|
|
if hasattr(chunk, "router_outputs"):
|
|
data["router_outputs"] = chunk.router_outputs
|
|
|
|
yield data
|
|
|
|
|
|
def upload_openai_file_to_gcs(file_id):
|
|
import openai
|
|
from google.cloud import storage
|
|
|
|
storage_client = storage.Client()
|
|
|
|
file = openai.files.content(file_id)
|
|
# upload file to GCS
|
|
bucket = storage_client.get_bucket("arena_user_content")
|
|
blob = bucket.blob(f"{file_id}")
|
|
blob.upload_from_string(file.read())
|
|
blob.make_public()
|
|
return blob.public_url
|
|
|
|
|
|
def openai_assistant_api_stream_iter(
|
|
state,
|
|
prompt,
|
|
assistant_id,
|
|
api_key=None,
|
|
):
|
|
import openai
|
|
import base64
|
|
|
|
api_key = api_key or os.environ["OPENAI_API_KEY"]
|
|
client = openai.OpenAI(base_url="https://api.openai.com/v1", api_key=api_key)
|
|
|
|
if state.oai_thread_id is None:
|
|
logger.info("==== create thread ====")
|
|
thread = client.beta.threads.create()
|
|
state.oai_thread_id = thread.id
|
|
logger.info(f"==== thread_id ====\n{state.oai_thread_id}")
|
|
thread_message = client.beta.threads.messages.with_raw_response.create(
|
|
state.oai_thread_id,
|
|
role="user",
|
|
content=prompt,
|
|
timeout=3,
|
|
)
|
|
# logger.info(f"header {thread_message.headers}")
|
|
thread_message = thread_message.parse()
|
|
# Make requests
|
|
gen_params = {
|
|
"assistant_id": assistant_id,
|
|
"thread_id": state.oai_thread_id,
|
|
"message": prompt,
|
|
}
|
|
logger.info(f"==== request ====\n{gen_params}")
|
|
|
|
res = requests.post(
|
|
f"https://api.openai.com/v1/threads/{state.oai_thread_id}/runs",
|
|
headers={
|
|
"Authorization": f"Bearer {api_key}",
|
|
"Content-Type": "application/json",
|
|
"OpenAI-Beta": "assistants=v1",
|
|
},
|
|
json={"assistant_id": assistant_id, "stream": True},
|
|
timeout=30,
|
|
stream=True,
|
|
)
|
|
|
|
list_of_text = []
|
|
list_of_raw_text = []
|
|
offset_idx = 0
|
|
full_ret_text = ""
|
|
idx_mapping = {}
|
|
cur_offset = 0
|
|
for line in res.iter_lines():
|
|
if not line:
|
|
continue
|
|
data = line.decode("utf-8")
|
|
# logger.info("data:", data)
|
|
if data.endswith("[DONE]"):
|
|
break
|
|
if data.startswith("event"):
|
|
event = data.split(":")[1].strip()
|
|
if event == "thread.message.completed":
|
|
offset_idx += len(list_of_text)
|
|
continue
|
|
data = json.loads(data[6:])
|
|
|
|
if data.get("status") == "failed":
|
|
yield {
|
|
"text": f"**API REQUEST ERROR** Reason: {data['last_error']['message']}",
|
|
"error_code": 1,
|
|
}
|
|
return
|
|
|
|
if data.get("status") == "completed":
|
|
logger.info(f"[debug]: {data}")
|
|
|
|
if data["object"] != "thread.message.delta":
|
|
continue
|
|
|
|
for delta in data["delta"]["content"]:
|
|
text_index = delta["index"] + offset_idx
|
|
if len(list_of_text) <= text_index:
|
|
list_of_text.append("")
|
|
list_of_raw_text.append("")
|
|
|
|
text = list_of_text[text_index]
|
|
raw_text = list_of_raw_text[text_index]
|
|
|
|
if delta["type"] == "text":
|
|
# text, url_citation or file_path
|
|
content = delta["text"]
|
|
if "annotations" in content and len(content["annotations"]) > 0:
|
|
annotations = content["annotations"]
|
|
|
|
raw_text_copy = text
|
|
for anno in annotations:
|
|
if anno["type"] == "url_citation":
|
|
pattern = r"【\d+†source】"
|
|
matches = re.findall(pattern, content["value"])
|
|
if len(matches) > 0:
|
|
for match in matches:
|
|
print(match)
|
|
if match not in idx_mapping:
|
|
idx_mapping[match] = len(idx_mapping) + 1
|
|
citation_number = idx_mapping[match]
|
|
|
|
start_idx = anno["start_index"] + cur_offset
|
|
end_idx = anno["end_index"] + cur_offset
|
|
url = anno["url_citation"]["url"]
|
|
|
|
citation = f" [[{citation_number}]]({url})"
|
|
raw_text_copy = (
|
|
raw_text_copy[:start_idx]
|
|
+ citation
|
|
+ raw_text_copy[end_idx:]
|
|
)
|
|
cur_offset += len(citation) - (end_idx - start_idx)
|
|
elif anno["type"] == "file_path":
|
|
file_public_url = upload_openai_file_to_gcs(
|
|
anno["file_path"]["file_id"]
|
|
)
|
|
raw_text_copy = raw_text_copy.replace(
|
|
anno["text"], f"{file_public_url}"
|
|
)
|
|
text = raw_text_copy
|
|
else:
|
|
text_content = content["value"]
|
|
text += text_content
|
|
elif delta["type"] == "image_file":
|
|
image_public_url = upload_openai_file_to_gcs(
|
|
delta["image_file"]["file_id"]
|
|
)
|
|
text += f""
|
|
|
|
list_of_text[text_index] = text
|
|
list_of_raw_text[text_index] = raw_text
|
|
|
|
full_ret_text = "\n".join(list_of_text)
|
|
yield {"text": full_ret_text, "error_code": 0}
|
|
|
|
|
|
def anthropic_api_stream_iter(model_name, prompt, temperature, top_p, max_new_tokens):
|
|
import anthropic
|
|
|
|
c = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
|
|
|
|
# Make requests
|
|
gen_params = {
|
|
"model": model_name,
|
|
"prompt": prompt,
|
|
"temperature": temperature,
|
|
"top_p": top_p,
|
|
"max_new_tokens": max_new_tokens,
|
|
}
|
|
logger.info(f"==== request ====\n{gen_params}")
|
|
|
|
res = c.completions.create(
|
|
prompt=prompt,
|
|
stop_sequences=[anthropic.HUMAN_PROMPT],
|
|
max_tokens_to_sample=max_new_tokens,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
model=model_name,
|
|
stream=True,
|
|
)
|
|
text = ""
|
|
for chunk in res:
|
|
text += chunk.completion
|
|
data = {
|
|
"text": text,
|
|
"error_code": 0,
|
|
}
|
|
yield data
|
|
|
|
|
|
def anthropic_message_api_stream_iter(
|
|
model_name,
|
|
messages,
|
|
temperature,
|
|
top_p,
|
|
max_new_tokens,
|
|
vertex_ai=False,
|
|
):
|
|
import anthropic
|
|
|
|
if vertex_ai:
|
|
client = anthropic.AnthropicVertex(
|
|
region=os.environ["GCP_LOCATION"],
|
|
project_id=os.environ["GCP_PROJECT_ID"],
|
|
max_retries=5,
|
|
)
|
|
else:
|
|
client = anthropic.Anthropic(
|
|
api_key=os.environ["ANTHROPIC_API_KEY"],
|
|
max_retries=5,
|
|
)
|
|
|
|
text_messages = []
|
|
for message in messages:
|
|
if type(message["content"]) == str: # text-only model
|
|
text_messages.append(message)
|
|
else: # vision model
|
|
filtered_content_list = [
|
|
content for content in message["content"] if content["type"] == "text"
|
|
]
|
|
text_messages.append(
|
|
{"role": message["role"], "content": filtered_content_list}
|
|
)
|
|
|
|
# Make requests for logging
|
|
gen_params = {
|
|
"model": model_name,
|
|
"prompt": text_messages,
|
|
"temperature": temperature,
|
|
"top_p": top_p,
|
|
"max_new_tokens": max_new_tokens,
|
|
}
|
|
logger.info(f"==== request ====\n{gen_params}")
|
|
|
|
system_prompt = ""
|
|
if messages[0]["role"] == "system":
|
|
if type(messages[0]["content"]) == dict:
|
|
system_prompt = messages[0]["content"]["text"]
|
|
elif type(messages[0]["content"]) == str:
|
|
system_prompt = messages[0]["content"]
|
|
# remove system prompt
|
|
messages = messages[1:]
|
|
|
|
text = ""
|
|
with client.messages.stream(
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
max_tokens=max_new_tokens,
|
|
messages=messages,
|
|
model=model_name,
|
|
system=system_prompt,
|
|
) as stream:
|
|
for chunk in stream.text_stream:
|
|
text += chunk
|
|
data = {
|
|
"text": text,
|
|
"error_code": 0,
|
|
}
|
|
yield data
|
|
|
|
|
|
def gemini_api_stream_iter(
|
|
model_name,
|
|
messages,
|
|
temperature,
|
|
top_p,
|
|
max_new_tokens,
|
|
api_key=None,
|
|
use_stream=True,
|
|
):
|
|
import google.generativeai as genai # pip install google-generativeai
|
|
|
|
if api_key is None:
|
|
api_key = os.environ["GEMINI_API_KEY"]
|
|
genai.configure(api_key=api_key)
|
|
|
|
generation_config = {
|
|
"temperature": temperature,
|
|
"max_output_tokens": max_new_tokens,
|
|
"top_p": top_p,
|
|
}
|
|
params = {
|
|
"model": model_name,
|
|
"prompt": messages,
|
|
}
|
|
params.update(generation_config)
|
|
logger.info(f"==== request ====\n{params}")
|
|
|
|
safety_settings = [
|
|
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
|
|
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
|
|
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
|
|
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
|
|
]
|
|
|
|
history = []
|
|
system_prompt = None
|
|
for message in messages[:-1]:
|
|
if message["role"] == "system":
|
|
system_prompt = message["content"]
|
|
continue
|
|
history.append({"role": message["role"], "parts": message["content"]})
|
|
|
|
model = genai.GenerativeModel(
|
|
model_name=model_name,
|
|
system_instruction=system_prompt,
|
|
generation_config=generation_config,
|
|
safety_settings=safety_settings,
|
|
)
|
|
convo = model.start_chat(history=history)
|
|
|
|
if use_stream:
|
|
response = convo.send_message(messages[-1]["content"], stream=True)
|
|
try:
|
|
text = ""
|
|
for chunk in response:
|
|
text += chunk.candidates[0].content.parts[0].text
|
|
data = {
|
|
"text": text,
|
|
"error_code": 0,
|
|
}
|
|
yield data
|
|
except Exception as e:
|
|
logger.error(f"==== error ====\n{e}")
|
|
reason = chunk.candidates
|
|
yield {
|
|
"text": f"**API REQUEST ERROR** Reason: {reason}.",
|
|
"error_code": 1,
|
|
}
|
|
else:
|
|
try:
|
|
response = convo.send_message(messages[-1]["content"], stream=False)
|
|
text = response.candidates[0].content.parts[0].text
|
|
pos = 0
|
|
while pos < len(text):
|
|
# simulate token streaming
|
|
pos += 5
|
|
time.sleep(0.001)
|
|
data = {
|
|
"text": text[:pos],
|
|
"error_code": 0,
|
|
}
|
|
yield data
|
|
except Exception as e:
|
|
logger.error(f"==== error ====\n{e}")
|
|
yield {
|
|
"text": f"**API REQUEST ERROR** Reason: {e}.",
|
|
"error_code": 1,
|
|
}
|
|
|
|
|
|
def ai2_api_stream_iter(
|
|
model_name,
|
|
model_id,
|
|
messages,
|
|
temperature,
|
|
top_p,
|
|
max_new_tokens,
|
|
api_key=None,
|
|
api_base=None,
|
|
):
|
|
# get keys and needed values
|
|
ai2_key = api_key or os.environ.get("AI2_API_KEY")
|
|
api_base = api_base or "https://inferd.allen.ai/api/v1/infer"
|
|
|
|
# Make requests
|
|
gen_params = {
|
|
"model": model_name,
|
|
"prompt": messages,
|
|
"temperature": temperature,
|
|
"top_p": top_p,
|
|
"max_new_tokens": max_new_tokens,
|
|
}
|
|
logger.info(f"==== request ====\n{gen_params}")
|
|
|
|
# AI2 uses vLLM, which requires that `top_p` be 1.0 for greedy sampling:
|
|
# https://github.com/vllm-project/vllm/blob/v0.1.7/vllm/sampling_params.py#L156-L157
|
|
if temperature == 0.0 and top_p < 1.0:
|
|
raise ValueError("top_p must be 1 when temperature is 0.0")
|
|
|
|
res = requests.post(
|
|
api_base,
|
|
stream=True,
|
|
headers={"Authorization": f"Bearer {ai2_key}"},
|
|
json={
|
|
"model_id": model_id,
|
|
# This input format is specific to the Tulu2 model. Other models
|
|
# may require different input formats. See the model's schema
|
|
# documentation on InferD for more information.
|
|
"input": {
|
|
"messages": messages,
|
|
"opts": {
|
|
"max_tokens": max_new_tokens,
|
|
"temperature": temperature,
|
|
"top_p": top_p,
|
|
"logprobs": 1, # increase for more choices
|
|
},
|
|
},
|
|
},
|
|
timeout=5,
|
|
)
|
|
|
|
if res.status_code != 200:
|
|
logger.error(f"unexpected response ({res.status_code}): {res.text}")
|
|
raise ValueError("unexpected response from InferD", res)
|
|
|
|
text = ""
|
|
for line in res.iter_lines():
|
|
if line:
|
|
part = json.loads(line)
|
|
if "result" in part and "output" in part["result"]:
|
|
for t in part["result"]["output"]["text"]:
|
|
text += t
|
|
else:
|
|
logger.error(f"unexpected part: {part}")
|
|
raise ValueError("empty result in InferD response")
|
|
|
|
data = {
|
|
"text": text,
|
|
"error_code": 0,
|
|
}
|
|
yield data
|
|
|
|
|
|
def mistral_api_stream_iter(
|
|
model_name, messages, temperature, top_p, max_new_tokens, api_key=None
|
|
):
|
|
# from mistralai.client import MistralClient
|
|
# from mistralai.models.chat_completion import ChatMessage
|
|
from mistralai import Mistral
|
|
|
|
if api_key is None:
|
|
api_key = os.environ["MISTRAL_API_KEY"]
|
|
|
|
client = Mistral(api_key=api_key)
|
|
|
|
# Make requests for logging
|
|
text_messages = []
|
|
for message in messages:
|
|
if type(message["content"]) == str: # text-only model
|
|
text_messages.append(message)
|
|
else: # vision model
|
|
filtered_content_list = [
|
|
content for content in message["content"] if content["type"] == "text"
|
|
]
|
|
text_messages.append(
|
|
{"role": message["role"], "content": filtered_content_list}
|
|
)
|
|
|
|
# Make requests
|
|
gen_params = {
|
|
"model": model_name,
|
|
"prompt": text_messages,
|
|
"temperature": temperature,
|
|
"top_p": top_p,
|
|
"max_new_tokens": max_new_tokens,
|
|
}
|
|
logger.info(f"==== request ====\n{gen_params}")
|
|
|
|
# new_messages = [
|
|
# ChatMessage(role=message["role"], content=message["content"])
|
|
# for message in messages
|
|
# ]
|
|
|
|
res = client.chat.stream(
|
|
model=model_name,
|
|
temperature=temperature,
|
|
messages=messages,
|
|
max_tokens=max_new_tokens,
|
|
top_p=top_p,
|
|
)
|
|
|
|
text = ""
|
|
for chunk in res:
|
|
if chunk.data.choices[0].delta.content is not None:
|
|
text += chunk.data.choices[0].delta.content
|
|
data = {
|
|
"text": text,
|
|
"error_code": 0,
|
|
}
|
|
yield data
|
|
|
|
|
|
def nvidia_api_stream_iter(
|
|
model_name, messages, temp, top_p, max_tokens, api_base, api_key=None
|
|
):
|
|
model_2_api = {
|
|
"nemotron-4-340b": "/b0fcd392-e905-4ab4-8eb9-aeae95c30b37",
|
|
}
|
|
api_base += model_2_api[model_name]
|
|
|
|
api_key = api_key or os.environ["NVIDIA_API_KEY"]
|
|
headers = {
|
|
"Authorization": f"Bearer {api_key}",
|
|
"accept": "text/event-stream",
|
|
"content-type": "application/json",
|
|
}
|
|
# nvidia api does not accept 0 temperature
|
|
if temp == 0.0:
|
|
temp = 0.000001
|
|
|
|
payload = {
|
|
"model": model_name,
|
|
"messages": messages,
|
|
"temperature": temp,
|
|
"top_p": top_p,
|
|
"max_tokens": max_tokens,
|
|
"seed": 42,
|
|
"stream": True,
|
|
}
|
|
logger.info(f"==== request ====\n{payload}")
|
|
|
|
# payload.pop("model")
|
|
|
|
# try 3 times
|
|
for i in range(3):
|
|
try:
|
|
response = requests.post(
|
|
api_base, headers=headers, json=payload, stream=True, timeout=3
|
|
)
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"==== error ====\n{e}")
|
|
if i == 2:
|
|
yield {
|
|
"text": f"**API REQUEST ERROR** Reason: API timeout. please try again later.",
|
|
"error_code": 1,
|
|
}
|
|
return
|
|
|
|
text = ""
|
|
for line in response.iter_lines():
|
|
if line:
|
|
data = line.decode("utf-8")
|
|
if data.endswith("[DONE]"):
|
|
break
|
|
data = json.loads(data[6:])["choices"][0]["delta"]["content"]
|
|
text += data
|
|
yield {"text": text, "error_code": 0}
|
|
|
|
|
|
def yandexgpt_api_stream_iter(
|
|
model_name, messages, temperature, max_tokens, api_base, api_key, folder_id
|
|
):
|
|
api_key = api_key or os.environ["YANDEXGPT_API_KEY"]
|
|
headers = {
|
|
"Authorization": f"Api-Key {api_key}",
|
|
"content-type": "application/json",
|
|
}
|
|
|
|
payload = {
|
|
"modelUri": f"gpt://{folder_id}/{model_name}",
|
|
"completionOptions": {
|
|
"temperature": temperature,
|
|
"max_tokens": max_tokens,
|
|
"stream": True,
|
|
},
|
|
"messages": messages,
|
|
}
|
|
logger.info(f"==== request ====\n{payload}")
|
|
|
|
# https://llm.api.cloud.yandex.net/foundationModels/v1/completion
|
|
response = requests.post(
|
|
api_base, headers=headers, json=payload, stream=True, timeout=60
|
|
)
|
|
text = ""
|
|
for line in response.iter_lines():
|
|
if line:
|
|
data = json.loads(line.decode("utf-8"))
|
|
data = data["result"]
|
|
top_alternative = data["alternatives"][0]
|
|
text = top_alternative["message"]["text"]
|
|
yield {"text": text, "error_code": 0}
|
|
|
|
status = top_alternative["status"]
|
|
if status in (
|
|
"ALTERNATIVE_STATUS_FINAL",
|
|
"ALTERNATIVE_STATUS_TRUNCATED_FINAL",
|
|
):
|
|
break
|
|
|
|
|
|
def cohere_api_stream_iter(
|
|
client_name: str,
|
|
model_id: str,
|
|
messages: list,
|
|
temperature: Optional[
|
|
float
|
|
] = None, # The SDK or API handles None for all parameters following
|
|
top_p: Optional[float] = None,
|
|
max_new_tokens: Optional[int] = None,
|
|
api_key: Optional[str] = None, # default is env var CO_API_KEY
|
|
api_base: Optional[str] = None,
|
|
):
|
|
import cohere
|
|
|
|
OPENAI_TO_COHERE_ROLE_MAP = {
|
|
"user": "User",
|
|
"assistant": "Chatbot",
|
|
"system": "System",
|
|
}
|
|
|
|
client = cohere.Client(
|
|
api_key=api_key,
|
|
base_url=api_base,
|
|
client_name=client_name,
|
|
)
|
|
|
|
# prepare and log requests
|
|
chat_history = [
|
|
dict(
|
|
role=OPENAI_TO_COHERE_ROLE_MAP[message["role"]], message=message["content"]
|
|
)
|
|
for message in messages[:-1]
|
|
]
|
|
actual_prompt = messages[-1]["content"]
|
|
|
|
gen_params = {
|
|
"model": model_id,
|
|
"messages": messages,
|
|
"chat_history": chat_history,
|
|
"prompt": actual_prompt,
|
|
"temperature": temperature,
|
|
"top_p": top_p,
|
|
"max_new_tokens": max_new_tokens,
|
|
}
|
|
logger.info(f"==== request ====\n{gen_params}")
|
|
|
|
# make request and stream response
|
|
res = client.chat_stream(
|
|
message=actual_prompt,
|
|
chat_history=chat_history,
|
|
model=model_id,
|
|
temperature=temperature,
|
|
max_tokens=max_new_tokens,
|
|
p=top_p,
|
|
)
|
|
try:
|
|
text = ""
|
|
for streaming_item in res:
|
|
if streaming_item.event_type == "text-generation":
|
|
text += streaming_item.text
|
|
yield {"text": text, "error_code": 0}
|
|
except cohere.core.ApiError as e:
|
|
logger.error(f"==== error from cohere api: {e} ====")
|
|
yield {
|
|
"text": f"**API REQUEST ERROR** Reason: {e}",
|
|
"error_code": 1,
|
|
}
|
|
|
|
|
|
def vertex_api_stream_iter(model_name, messages, temperature, top_p, max_new_tokens):
|
|
import vertexai
|
|
from vertexai import generative_models
|
|
from vertexai.generative_models import (
|
|
GenerationConfig,
|
|
GenerativeModel,
|
|
Image,
|
|
)
|
|
|
|
project_id = os.environ.get("GCP_PROJECT_ID", None)
|
|
location = os.environ.get("GCP_LOCATION", None)
|
|
vertexai.init(project=project_id, location=location)
|
|
|
|
text_messages = []
|
|
for message in messages:
|
|
if type(message) == str:
|
|
text_messages.append(message)
|
|
|
|
gen_params = {
|
|
"model": model_name,
|
|
"prompt": text_messages,
|
|
"temperature": temperature,
|
|
"top_p": top_p,
|
|
"max_new_tokens": max_new_tokens,
|
|
}
|
|
logger.info(f"==== request ====\n{gen_params}")
|
|
|
|
safety_settings = [
|
|
generative_models.SafetySetting(
|
|
category=generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT,
|
|
threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,
|
|
),
|
|
generative_models.SafetySetting(
|
|
category=generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
|
threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,
|
|
),
|
|
generative_models.SafetySetting(
|
|
category=generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
|
threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,
|
|
),
|
|
generative_models.SafetySetting(
|
|
category=generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
|
threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,
|
|
),
|
|
]
|
|
generator = GenerativeModel(model_name).generate_content(
|
|
messages,
|
|
stream=True,
|
|
generation_config=GenerationConfig(
|
|
top_p=top_p, max_output_tokens=max_new_tokens, temperature=temperature
|
|
),
|
|
safety_settings=safety_settings,
|
|
)
|
|
|
|
ret = ""
|
|
for chunk in generator:
|
|
# NOTE(chris): This may be a vertex api error, below is HOTFIX: https://github.com/googleapis/python-aiplatform/issues/3129
|
|
ret += chunk.candidates[0].content.parts[0]._raw_part.text
|
|
# ret += chunk.text
|
|
data = {
|
|
"text": ret,
|
|
"error_code": 0,
|
|
}
|
|
yield data
|
|
|
|
|
|
def reka_api_stream_iter(
|
|
model_name: str,
|
|
messages: list,
|
|
temperature: Optional[
|
|
float
|
|
] = None, # The SDK or API handles None for all parameters following
|
|
top_p: Optional[float] = None,
|
|
max_new_tokens: Optional[int] = None,
|
|
api_key: Optional[str] = None, # default is env var CO_API_KEY
|
|
api_base: Optional[str] = None,
|
|
):
|
|
from reka.client import Reka
|
|
from reka import TypedText
|
|
|
|
api_key = api_key or os.environ["REKA_API_KEY"]
|
|
|
|
client = Reka(api_key=api_key)
|
|
|
|
use_search_engine = False
|
|
if "-online" in model_name:
|
|
model_name = model_name.replace("-online", "")
|
|
use_search_engine = True
|
|
request = {
|
|
"model_name": model_name,
|
|
"conversation_history": messages,
|
|
"temperature": temperature,
|
|
"request_output_len": max_new_tokens,
|
|
"runtime_top_p": top_p,
|
|
"stream": True,
|
|
"use_search_engine": use_search_engine,
|
|
}
|
|
|
|
# Make requests for logging
|
|
text_messages = []
|
|
for turn in messages:
|
|
for message in turn.content:
|
|
if isinstance(message, TypedText):
|
|
text_messages.append({"type": message.type, "text": message.text})
|
|
logged_request = dict(request)
|
|
logged_request["conversation_history"] = text_messages
|
|
|
|
logger.info(f"==== request ====\n{logged_request}")
|
|
|
|
response = client.chat.create_stream(
|
|
messages=messages,
|
|
max_tokens=max_new_tokens,
|
|
top_p=top_p,
|
|
model=model_name,
|
|
)
|
|
|
|
for chunk in response:
|
|
try:
|
|
yield {"text": chunk.responses[0].chunk.content, "error_code": 0}
|
|
except:
|
|
yield {
|
|
"text": f"**API REQUEST ERROR** ",
|
|
"error_code": 1,
|
|
}
|
|
|
|
|
|
def metagen_api_stream_iter(
|
|
model_name,
|
|
messages,
|
|
temperature,
|
|
top_p,
|
|
max_new_tokens,
|
|
api_key,
|
|
api_base,
|
|
conversation_id,
|
|
):
|
|
try:
|
|
text_messages = []
|
|
for message in messages:
|
|
if type(message["content"]) == str: # text-only model
|
|
text_messages.append(message)
|
|
else: # vision model
|
|
filtered_content_list = [
|
|
content
|
|
for content in message["content"]
|
|
if content["type"] == "text"
|
|
]
|
|
text_messages.append(
|
|
{"role": message["role"], "content": filtered_content_list}
|
|
)
|
|
gen_params = {
|
|
"model": model_name,
|
|
"prompt": text_messages,
|
|
"temperature": temperature,
|
|
"top_p": top_p,
|
|
"max_new_tokens": max_new_tokens,
|
|
}
|
|
logger.info(f"==== request ====\n{gen_params}")
|
|
|
|
res = requests.post(
|
|
f"{api_base}/chat_stream_completions?access_token={api_key}",
|
|
stream=True,
|
|
headers={"Content-Type": "application/json"},
|
|
json={
|
|
"model": model_name,
|
|
"chunks_delimited": True,
|
|
"messages": messages,
|
|
"conversation_id": conversation_id,
|
|
"options": {
|
|
"max_tokens": max_new_tokens,
|
|
"generation_algorithm": "top_p",
|
|
"top_p": top_p,
|
|
"temperature": temperature,
|
|
},
|
|
},
|
|
timeout=30,
|
|
)
|
|
|
|
if res.status_code != 200:
|
|
logger.error(f"Unexpected response ({res.status_code}): {res.text}")
|
|
yield {
|
|
"text": f"**API REQUEST ERROR** Reason: Unknown.",
|
|
"error_code": 1,
|
|
}
|
|
|
|
text = ""
|
|
for line in res.iter_lines():
|
|
if line:
|
|
part = json.loads(line.decode("utf-8"))
|
|
if "text" in part:
|
|
text += part["text"]
|
|
data = {
|
|
"text": text,
|
|
"error_code": 0,
|
|
}
|
|
yield data
|
|
except Exception as e:
|
|
logger.error(f"==== error ====\n{e}")
|
|
yield {
|
|
"text": f"**API REQUEST ERROR** Reason: Unknown.",
|
|
"error_code": 1,
|
|
}
|