fix multimodal envs. add view_run_multimodal

This commit is contained in:
Artem Yatsenko 2025-05-07 21:53:01 +00:00
parent a282604baa
commit 0f15be68a2
8 changed files with 265 additions and 187 deletions

View file

@ -1,9 +1,7 @@
import base64
import io
import os
import random
import re
import sys
import traceback
from typing import List, Optional, Tuple
@ -68,13 +66,9 @@ class ClockDatasetEnv(BaseEnv):
history: GameHistory = (user_msg, assistant_msg)
to_score.append((history, item[1], base64_image))
return to_score, to_backlog
to_postprocess = await self.score(to_score)
async def postprocess_histories(
self, trajectories: List[GameHistory]
) -> ScoredDataGroup:
# No custom post-processing
pass
return to_postprocess, to_backlog
async def evaluate(self, *args, **kwargs):
# No custom evaluation
@ -87,6 +81,7 @@ class ClockDatasetEnv(BaseEnv):
self.iter = 0
async def get_next_item(self) -> Item:
try:
entry = self.train[self.iter % len(self.train)]
self.iter += 1
@ -128,6 +123,7 @@ class ClockDatasetEnv(BaseEnv):
scores["scores"] = []
scores["images"] = []
random.shuffle(rollout_group_data)
for item in rollout_group_data:
out = tokenize_for_trainer(self.tokenizer, item[0])
tokens = out["tokens"]
@ -169,29 +165,26 @@ class ClockDatasetEnv(BaseEnv):
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
if not os.environ.get("OPENAI_API_KEY"):
print("ERROR: OPENAI_API_KEY environment variable is not set!")
sys.exit(1)
config = BaseEnvConfig(
wandb_name="clocks",
tokenizer_name="gpt2",
group_size=2,
use_wandb=False,
wandb_name="pixmo_clocks",
tokenizer_name="Qwen/Qwen2-VL-2B-Instruct",
group_size=8,
use_wandb=True,
max_num_workers=2,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=1,
steps_per_eval=10,
ensure_scores_are_not_same=False,
batch_size=12,
steps_per_eval=100,
max_token_length=2048,
)
server_configs = [
OpenaiConfig(
model_name="gpt-4o",
base_url=None,
api_key=os.environ.get("OPENAI_API_KEY"),
num_requests_for_eval=1,
model_name="Qwen/Qwen2-VL-2B-Instruct",
base_url="http://localhost:9001/v1",
api_key="x",
num_requests_for_eval=256,
),
]