fix multimodal envs. add view_run_multimodal

This commit is contained in:
Artem Yatsenko 2025-05-07 21:53:01 +00:00
parent a282604baa
commit 0f15be68a2
8 changed files with 265 additions and 187 deletions

View file

@ -1,9 +1,7 @@
import base64
import io
import os
import random
import re
import sys
import traceback
from typing import List, Optional, Tuple
@ -50,8 +48,7 @@ class PixmoPointExplanationsEnv(BaseEnv):
img.save(buf, format="PNG")
img_bytes = buf.getvalue()
base64_image = base64.b64encode(img_bytes).decode("utf-8")
except Exception as e:
print(f"Error loading image from URL: {e}")
except Exception:
base64_image = None
return (prompt, gold_answer, base64_image)
@ -113,13 +110,9 @@ class PixmoPointExplanationsEnv(BaseEnv):
history: GameHistory = (user_hist, assistant_hist)
to_score.append((history, gold, base64_image))
return to_score, to_backlog
to_postprocess = await self.score(to_score)
async def postprocess_histories(
self, trajectories: List[GameHistory]
) -> ScoredDataGroup:
# No custom post-processing needed
pass
return to_postprocess, to_backlog
async def evaluate(self, *args, **kwargs):
# No custom evaluation
@ -169,29 +162,26 @@ class PixmoPointExplanationsEnv(BaseEnv):
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
if not os.environ.get("OPENAI_API_KEY"):
print("ERROR: OPENAI_API_KEY environment variable is not set!")
sys.exit(1)
config = BaseEnvConfig(
wandb_name="pixmo_point_explanations",
tokenizer_name="gpt2",
group_size=2,
use_wandb=False,
tokenizer_name="Qwen/Qwen2-VL-2B-Instruct",
group_size=8,
use_wandb=True,
max_num_workers=2,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=1,
steps_per_eval=10,
ensure_scores_are_not_same=False,
batch_size=12,
steps_per_eval=100,
max_token_length=2048,
)
server_configs = [
OpenaiConfig(
model_name="gpt-4o",
base_url=None,
api_key=os.environ.get("OPENAI_API_KEY"),
num_requests_for_eval=1,
model_name="Qwen/Qwen2-VL-2B-Instruct",
base_url="http://localhost:9001/v1",
api_key="x",
num_requests_for_eval=256,
),
]