fix multimodal envs. add view_run_multimodal

This commit is contained in:
Artem Yatsenko 2025-05-07 21:53:01 +00:00
parent a282604baa
commit 0f15be68a2
8 changed files with 265 additions and 187 deletions

View file

@ -1,9 +1,7 @@
import base64
import json
import os
import random
import re
import sys
import traceback
from typing import List, Optional, Tuple
@ -24,13 +22,11 @@ class MultimodalExampleEnv(BaseEnv):
async def collect_trajectories(
self, item: Item
) -> Tuple[GameHistory | None, List[Item]]:
print("DEBUG: Starting collect_trajectories")
to_score = list()
to_backlog = list()
# Get the current image if it was stored
if hasattr(self, "current_image"):
print("DEBUG: Using current_image for multimodal content")
# Convert PIL image to base64
import io
@ -56,14 +52,12 @@ class MultimodalExampleEnv(BaseEnv):
if not text_content:
text_content = "Please solve this problem and provide your answer as \\boxed{answer}."
except Exception as e:
print(f"DEBUG: Error parsing JSON: {e}")
except Exception:
text_content = "Please solve this problem and provide your answer as \\boxed{answer}."
else:
text_content = user_content
# Create messages with the new format
print("DEBUG: Creating multimodal message with new format")
messages = [
{
"role": "system",
@ -84,7 +78,6 @@ class MultimodalExampleEnv(BaseEnv):
]
else:
print("DEBUG: No image available, using text-only message")
messages = [
{
"role": "system",
@ -93,26 +86,20 @@ class MultimodalExampleEnv(BaseEnv):
dict(item[0][0]),
]
print("DEBUG: About to call chat_completion")
chat_completions = await self.server.chat_completion(
messages=messages,
n=self.config.group_size,
max_tokens=1024 * 2,
timeout=60, # Add timeout to prevent hanging (60 seconds is more reasonable)
)
print("DEBUG: chat_completion call successful")
for i, chat_completion in enumerate(chat_completions.choices):
print(f"DEBUG: Processing completion {i+1}/{len(chat_completions.choices)}")
messages = (
dict(item[0][0]),
{"role": "assistant", "content": chat_completion.message.content},
)
to_score.append((messages, item[1], base64_image))
print("DEBUG: Finished processing completions")
print("DEBUG: Returning from collect_trajectories")
return to_score, to_backlog
async def postprocess_histories(
@ -141,20 +128,12 @@ class MultimodalExampleEnv(BaseEnv):
Get the next items to be rolled out, including the image
"""
try:
print("DEBUG: Starting get_next_item")
# Get next dataset item
next_item = self.train[self.iter % len(self.train)]
self.iter += 1
print(f"DEBUG: Retrieved dataset item {self.iter-1}")
# For debugging, we'll use a simple text-only prompt and store the image separately
# This is because the collect_trajectories method will handle the multimodal formatting
# Store image as a class attribute so collect_trajectories can access it
self.current_image = next_item["image"]
print("DEBUG: Stored image in current_image attribute")
# Create a simple text prompt - the image will be added in collect_trajectories
# This avoids the unhashable type error with lists in frozensets
@ -177,11 +156,9 @@ class MultimodalExampleEnv(BaseEnv):
img_byte_arr = img_byte_arr.getvalue()
base64_image = base64.b64encode(img_byte_arr).decode("utf-8")
print("DEBUG: Created simple text-only prompt for get_next_item")
return (prompt, answer, base64_image)
except Exception as e:
print(f"DEBUG: Error in get_next_item: {str(e)}")
except Exception:
traceback.print_exc()
# Create a dummy item as fallback
@ -212,9 +189,6 @@ class MultimodalExampleEnv(BaseEnv):
model_answer = (
item[0][-1]["content"].split("\\boxed{")[-1].split("}")[0]
)
print(
f"DEBUG: Model answer: {model_answer} and RG data: {rollout_group_data[0][1]}"
)
pattern = r"<answer>\s*(\d{1,2})\s*</answer>"
string = rollout_group_data[0][1]
@ -243,35 +217,26 @@ class MultimodalExampleEnv(BaseEnv):
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
if not os.environ.get("OPENAI_API_KEY"):
print("ERROR: OPENAI_API_KEY environment variable is not set!")
print("Please set it using: export OPENAI_API_KEY=your_api_key")
sys.exit(1)
print(
f"DEBUG: Using API key starting with: {os.environ.get('OPENAI_API_KEY')[:5]}..."
)
config = BaseEnvConfig(
wandb_name="clevr_cogen",
tokenizer_name="gpt2",
group_size=2,
use_wandb=False,
wandb_name="clevr_cogen_a_train",
tokenizer_name="Qwen/Qwen2-VL-2B-Instruct",
group_size=8,
use_wandb=True,
max_num_workers=2,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=1,
steps_per_eval=10,
ensure_scores_are_not_same=False,
batch_size=12,
steps_per_eval=100,
max_token_length=2048,
)
print("DEBUG: Creating OpenAI configuration")
server_configs = [
OpenaiConfig(
model_name="gpt-4o", # Using GPT-4o which has multimodal capabilities
base_url=None,
api_key=os.environ.get("OPENAI_API_KEY"),
num_requests_for_eval=1,
model_name="Qwen/Qwen2-VL-2B-Instruct",
base_url="http://localhost:9001/v1",
api_key="x",
num_requests_for_eval=256,
),
]

View file

@ -1,8 +1,6 @@
import base64
import json
import os
import random
import sys
import traceback
from typing import List, Optional, Tuple
@ -10,6 +8,7 @@ from datasets import load_dataset
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
from atroposlib.type_definitions import GameHistory, Item
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
@ -23,13 +22,11 @@ class MultimodalComplexEnv(BaseEnv):
async def collect_trajectories(
self, item: Item
) -> Tuple[GameHistory | None, List[Item]]:
print("DEBUG: Starting collect_trajectories")
to_score = list()
to_backlog = list()
# Get the current image if it was stored
if hasattr(self, "current_image"):
print("DEBUG: Using current_image for multimodal content")
# Convert PIL image to base64
import io
@ -55,14 +52,12 @@ class MultimodalComplexEnv(BaseEnv):
if not text_content:
text_content = "Please solve this problem and provide your answer as \\boxed{answer}."
except Exception as e:
print(f"DEBUG: Error parsing JSON: {e}")
except Exception:
text_content = "Please solve this problem and provide your answer as \\boxed{answer}."
else:
text_content = user_content
# Create messages with the new format
print("DEBUG: Creating multimodal message with new format")
messages = [
{
"role": "system",
@ -83,7 +78,6 @@ class MultimodalComplexEnv(BaseEnv):
]
else:
print("DEBUG: No image available, using text-only message")
messages = [
{
"role": "system",
@ -92,32 +86,23 @@ class MultimodalComplexEnv(BaseEnv):
dict(item[0][0]),
]
print("DEBUG: About to call chat_completion")
chat_completions = await self.server.chat_completion(
messages=messages,
n=self.config.group_size,
max_tokens=1024 * 2,
timeout=60, # Add timeout to prevent hanging (60 seconds is more reasonable)
)
print("DEBUG: chat_completion call successful")
for i, chat_completion in enumerate(chat_completions.choices):
print(f"DEBUG: Processing completion {i+1}/{len(chat_completions.choices)}")
messages = (
dict(item[0][0]),
{"role": "assistant", "content": chat_completion.message.content},
)
to_score.append((messages, item[1], base64_image))
print("DEBUG: Finished processing completions")
to_postprocess = await self.score(to_score)
print("DEBUG: Returning from collect_trajectories")
return to_score, to_backlog
async def postprocess_histories(
self, trajectories: List[GameHistory]
) -> ScoredDataGroup:
pass
return to_postprocess, to_backlog
async def evaluate(self, *args, **kwargs):
"""
@ -140,20 +125,13 @@ class MultimodalComplexEnv(BaseEnv):
Get the next items to be rolled out, including the image
"""
try:
print("DEBUG: Starting get_next_item")
# Get next dataset item
next_item = self.train[self.iter % len(self.train)]
self.iter += 1
print(f"DEBUG: Retrieved dataset item {self.iter-1}")
# For debugging, we'll use a simple text-only prompt and store the image separately
# This is because the collect_trajectories method will handle the multimodal formatting
# Store image as a class attribute so collect_trajectories can access it
self.current_image = next_item["image"]
print("DEBUG: Stored image in current_image attribute")
# Create a simple text prompt - the image will be added in collect_trajectories
# This avoids the unhashable type error with lists in frozensets
@ -173,11 +151,9 @@ class MultimodalComplexEnv(BaseEnv):
img_byte_arr = img_byte_arr.getvalue()
base64_image = base64.b64encode(img_byte_arr).decode("utf-8")
print("DEBUG: Created simple text-only prompt for get_next_item")
return (prompt, answer, base64_image)
except Exception as e:
print(f"DEBUG: Error in get_next_item: {str(e)}")
except Exception:
traceback.print_exc()
# Create a dummy item as fallback
@ -208,9 +184,6 @@ class MultimodalComplexEnv(BaseEnv):
model_answer = (
item[0][-1]["content"].split("\\boxed{")[-1].split("}")[0]
)
print(
f"DEBUG: Model answer: {model_answer} and RG data: {rollout_group_data[0][1]}"
)
# Handle both numeric and yes/no answers
gold_answer = rollout_group_data[0][1]
@ -244,35 +217,26 @@ class MultimodalComplexEnv(BaseEnv):
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
if not os.environ.get("OPENAI_API_KEY"):
print("ERROR: OPENAI_API_KEY environment variable is not set!")
print("Please set it using: export OPENAI_API_KEY=your_api_key")
sys.exit(1)
print(
f"DEBUG: Using API key starting with: {os.environ.get('OPENAI_API_KEY')[:5]}..."
)
config = BaseEnvConfig(
wandb_name="clevr_complex",
tokenizer_name="gpt2",
group_size=2,
use_wandb=False,
tokenizer_name="Qwen/Qwen2-VL-2B-Instruct",
group_size=8,
use_wandb=True,
max_num_workers=2,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=1,
steps_per_eval=10,
ensure_scores_are_not_same=False,
batch_size=12,
steps_per_eval=100,
max_token_length=2048,
)
print("DEBUG: Creating OpenAI configuration")
server_configs = [
OpenaiConfig(
model_name="gpt-4o", # Using GPT-4o which has multimodal capabilities
base_url=None,
api_key=os.environ.get("OPENAI_API_KEY"),
num_requests_for_eval=1,
model_name="Qwen/Qwen2-VL-2B-Instruct",
base_url="http://localhost:9001/v1",
api_key="x",
num_requests_for_eval=256,
),
]

View file

@ -1,9 +1,7 @@
import base64
import io
import os
import random
import re
import sys
import traceback
from typing import List, Optional, Tuple
@ -68,13 +66,8 @@ class OcrVqaEnv(BaseEnv):
history: GameHistory = (user_hist, assistant_hist)
to_score.append((history, gold, base64_image))
return to_score, to_backlog
async def postprocess_histories(
self, trajectories: List[GameHistory]
) -> ScoredDataGroup:
# No additional post-processing needed
pass
to_postprocess = await self.score(to_score)
return to_postprocess, to_backlog
async def evaluate(self, *args, **kwargs):
# No custom evaluation
@ -167,29 +160,26 @@ class OcrVqaEnv(BaseEnv):
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
if not os.environ.get("OPENAI_API_KEY"):
print("ERROR: OPENAI_API_KEY environment variable is not set!")
sys.exit(1)
config = BaseEnvConfig(
wandb_name="ocr_vqa",
tokenizer_name="gpt2",
group_size=2,
use_wandb=False,
tokenizer_name="Qwen/Qwen2-VL-2B-Instruct",
group_size=8,
use_wandb=True,
max_num_workers=2,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=1,
steps_per_eval=10,
ensure_scores_are_not_same=False,
batch_size=12,
steps_per_eval=100,
max_token_length=2048,
)
server_configs = [
OpenaiConfig(
model_name="gpt-4o",
base_url=None,
api_key=os.environ.get("OPENAI_API_KEY"),
num_requests_for_eval=1,
model_name="Qwen/Qwen2-VL-2B-Instruct",
base_url="http://localhost:9001/v1",
api_key="x",
num_requests_for_eval=256,
),
]

View file

@ -1,9 +1,7 @@
import base64
import io
import os
import random
import re
import sys
import traceback
from typing import List, Optional, Tuple
@ -68,13 +66,9 @@ class ClockDatasetEnv(BaseEnv):
history: GameHistory = (user_msg, assistant_msg)
to_score.append((history, item[1], base64_image))
return to_score, to_backlog
to_postprocess = await self.score(to_score)
async def postprocess_histories(
self, trajectories: List[GameHistory]
) -> ScoredDataGroup:
# No custom post-processing
pass
return to_postprocess, to_backlog
async def evaluate(self, *args, **kwargs):
# No custom evaluation
@ -87,6 +81,7 @@ class ClockDatasetEnv(BaseEnv):
self.iter = 0
async def get_next_item(self) -> Item:
try:
entry = self.train[self.iter % len(self.train)]
self.iter += 1
@ -128,6 +123,7 @@ class ClockDatasetEnv(BaseEnv):
scores["scores"] = []
scores["images"] = []
random.shuffle(rollout_group_data)
for item in rollout_group_data:
out = tokenize_for_trainer(self.tokenizer, item[0])
tokens = out["tokens"]
@ -169,29 +165,26 @@ class ClockDatasetEnv(BaseEnv):
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
if not os.environ.get("OPENAI_API_KEY"):
print("ERROR: OPENAI_API_KEY environment variable is not set!")
sys.exit(1)
config = BaseEnvConfig(
wandb_name="clocks",
tokenizer_name="gpt2",
group_size=2,
use_wandb=False,
wandb_name="pixmo_clocks",
tokenizer_name="Qwen/Qwen2-VL-2B-Instruct",
group_size=8,
use_wandb=True,
max_num_workers=2,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=1,
steps_per_eval=10,
ensure_scores_are_not_same=False,
batch_size=12,
steps_per_eval=100,
max_token_length=2048,
)
server_configs = [
OpenaiConfig(
model_name="gpt-4o",
base_url=None,
api_key=os.environ.get("OPENAI_API_KEY"),
num_requests_for_eval=1,
model_name="Qwen/Qwen2-VL-2B-Instruct",
base_url="http://localhost:9001/v1",
api_key="x",
num_requests_for_eval=256,
),
]

View file

@ -1,9 +1,7 @@
import base64
import io
import os
import random
import re
import sys
import traceback
from typing import List, Optional, Tuple
@ -103,13 +101,9 @@ class PixmoCountEnv(BaseEnv):
history: GameHistory = (user_hist, assistant_hist)
to_score.append((history, gold, base64_image))
return to_score, to_backlog
to_postprocess = await self.score(to_score)
async def postprocess_histories(
self, trajectories: List[GameHistory]
) -> ScoredDataGroup:
# No custom post-processing
pass
return to_postprocess, to_backlog
async def evaluate(self, *args, **kwargs):
# No custom evaluation
@ -158,29 +152,26 @@ class PixmoCountEnv(BaseEnv):
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
if not os.environ.get("OPENAI_API_KEY"):
print("ERROR: OPENAI_API_KEY environment variable is not set!")
sys.exit(1)
config = BaseEnvConfig(
wandb_name="pixmo_count",
tokenizer_name="gpt2",
group_size=2,
use_wandb=False,
tokenizer_name="Qwen/Qwen2-VL-2B-Instruct",
group_size=8,
use_wandb=True,
max_num_workers=2,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=1,
steps_per_eval=10,
ensure_scores_are_not_same=False,
batch_size=12,
steps_per_eval=100,
max_token_length=2048,
)
server_configs = [
OpenaiConfig(
model_name="gpt-4o",
base_url=None,
api_key=os.environ.get("OPENAI_API_KEY"),
num_requests_for_eval=1,
model_name="Qwen/Qwen2-VL-2B-Instruct",
base_url="http://localhost:9001/v1",
api_key="x",
num_requests_for_eval=256,
),
]

View file

@ -1,9 +1,7 @@
import base64
import io
import os
import random
import re
import sys
import traceback
from typing import List, Optional, Tuple
@ -50,8 +48,7 @@ class PixmoPointExplanationsEnv(BaseEnv):
img.save(buf, format="PNG")
img_bytes = buf.getvalue()
base64_image = base64.b64encode(img_bytes).decode("utf-8")
except Exception as e:
print(f"Error loading image from URL: {e}")
except Exception:
base64_image = None
return (prompt, gold_answer, base64_image)
@ -113,13 +110,9 @@ class PixmoPointExplanationsEnv(BaseEnv):
history: GameHistory = (user_hist, assistant_hist)
to_score.append((history, gold, base64_image))
return to_score, to_backlog
to_postprocess = await self.score(to_score)
async def postprocess_histories(
self, trajectories: List[GameHistory]
) -> ScoredDataGroup:
# No custom post-processing needed
pass
return to_postprocess, to_backlog
async def evaluate(self, *args, **kwargs):
# No custom evaluation
@ -169,29 +162,26 @@ class PixmoPointExplanationsEnv(BaseEnv):
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
if not os.environ.get("OPENAI_API_KEY"):
print("ERROR: OPENAI_API_KEY environment variable is not set!")
sys.exit(1)
config = BaseEnvConfig(
wandb_name="pixmo_point_explanations",
tokenizer_name="gpt2",
group_size=2,
use_wandb=False,
tokenizer_name="Qwen/Qwen2-VL-2B-Instruct",
group_size=8,
use_wandb=True,
max_num_workers=2,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=1,
steps_per_eval=10,
ensure_scores_are_not_same=False,
batch_size=12,
steps_per_eval=100,
max_token_length=2048,
)
server_configs = [
OpenaiConfig(
model_name="gpt-4o",
base_url=None,
api_key=os.environ.get("OPENAI_API_KEY"),
num_requests_for_eval=1,
model_name="Qwen/Qwen2-VL-2B-Instruct",
base_url="http://localhost:9001/v1",
api_key="x",
num_requests_for_eval=256,
),
]