atropos/environments/multimodal_dpo/clevr_complex.py

250 lines
8.1 KiB
Python

import base64
import json
import random
import traceback
from typing import List, Optional, Tuple
from datasets import load_dataset
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
ScoredDataGroup,
)
from atroposlib.type_definitions import GameHistory, Item
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
class MultimodalComplexEnv(BaseEnv):
name = "clevr_complex"
name_config_cls = BaseEnvConfig
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
async def collect_trajectories(
self, item: Item
) -> Tuple[GameHistory | None, List[Item]]:
to_score = list()
to_backlog = list()
# Get the current image if it was stored
if hasattr(self, "current_image"):
# Convert PIL image to base64
import io
img_byte_arr = io.BytesIO()
self.current_image.save(img_byte_arr, format="PNG")
img_byte_arr = img_byte_arr.getvalue()
base64_image = base64.b64encode(img_byte_arr).decode("utf-8")
# Extract text content from item
user_content = dict(item[0][0]).get("content", "")
# Try to parse if it's JSON
if isinstance(user_content, str) and (
user_content.startswith("[") or user_content.startswith("{")
):
try:
parsed = json.loads(user_content)
text_content = ""
for element in parsed:
if element.get("type") == "text":
text_content = element.get("text", "")
if not text_content:
text_content = "Please solve this problem and provide your answer as \\boxed{answer}."
except Exception:
text_content = "Please solve this problem and provide your answer as \\boxed{answer}."
else:
text_content = user_content
# Create messages with the new format
messages = [
{
"role": "system",
"content": "You must submit your answer with \\boxed{answer}, please make sure to do this",
},
{
"role": "user",
"content": [
{"type": "text", "text": text_content},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64_image}",
},
},
],
},
]
else:
messages = [
{
"role": "system",
"content": "You must submit your answer with \\boxed{answer}",
},
dict(item[0][0]),
]
chat_completions = await self.server.chat_completion(
messages=messages,
n=self.config.group_size,
max_tokens=1024 * 2,
timeout=60, # Add timeout to prevent hanging (60 seconds is more reasonable)
)
for i, chat_completion in enumerate(chat_completions.choices):
messages = (
dict(item[0][0]),
{"role": "assistant", "content": chat_completion.message.content},
)
to_score.append((messages, item[1], base64_image))
to_postprocess = await self.score(to_score)
return to_postprocess, to_backlog
async def evaluate(self, *args, **kwargs):
"""
Evaluate the environment, this is called every steps_per_eval steps
:param args:
:param kwargs:
:return: None.
"""
return
async def setup(self):
"""Setup the environment and load the multimodal dataset"""
self.dataset = load_dataset("MMInstruction/Clevr_CoGenT_TrainA_70K_Complex")
self.train = self.dataset["train"]
self.iter = 0
async def get_next_item(self) -> Item:
"""
Get the next items to be rolled out, including the image
"""
try:
# Get next dataset item
next_item = self.train[self.iter % len(self.train)]
self.iter += 1
# Store image as a class attribute so collect_trajectories can access it
self.current_image = next_item["image"]
# Create a simple text prompt - the image will be added in collect_trajectories
# This avoids the unhashable type error with lists in frozensets
text_prompt = next_item["problem"]
# Create a simple text-only prompt
prompt = tuple(
[frozenset({"role": "user", "content": text_prompt}.items())]
)
answer = next_item["solution"]
# Convert PIL image to base64
import io
img_byte_arr = io.BytesIO()
self.current_image.save(img_byte_arr, format="PNG")
img_byte_arr = img_byte_arr.getvalue()
base64_image = base64.b64encode(img_byte_arr).decode("utf-8")
return (prompt, answer, base64_image)
except Exception:
traceback.print_exc()
# Create a dummy item as fallback
prompt = tuple(
[
frozenset(
{"role": "user", "content": "Please solve: 2 + 2 = ?"}.items()
)
]
)
answer = "4"
return (prompt, answer, "obobob")
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
scores = ScoredDataGroup()
scores["tokens"] = list()
scores["masks"] = list()
scores["scores"] = list()
scores["images"] = list()
random.shuffle(rollout_group_data)
for item in rollout_group_data:
out_dict = tokenize_for_trainer(self.tokenizer, item[0])
tokens = out_dict["tokens"]
masks = out_dict["masks"]
# Extract the answer from the model's response
try:
model_answer = (
item[0][-1]["content"].split("\\boxed{")[-1].split("}")[0]
)
# Handle both numeric and yes/no answers
gold_answer = rollout_group_data[0][1]
# Case-insensitive comparison for yes/no and direct comparison for numbers
if gold_answer.lower() in ["yes", "no"]:
reward = gold_answer.lower() == model_answer.lower()
else:
# For numeric answers
reward = gold_answer == model_answer
except IndexError:
reward = False
# remove obviously bad examples
if len([1 for i in masks if i != -100]) < 10:
continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(1.0 if reward else -1.0)
try:
scores["images"].append(item[2])
except IndexError:
scores["images"].append(None)
if len(scores["tokens"]) >= self.config.group_size:
break
return scores
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
config = BaseEnvConfig(
wandb_name="clevr_complex",
tokenizer_name="Qwen/Qwen2-VL-2B-Instruct",
group_size=8,
use_wandb=True,
max_num_workers=2,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=12,
steps_per_eval=100,
max_token_length=2048,
)
server_configs = [
APIServerConfig(
model_name="Qwen/Qwen2-VL-2B-Instruct",
base_url="http://localhost:9001/v1",
api_key="x",
num_requests_for_eval=256,
),
]
return config, server_configs
if __name__ == "__main__":
MultimodalComplexEnv.cli()