atropos/atroposlib/cli/view_run_multimodal.py
2025-05-07 21:53:01 +00:00

184 lines
6.1 KiB
Python

import argparse
import asyncio
import base64
import re
from io import BytesIO
import aiohttp
import gradio as gr
import PIL.Image
from transformers import AutoTokenizer
def find_common_prefix(strings):
if not strings:
return ""
prefix = strings[0]
for s in strings[1:]:
while not s.startswith(prefix):
prefix = prefix[:-1]
if not prefix:
return ""
return prefix
async def register_to_api(group_size, max_token_len):
async with aiohttp.ClientSession() as session:
async with session.get("http://localhost:8000/reset_data") as response:
print(await response.text())
print(group_size)
async with session.post(
"http://localhost:8000/register",
json={
"wandb_group": "test",
"wandb_project": "test",
"batch_size": group_size
* 8, # * 8 just in case you want to just sample from a large group
"max_token_len": max_token_len,
"checkpoint_dir": "checkpoints",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": 69,
},
) as response:
print("output of register is")
print(await response.text())
async def check_for_batch():
while True:
async with aiohttp.ClientSession() as session:
async with session.get("http://localhost:8000/batch") as response:
data = await response.json()
print(data)
if data["batch"] is not None:
return data["batch"]
await asyncio.sleep(1)
def extract_image_from_chat(chat_text):
# Extract the base64 image data from the chat text
# Support both jpeg and png formats
image_pattern = r'data:image/(jpeg|png);base64,([^"\\]*)'
match = re.search(image_pattern, chat_text)
if match:
base64_data = match.group(2)
try:
image_data = base64.b64decode(base64_data)
image = PIL.Image.open(BytesIO(image_data))
return image
except Exception as e:
print(f"Error decoding image: {e}")
return None
def extract_text_from_chat(chat_text):
# Try to extract text from JSON format first
# Check if this is JSON multimodal content
if '"type": "text"' in chat_text:
text_pattern = r'"type": "text", "text": "([^"]*)"'
match = re.search(text_pattern, chat_text)
if match:
return match.group(1)
# If not in JSON format, look for [Image] prefix
if "[Image]" in chat_text:
return chat_text.split("[Image]", 1)[1].strip()
# Return original text if no pattern is found
return chat_text
async def build_interface(group_size, max_token_len, tokenizer, port):
async def grab_batch():
tok = AutoTokenizer.from_pretrained(tokenizer)
data = await check_for_batch()
print(data)
chats = [tok.decode(chat) for chat in data[0]["tokens"]]
# Find common prefix
prefix = find_common_prefix(chats)
# Handle base64 encoded image
try:
if "images" in data[0] and data[0]["images"] and data[0]["images"][0]:
print("Found image data in batch")
# Convert base64 string to image
base64_image = data[0]["images"][0]
# If it's already a PIL Image, use it directly
if isinstance(base64_image, PIL.Image.Image):
image = base64_image
# If it's a base64 string, decode it
elif isinstance(base64_image, str):
# Remove data:image prefix if present
if base64_image.startswith("data:image"):
# Extract just the base64 part
image_data = base64_image.split(",", 1)[1]
else:
image_data = base64_image
# Decode base64 to bytes and create image
image_bytes = base64.b64decode(image_data)
image = PIL.Image.open(BytesIO(image_bytes))
else:
print(f"Image type not recognized: {type(base64_image)}")
image = None
else:
# Try to extract image from chat text as fallback
print("No images field found, trying to extract from chat text")
image = extract_image_from_chat(prefix)
except Exception as e:
print(f"Error processing image: {e}")
image = None
# Extract text prompt from prefix
text_prompt = extract_text_from_chat(prefix)
return (
image, # Image
text_prompt, # Text prompt
*[chat.split(prefix)[1] for chat in chats[:group_size]], # Model outputs
*data[0]["scores"][:group_size], # Scores
)
with gr.Blocks() as demo:
image_blk = gr.Image(label="Image", type="pil")
prompt_blk = gr.Textbox(label="Text Prompt")
with gr.Row():
score_blks = [gr.Textbox(label=f"Score_{i+1}") for i in range(group_size)]
with gr.Row():
outputs_blks = [
gr.Textbox(label=f"Output_{i+1}") for i in range(group_size)
]
with gr.Row():
grab_next = gr.Button(value="Grab Next Batch")
grab_next.click(
fn=grab_batch,
outputs=[image_blk, prompt_blk] + outputs_blks + score_blks,
api_name="get_batch",
)
await register_to_api(group_size, max_token_len)
demo.launch(server_port=port, share=True)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=9001)
parser.add_argument("--group-size", type=int, default=2)
parser.add_argument("--max-token-len", type=int, default=2048)
parser.add_argument("--tokenizer", type=str, default="Qwen/Qwen2-VL-2B-Instruct")
args = parser.parse_args()
asyncio.run(
build_interface(args.group_size, args.max_token_len, args.tokenizer, args.port)
)
if __name__ == "__main__":
main()