mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
first commit
This commit is contained in:
commit
621d00dd80
89 changed files with 15315 additions and 0 deletions
283
environments/multimodal_dpo/clevr_complex.py
Normal file
283
environments/multimodal_dpo/clevr_complex.py
Normal file
|
|
@ -0,0 +1,283 @@
|
|||
import base64
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import traceback
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
|
||||
from atroposlib.type_definitions import GameHistory, Item
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
|
||||
class MultimodalComplexEnv(BaseEnv):
|
||||
name = "clevr_complex"
|
||||
name_config_cls = BaseEnvConfig
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: Item
|
||||
) -> Tuple[GameHistory | None, List[Item]]:
|
||||
print("DEBUG: Starting collect_trajectories")
|
||||
to_score = list()
|
||||
to_backlog = list()
|
||||
|
||||
# Get the current image if it was stored
|
||||
if hasattr(self, "current_image"):
|
||||
print("DEBUG: Using current_image for multimodal content")
|
||||
|
||||
# Convert PIL image to base64
|
||||
import io
|
||||
|
||||
img_byte_arr = io.BytesIO()
|
||||
self.current_image.save(img_byte_arr, format="PNG")
|
||||
img_byte_arr = img_byte_arr.getvalue()
|
||||
base64_image = base64.b64encode(img_byte_arr).decode("utf-8")
|
||||
|
||||
# Extract text content from item
|
||||
user_content = dict(item[0][0]).get("content", "")
|
||||
|
||||
# Try to parse if it's JSON
|
||||
if isinstance(user_content, str) and (
|
||||
user_content.startswith("[") or user_content.startswith("{")
|
||||
):
|
||||
try:
|
||||
parsed = json.loads(user_content)
|
||||
text_content = ""
|
||||
for element in parsed:
|
||||
if element.get("type") == "text":
|
||||
text_content = element.get("text", "")
|
||||
|
||||
if not text_content:
|
||||
text_content = "Please solve this problem and provide your answer as \\boxed{answer}."
|
||||
except Exception as e:
|
||||
print(f"DEBUG: Error parsing JSON: {e}")
|
||||
text_content = "Please solve this problem and provide your answer as \\boxed{answer}."
|
||||
else:
|
||||
text_content = user_content
|
||||
|
||||
# Create messages with the new format
|
||||
print("DEBUG: Creating multimodal message with new format")
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You must submit your answer with \\boxed{answer}, please make sure to do this",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": text_content},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{base64_image}",
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
else:
|
||||
print("DEBUG: No image available, using text-only message")
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You must submit your answer with \\boxed{answer}",
|
||||
},
|
||||
dict(item[0][0]),
|
||||
]
|
||||
|
||||
print("DEBUG: About to call chat_completion")
|
||||
chat_completions = await self.server.chat_completion(
|
||||
messages=messages,
|
||||
n=self.config.group_size,
|
||||
max_tokens=1024 * 2,
|
||||
timeout=60, # Add timeout to prevent hanging (60 seconds is more reasonable)
|
||||
)
|
||||
print("DEBUG: chat_completion call successful")
|
||||
|
||||
for i, chat_completion in enumerate(chat_completions.choices):
|
||||
print(f"DEBUG: Processing completion {i+1}/{len(chat_completions.choices)}")
|
||||
messages = (
|
||||
dict(item[0][0]),
|
||||
{"role": "assistant", "content": chat_completion.message.content},
|
||||
)
|
||||
to_score.append((messages, item[1], base64_image))
|
||||
|
||||
print("DEBUG: Finished processing completions")
|
||||
|
||||
print("DEBUG: Returning from collect_trajectories")
|
||||
return to_score, to_backlog
|
||||
|
||||
async def postprocess_histories(
|
||||
self, trajectories: List[GameHistory]
|
||||
) -> ScoredDataGroup:
|
||||
pass
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""
|
||||
Evaluate the environment, this is called every steps_per_eval steps
|
||||
|
||||
:param args:
|
||||
:param kwargs:
|
||||
:return: None.
|
||||
"""
|
||||
return
|
||||
|
||||
async def setup(self):
|
||||
"""Setup the environment and load the multimodal dataset"""
|
||||
self.dataset = load_dataset("MMInstruction/Clevr_CoGenT_TrainA_70K_Complex")
|
||||
self.train = self.dataset["train"]
|
||||
self.iter = 0
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
"""
|
||||
Get the next items to be rolled out, including the image
|
||||
"""
|
||||
try:
|
||||
print("DEBUG: Starting get_next_item")
|
||||
|
||||
# Get next dataset item
|
||||
next_item = self.train[self.iter % len(self.train)]
|
||||
self.iter += 1
|
||||
|
||||
print(f"DEBUG: Retrieved dataset item {self.iter-1}")
|
||||
|
||||
# For debugging, we'll use a simple text-only prompt and store the image separately
|
||||
# This is because the collect_trajectories method will handle the multimodal formatting
|
||||
|
||||
# Store image as a class attribute so collect_trajectories can access it
|
||||
self.current_image = next_item["image"]
|
||||
print("DEBUG: Stored image in current_image attribute")
|
||||
|
||||
# Create a simple text prompt - the image will be added in collect_trajectories
|
||||
# This avoids the unhashable type error with lists in frozensets
|
||||
text_prompt = next_item["problem"]
|
||||
|
||||
# Create a simple text-only prompt
|
||||
prompt = tuple(
|
||||
[frozenset({"role": "user", "content": text_prompt}.items())]
|
||||
)
|
||||
answer = next_item["solution"]
|
||||
|
||||
# Convert PIL image to base64
|
||||
import io
|
||||
|
||||
img_byte_arr = io.BytesIO()
|
||||
self.current_image.save(img_byte_arr, format="PNG")
|
||||
img_byte_arr = img_byte_arr.getvalue()
|
||||
base64_image = base64.b64encode(img_byte_arr).decode("utf-8")
|
||||
|
||||
print("DEBUG: Created simple text-only prompt for get_next_item")
|
||||
return (prompt, answer, base64_image)
|
||||
|
||||
except Exception as e:
|
||||
print(f"DEBUG: Error in get_next_item: {str(e)}")
|
||||
traceback.print_exc()
|
||||
|
||||
# Create a dummy item as fallback
|
||||
prompt = tuple(
|
||||
[
|
||||
frozenset(
|
||||
{"role": "user", "content": "Please solve: 2 + 2 = ?"}.items()
|
||||
)
|
||||
]
|
||||
)
|
||||
answer = "4"
|
||||
return (prompt, answer, "obobob")
|
||||
|
||||
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
|
||||
scores = ScoredDataGroup()
|
||||
scores["tokens"] = list()
|
||||
scores["masks"] = list()
|
||||
scores["scores"] = list()
|
||||
scores["images"] = list()
|
||||
random.shuffle(rollout_group_data)
|
||||
for item in rollout_group_data:
|
||||
out_dict = tokenize_for_trainer(self.tokenizer, item[0])
|
||||
tokens = out_dict["tokens"]
|
||||
masks = out_dict["masks"]
|
||||
|
||||
# Extract the answer from the model's response
|
||||
try:
|
||||
model_answer = (
|
||||
item[0][-1]["content"].split("\\boxed{")[-1].split("}")[0]
|
||||
)
|
||||
print(
|
||||
f"DEBUG: Model answer: {model_answer} and RG data: {rollout_group_data[0][1]}"
|
||||
)
|
||||
|
||||
# Handle both numeric and yes/no answers
|
||||
gold_answer = rollout_group_data[0][1]
|
||||
|
||||
# Case-insensitive comparison for yes/no and direct comparison for numbers
|
||||
if gold_answer.lower() in ["yes", "no"]:
|
||||
reward = gold_answer.lower() == model_answer.lower()
|
||||
else:
|
||||
# For numeric answers
|
||||
reward = gold_answer == model_answer
|
||||
|
||||
except IndexError:
|
||||
reward = False
|
||||
|
||||
# remove obviously bad examples
|
||||
if len([1 for i in masks if i != -100]) < 10:
|
||||
continue
|
||||
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
scores["scores"].append(1.0 if reward else -1.0)
|
||||
|
||||
try:
|
||||
scores["images"].append(item[2])
|
||||
except IndexError:
|
||||
scores["images"].append(None)
|
||||
if len(scores["tokens"]) >= self.config.group_size:
|
||||
break
|
||||
|
||||
return scores
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
print("ERROR: OPENAI_API_KEY environment variable is not set!")
|
||||
print("Please set it using: export OPENAI_API_KEY=your_api_key")
|
||||
sys.exit(1)
|
||||
|
||||
print(
|
||||
f"DEBUG: Using API key starting with: {os.environ.get('OPENAI_API_KEY')[:5]}..."
|
||||
)
|
||||
|
||||
config = BaseEnvConfig(
|
||||
wandb_name="clevr_complex",
|
||||
tokenizer_name="gpt2",
|
||||
group_size=2,
|
||||
use_wandb=False,
|
||||
max_num_workers=2,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=1,
|
||||
steps_per_eval=10,
|
||||
ensure_scores_are_not_same=False,
|
||||
)
|
||||
|
||||
print("DEBUG: Creating OpenAI configuration")
|
||||
server_configs = [
|
||||
OpenaiConfig(
|
||||
model_name="gpt-4o", # Using GPT-4o which has multimodal capabilities
|
||||
base_url=None,
|
||||
api_key=os.environ.get("OPENAI_API_KEY"),
|
||||
num_requests_for_eval=1,
|
||||
),
|
||||
]
|
||||
|
||||
return config, server_configs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
MultimodalComplexEnv.cli()
|
||||
Loading…
Add table
Add a link
Reference in a new issue