mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
251 lines
8 KiB
Python
251 lines
8 KiB
Python
import base64
|
|
import json
|
|
import random
|
|
import re
|
|
import traceback
|
|
from typing import List, Optional, Tuple
|
|
|
|
from datasets import load_dataset
|
|
|
|
from atroposlib.envs.base import (
|
|
APIServerConfig,
|
|
BaseEnv,
|
|
BaseEnvConfig,
|
|
ScoredDataGroup,
|
|
)
|
|
from atroposlib.type_definitions import GameHistory, Item
|
|
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
|
|
|
|
|
class MultimodalExampleEnv(BaseEnv):
|
|
name = "clevr_cogen_a_train"
|
|
name_config_cls = BaseEnvConfig
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
async def collect_trajectories(
|
|
self, item: Item
|
|
) -> Tuple[GameHistory | None, List[Item]]:
|
|
to_score = list()
|
|
to_backlog = list()
|
|
|
|
# Get the current image if it was stored
|
|
if hasattr(self, "current_image"):
|
|
|
|
# Convert PIL image to base64
|
|
import io
|
|
|
|
img_byte_arr = io.BytesIO()
|
|
self.current_image.save(img_byte_arr, format="PNG")
|
|
img_byte_arr = img_byte_arr.getvalue()
|
|
base64_image = base64.b64encode(img_byte_arr).decode("utf-8")
|
|
|
|
# Extract text content from item
|
|
user_content = dict(item[0][0]).get("content", "")
|
|
|
|
# Try to parse if it's JSON
|
|
if isinstance(user_content, str) and (
|
|
user_content.startswith("[") or user_content.startswith("{")
|
|
):
|
|
try:
|
|
parsed = json.loads(user_content)
|
|
text_content = ""
|
|
for element in parsed:
|
|
if element.get("type") == "text":
|
|
text_content = element.get("text", "")
|
|
|
|
if not text_content:
|
|
text_content = "Please solve this problem and provide your answer as \\boxed{answer}."
|
|
except Exception:
|
|
text_content = "Please solve this problem and provide your answer as \\boxed{answer}."
|
|
else:
|
|
text_content = user_content
|
|
|
|
# Create messages with the new format
|
|
messages = [
|
|
{
|
|
"role": "system",
|
|
"content": "You must submit your answer with \\boxed{answer}, please make sure to do this",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": text_content},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:image/png;base64,{base64_image}",
|
|
},
|
|
},
|
|
],
|
|
},
|
|
]
|
|
|
|
else:
|
|
messages = [
|
|
{
|
|
"role": "system",
|
|
"content": "You must submit your answer with \\boxed{answer}",
|
|
},
|
|
dict(item[0][0]),
|
|
]
|
|
|
|
chat_completions = await self.server.chat_completion(
|
|
messages=messages,
|
|
n=self.config.group_size,
|
|
max_tokens=1024 * 2,
|
|
timeout=60, # Add timeout to prevent hanging (60 seconds is more reasonable)
|
|
)
|
|
|
|
for i, chat_completion in enumerate(chat_completions.choices):
|
|
messages = (
|
|
dict(item[0][0]),
|
|
{"role": "assistant", "content": chat_completion.message.content},
|
|
)
|
|
to_score.append((messages, item[1], base64_image))
|
|
|
|
return to_score, to_backlog
|
|
|
|
async def postprocess_histories(
|
|
self, trajectories: List[GameHistory]
|
|
) -> ScoredDataGroup:
|
|
pass
|
|
|
|
async def evaluate(self, *args, **kwargs):
|
|
"""
|
|
Evaluate the environment, this is called every steps_per_eval steps
|
|
|
|
:param args:
|
|
:param kwargs:
|
|
:return: None.
|
|
"""
|
|
return
|
|
|
|
async def setup(self):
|
|
"""Setup the environment and load the multimodal dataset"""
|
|
self.dataset = load_dataset("leonardPKU/clevr_cogen_a_train")
|
|
self.train = self.dataset["train"]
|
|
self.iter = 0
|
|
|
|
async def get_next_item(self) -> Item:
|
|
"""
|
|
Get the next items to be rolled out, including the image
|
|
"""
|
|
try:
|
|
# Get next dataset item
|
|
next_item = self.train[self.iter % len(self.train)]
|
|
self.iter += 1
|
|
|
|
# Store image as a class attribute so collect_trajectories can access it
|
|
self.current_image = next_item["image"]
|
|
|
|
# Create a simple text prompt - the image will be added in collect_trajectories
|
|
# This avoids the unhashable type error with lists in frozensets
|
|
text_prompt = next_item["problem"]
|
|
|
|
# Create a simple text-only prompt
|
|
prompt = tuple(
|
|
[frozenset({"role": "user", "content": text_prompt}.items())]
|
|
)
|
|
answer = next_item["solution"]
|
|
|
|
# get image as base64
|
|
# image = next_item["image"]
|
|
|
|
# Convert PIL image to base64
|
|
import io
|
|
|
|
img_byte_arr = io.BytesIO()
|
|
self.current_image.save(img_byte_arr, format="PNG")
|
|
img_byte_arr = img_byte_arr.getvalue()
|
|
base64_image = base64.b64encode(img_byte_arr).decode("utf-8")
|
|
|
|
return (prompt, answer, base64_image)
|
|
|
|
except Exception:
|
|
traceback.print_exc()
|
|
|
|
# Create a dummy item as fallback
|
|
prompt = tuple(
|
|
[
|
|
frozenset(
|
|
{"role": "user", "content": "Please solve: 2 + 2 = ?"}.items()
|
|
)
|
|
]
|
|
)
|
|
answer = "4"
|
|
return (prompt, answer, "obobob")
|
|
|
|
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
|
|
scores = ScoredDataGroup()
|
|
scores["tokens"] = list()
|
|
scores["masks"] = list()
|
|
scores["scores"] = list()
|
|
scores["images"] = list()
|
|
random.shuffle(rollout_group_data)
|
|
for item in rollout_group_data:
|
|
out_dict = tokenize_for_trainer(self.tokenizer, item[0])
|
|
tokens = out_dict["tokens"]
|
|
masks = out_dict["masks"]
|
|
|
|
# Extract the answer from the model's response
|
|
try:
|
|
model_answer = (
|
|
item[0][-1]["content"].split("\\boxed{")[-1].split("}")[0]
|
|
)
|
|
|
|
pattern = r"<answer>\s*(\d{1,2})\s*</answer>"
|
|
string = rollout_group_data[0][1]
|
|
gold_answer = re.search(pattern, string).group(1)
|
|
|
|
reward = gold_answer == model_answer
|
|
except IndexError:
|
|
reward = False
|
|
|
|
# remove obviously bad examples
|
|
if len([1 for i in masks if i != -100]) < 10:
|
|
continue
|
|
|
|
scores["tokens"].append(tokens)
|
|
scores["masks"].append(masks)
|
|
scores["scores"].append(1.0 if reward else -1.0)
|
|
|
|
try:
|
|
scores["images"].append(item[2])
|
|
except IndexError:
|
|
scores["images"].append(None)
|
|
if len(scores["tokens"]) >= self.config.group_size:
|
|
break
|
|
|
|
return scores
|
|
|
|
@classmethod
|
|
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
|
|
config = BaseEnvConfig(
|
|
wandb_name="clevr_cogen_a_train",
|
|
tokenizer_name="Qwen/Qwen2-VL-2B-Instruct",
|
|
group_size=8,
|
|
use_wandb=True,
|
|
max_num_workers=2,
|
|
rollout_server_url="http://localhost:8000",
|
|
total_steps=1000,
|
|
batch_size=12,
|
|
steps_per_eval=100,
|
|
max_token_length=2048,
|
|
)
|
|
|
|
server_configs = [
|
|
APIServerConfig(
|
|
model_name="Qwen/Qwen2-VL-2B-Instruct",
|
|
base_url="http://localhost:9001/v1",
|
|
api_key="x",
|
|
num_requests_for_eval=256,
|
|
),
|
|
]
|
|
|
|
return config, server_configs
|
|
|
|
|
|
if __name__ == "__main__":
|
|
MultimodalExampleEnv.cli()
|