diff --git a/atroposlib/cli/view_run_multimodal.py b/atroposlib/cli/view_run_multimodal.py new file mode 100644 index 00000000..c7ae0592 --- /dev/null +++ b/atroposlib/cli/view_run_multimodal.py @@ -0,0 +1,184 @@ +import argparse +import asyncio +import base64 +import re +from io import BytesIO + +import aiohttp +import gradio as gr +import PIL.Image +from transformers import AutoTokenizer + + +def find_common_prefix(strings): + if not strings: + return "" + + prefix = strings[0] + for s in strings[1:]: + while not s.startswith(prefix): + prefix = prefix[:-1] + if not prefix: + return "" + return prefix + + +async def register_to_api(group_size, max_token_len): + async with aiohttp.ClientSession() as session: + async with session.get("http://localhost:8000/reset_data") as response: + print(await response.text()) + print(group_size) + async with session.post( + "http://localhost:8000/register", + json={ + "wandb_group": "test", + "wandb_project": "test", + "batch_size": group_size + * 8, # * 8 just in case you want to just sample from a large group + "max_token_len": max_token_len, + "checkpoint_dir": "checkpoints", + "save_checkpoint_interval": 10, + "starting_step": 0, + "num_steps": 69, + }, + ) as response: + print("output of register is") + print(await response.text()) + + +async def check_for_batch(): + while True: + async with aiohttp.ClientSession() as session: + async with session.get("http://localhost:8000/batch") as response: + data = await response.json() + print(data) + if data["batch"] is not None: + return data["batch"] + await asyncio.sleep(1) + + +def extract_image_from_chat(chat_text): + # Extract the base64 image data from the chat text + # Support both jpeg and png formats + image_pattern = r'data:image/(jpeg|png);base64,([^"\\]*)' + match = re.search(image_pattern, chat_text) + + if match: + base64_data = match.group(2) + try: + image_data = base64.b64decode(base64_data) + image = PIL.Image.open(BytesIO(image_data)) + return image + except Exception as e: + print(f"Error decoding image: {e}") + return None + + +def extract_text_from_chat(chat_text): + # Try to extract text from JSON format first + # Check if this is JSON multimodal content + if '"type": "text"' in chat_text: + text_pattern = r'"type": "text", "text": "([^"]*)"' + match = re.search(text_pattern, chat_text) + if match: + return match.group(1) + + # If not in JSON format, look for [Image] prefix + if "[Image]" in chat_text: + return chat_text.split("[Image]", 1)[1].strip() + + # Return original text if no pattern is found + return chat_text + + +async def build_interface(group_size, max_token_len, tokenizer, port): + async def grab_batch(): + tok = AutoTokenizer.from_pretrained(tokenizer) + data = await check_for_batch() + print(data) + chats = [tok.decode(chat) for chat in data[0]["tokens"]] + + # Find common prefix + prefix = find_common_prefix(chats) + + # Handle base64 encoded image + try: + if "images" in data[0] and data[0]["images"] and data[0]["images"][0]: + print("Found image data in batch") + # Convert base64 string to image + base64_image = data[0]["images"][0] + + # If it's already a PIL Image, use it directly + if isinstance(base64_image, PIL.Image.Image): + image = base64_image + # If it's a base64 string, decode it + elif isinstance(base64_image, str): + # Remove data:image prefix if present + if base64_image.startswith("data:image"): + # Extract just the base64 part + image_data = base64_image.split(",", 1)[1] + else: + image_data = base64_image + + # Decode base64 to bytes and create image + image_bytes = base64.b64decode(image_data) + image = PIL.Image.open(BytesIO(image_bytes)) + else: + print(f"Image type not recognized: {type(base64_image)}") + image = None + else: + # Try to extract image from chat text as fallback + print("No images field found, trying to extract from chat text") + image = extract_image_from_chat(prefix) + except Exception as e: + print(f"Error processing image: {e}") + image = None + + # Extract text prompt from prefix + text_prompt = extract_text_from_chat(prefix) + + return ( + image, # Image + text_prompt, # Text prompt + *[chat.split(prefix)[1] for chat in chats[:group_size]], # Model outputs + *data[0]["scores"][:group_size], # Scores + ) + + with gr.Blocks() as demo: + image_blk = gr.Image(label="Image", type="pil") + prompt_blk = gr.Textbox(label="Text Prompt") + + with gr.Row(): + score_blks = [gr.Textbox(label=f"Score_{i+1}") for i in range(group_size)] + + with gr.Row(): + outputs_blks = [ + gr.Textbox(label=f"Output_{i+1}") for i in range(group_size) + ] + + with gr.Row(): + grab_next = gr.Button(value="Grab Next Batch") + + grab_next.click( + fn=grab_batch, + outputs=[image_blk, prompt_blk] + outputs_blks + score_blks, + api_name="get_batch", + ) + await register_to_api(group_size, max_token_len) + demo.launch(server_port=port, share=True) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--port", type=int, default=9001) + parser.add_argument("--group-size", type=int, default=2) + parser.add_argument("--max-token-len", type=int, default=2048) + parser.add_argument("--tokenizer", type=str, default="Qwen/Qwen2-VL-2B-Instruct") + args = parser.parse_args() + asyncio.run( + build_interface(args.group_size, args.max_token_len, args.tokenizer, args.port) + ) + + +if __name__ == "__main__": + main() diff --git a/environments/multimodal_dpo/clevr_cogen_a_train.py b/environments/multimodal_dpo/clevr_cogen_a_train.py index f9e786e3..a1c60a26 100644 --- a/environments/multimodal_dpo/clevr_cogen_a_train.py +++ b/environments/multimodal_dpo/clevr_cogen_a_train.py @@ -1,9 +1,7 @@ import base64 import json -import os import random import re -import sys import traceback from typing import List, Optional, Tuple @@ -24,13 +22,11 @@ class MultimodalExampleEnv(BaseEnv): async def collect_trajectories( self, item: Item ) -> Tuple[GameHistory | None, List[Item]]: - print("DEBUG: Starting collect_trajectories") to_score = list() to_backlog = list() # Get the current image if it was stored if hasattr(self, "current_image"): - print("DEBUG: Using current_image for multimodal content") # Convert PIL image to base64 import io @@ -56,14 +52,12 @@ class MultimodalExampleEnv(BaseEnv): if not text_content: text_content = "Please solve this problem and provide your answer as \\boxed{answer}." - except Exception as e: - print(f"DEBUG: Error parsing JSON: {e}") + except Exception: text_content = "Please solve this problem and provide your answer as \\boxed{answer}." else: text_content = user_content # Create messages with the new format - print("DEBUG: Creating multimodal message with new format") messages = [ { "role": "system", @@ -84,7 +78,6 @@ class MultimodalExampleEnv(BaseEnv): ] else: - print("DEBUG: No image available, using text-only message") messages = [ { "role": "system", @@ -93,26 +86,20 @@ class MultimodalExampleEnv(BaseEnv): dict(item[0][0]), ] - print("DEBUG: About to call chat_completion") chat_completions = await self.server.chat_completion( messages=messages, n=self.config.group_size, max_tokens=1024 * 2, timeout=60, # Add timeout to prevent hanging (60 seconds is more reasonable) ) - print("DEBUG: chat_completion call successful") for i, chat_completion in enumerate(chat_completions.choices): - print(f"DEBUG: Processing completion {i+1}/{len(chat_completions.choices)}") messages = ( dict(item[0][0]), {"role": "assistant", "content": chat_completion.message.content}, ) to_score.append((messages, item[1], base64_image)) - print("DEBUG: Finished processing completions") - - print("DEBUG: Returning from collect_trajectories") return to_score, to_backlog async def postprocess_histories( @@ -141,20 +128,12 @@ class MultimodalExampleEnv(BaseEnv): Get the next items to be rolled out, including the image """ try: - print("DEBUG: Starting get_next_item") - # Get next dataset item next_item = self.train[self.iter % len(self.train)] self.iter += 1 - print(f"DEBUG: Retrieved dataset item {self.iter-1}") - - # For debugging, we'll use a simple text-only prompt and store the image separately - # This is because the collect_trajectories method will handle the multimodal formatting - # Store image as a class attribute so collect_trajectories can access it self.current_image = next_item["image"] - print("DEBUG: Stored image in current_image attribute") # Create a simple text prompt - the image will be added in collect_trajectories # This avoids the unhashable type error with lists in frozensets @@ -177,11 +156,9 @@ class MultimodalExampleEnv(BaseEnv): img_byte_arr = img_byte_arr.getvalue() base64_image = base64.b64encode(img_byte_arr).decode("utf-8") - print("DEBUG: Created simple text-only prompt for get_next_item") return (prompt, answer, base64_image) - except Exception as e: - print(f"DEBUG: Error in get_next_item: {str(e)}") + except Exception: traceback.print_exc() # Create a dummy item as fallback @@ -212,9 +189,6 @@ class MultimodalExampleEnv(BaseEnv): model_answer = ( item[0][-1]["content"].split("\\boxed{")[-1].split("}")[0] ) - print( - f"DEBUG: Model answer: {model_answer} and RG data: {rollout_group_data[0][1]}" - ) pattern = r"\s*(\d{1,2})\s*" string = rollout_group_data[0][1] @@ -243,35 +217,26 @@ class MultimodalExampleEnv(BaseEnv): @classmethod def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: - if not os.environ.get("OPENAI_API_KEY"): - print("ERROR: OPENAI_API_KEY environment variable is not set!") - print("Please set it using: export OPENAI_API_KEY=your_api_key") - sys.exit(1) - - print( - f"DEBUG: Using API key starting with: {os.environ.get('OPENAI_API_KEY')[:5]}..." - ) config = BaseEnvConfig( - wandb_name="clevr_cogen", - tokenizer_name="gpt2", - group_size=2, - use_wandb=False, + wandb_name="clevr_cogen_a_train", + tokenizer_name="Qwen/Qwen2-VL-2B-Instruct", + group_size=8, + use_wandb=True, max_num_workers=2, rollout_server_url="http://localhost:8000", total_steps=1000, - batch_size=1, - steps_per_eval=10, - ensure_scores_are_not_same=False, + batch_size=12, + steps_per_eval=100, + max_token_length=2048, ) - print("DEBUG: Creating OpenAI configuration") server_configs = [ OpenaiConfig( - model_name="gpt-4o", # Using GPT-4o which has multimodal capabilities - base_url=None, - api_key=os.environ.get("OPENAI_API_KEY"), - num_requests_for_eval=1, + model_name="Qwen/Qwen2-VL-2B-Instruct", + base_url="http://localhost:9001/v1", + api_key="x", + num_requests_for_eval=256, ), ] diff --git a/environments/multimodal_dpo/clevr_complex.py b/environments/multimodal_dpo/clevr_complex.py index ca99b83c..4f68080f 100644 --- a/environments/multimodal_dpo/clevr_complex.py +++ b/environments/multimodal_dpo/clevr_complex.py @@ -1,8 +1,6 @@ import base64 import json -import os import random -import sys import traceback from typing import List, Optional, Tuple @@ -10,6 +8,7 @@ from datasets import load_dataset from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup from atroposlib.type_definitions import GameHistory, Item + from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer @@ -23,13 +22,11 @@ class MultimodalComplexEnv(BaseEnv): async def collect_trajectories( self, item: Item ) -> Tuple[GameHistory | None, List[Item]]: - print("DEBUG: Starting collect_trajectories") to_score = list() to_backlog = list() # Get the current image if it was stored if hasattr(self, "current_image"): - print("DEBUG: Using current_image for multimodal content") # Convert PIL image to base64 import io @@ -55,14 +52,12 @@ class MultimodalComplexEnv(BaseEnv): if not text_content: text_content = "Please solve this problem and provide your answer as \\boxed{answer}." - except Exception as e: - print(f"DEBUG: Error parsing JSON: {e}") + except Exception: text_content = "Please solve this problem and provide your answer as \\boxed{answer}." else: text_content = user_content # Create messages with the new format - print("DEBUG: Creating multimodal message with new format") messages = [ { "role": "system", @@ -83,7 +78,6 @@ class MultimodalComplexEnv(BaseEnv): ] else: - print("DEBUG: No image available, using text-only message") messages = [ { "role": "system", @@ -92,32 +86,23 @@ class MultimodalComplexEnv(BaseEnv): dict(item[0][0]), ] - print("DEBUG: About to call chat_completion") chat_completions = await self.server.chat_completion( messages=messages, n=self.config.group_size, max_tokens=1024 * 2, timeout=60, # Add timeout to prevent hanging (60 seconds is more reasonable) ) - print("DEBUG: chat_completion call successful") for i, chat_completion in enumerate(chat_completions.choices): - print(f"DEBUG: Processing completion {i+1}/{len(chat_completions.choices)}") messages = ( dict(item[0][0]), {"role": "assistant", "content": chat_completion.message.content}, ) to_score.append((messages, item[1], base64_image)) - print("DEBUG: Finished processing completions") + to_postprocess = await self.score(to_score) - print("DEBUG: Returning from collect_trajectories") - return to_score, to_backlog - - async def postprocess_histories( - self, trajectories: List[GameHistory] - ) -> ScoredDataGroup: - pass + return to_postprocess, to_backlog async def evaluate(self, *args, **kwargs): """ @@ -140,20 +125,13 @@ class MultimodalComplexEnv(BaseEnv): Get the next items to be rolled out, including the image """ try: - print("DEBUG: Starting get_next_item") # Get next dataset item next_item = self.train[self.iter % len(self.train)] self.iter += 1 - print(f"DEBUG: Retrieved dataset item {self.iter-1}") - - # For debugging, we'll use a simple text-only prompt and store the image separately - # This is because the collect_trajectories method will handle the multimodal formatting - # Store image as a class attribute so collect_trajectories can access it self.current_image = next_item["image"] - print("DEBUG: Stored image in current_image attribute") # Create a simple text prompt - the image will be added in collect_trajectories # This avoids the unhashable type error with lists in frozensets @@ -173,11 +151,9 @@ class MultimodalComplexEnv(BaseEnv): img_byte_arr = img_byte_arr.getvalue() base64_image = base64.b64encode(img_byte_arr).decode("utf-8") - print("DEBUG: Created simple text-only prompt for get_next_item") return (prompt, answer, base64_image) - except Exception as e: - print(f"DEBUG: Error in get_next_item: {str(e)}") + except Exception: traceback.print_exc() # Create a dummy item as fallback @@ -208,9 +184,6 @@ class MultimodalComplexEnv(BaseEnv): model_answer = ( item[0][-1]["content"].split("\\boxed{")[-1].split("}")[0] ) - print( - f"DEBUG: Model answer: {model_answer} and RG data: {rollout_group_data[0][1]}" - ) # Handle both numeric and yes/no answers gold_answer = rollout_group_data[0][1] @@ -244,35 +217,26 @@ class MultimodalComplexEnv(BaseEnv): @classmethod def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: - if not os.environ.get("OPENAI_API_KEY"): - print("ERROR: OPENAI_API_KEY environment variable is not set!") - print("Please set it using: export OPENAI_API_KEY=your_api_key") - sys.exit(1) - - print( - f"DEBUG: Using API key starting with: {os.environ.get('OPENAI_API_KEY')[:5]}..." - ) config = BaseEnvConfig( wandb_name="clevr_complex", - tokenizer_name="gpt2", - group_size=2, - use_wandb=False, + tokenizer_name="Qwen/Qwen2-VL-2B-Instruct", + group_size=8, + use_wandb=True, max_num_workers=2, rollout_server_url="http://localhost:8000", total_steps=1000, - batch_size=1, - steps_per_eval=10, - ensure_scores_are_not_same=False, + batch_size=12, + steps_per_eval=100, + max_token_length=2048, ) - print("DEBUG: Creating OpenAI configuration") server_configs = [ OpenaiConfig( - model_name="gpt-4o", # Using GPT-4o which has multimodal capabilities - base_url=None, - api_key=os.environ.get("OPENAI_API_KEY"), - num_requests_for_eval=1, + model_name="Qwen/Qwen2-VL-2B-Instruct", + base_url="http://localhost:9001/v1", + api_key="x", + num_requests_for_eval=256, ), ] diff --git a/environments/multimodal_dpo/ocr_vqa.py b/environments/multimodal_dpo/ocr_vqa.py index aa6978ef..1ac43556 100644 --- a/environments/multimodal_dpo/ocr_vqa.py +++ b/environments/multimodal_dpo/ocr_vqa.py @@ -1,9 +1,7 @@ import base64 import io -import os import random import re -import sys import traceback from typing import List, Optional, Tuple @@ -68,13 +66,8 @@ class OcrVqaEnv(BaseEnv): history: GameHistory = (user_hist, assistant_hist) to_score.append((history, gold, base64_image)) - return to_score, to_backlog - - async def postprocess_histories( - self, trajectories: List[GameHistory] - ) -> ScoredDataGroup: - # No additional post-processing needed - pass + to_postprocess = await self.score(to_score) + return to_postprocess, to_backlog async def evaluate(self, *args, **kwargs): # No custom evaluation @@ -167,29 +160,26 @@ class OcrVqaEnv(BaseEnv): @classmethod def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: - if not os.environ.get("OPENAI_API_KEY"): - print("ERROR: OPENAI_API_KEY environment variable is not set!") - sys.exit(1) config = BaseEnvConfig( wandb_name="ocr_vqa", - tokenizer_name="gpt2", - group_size=2, - use_wandb=False, + tokenizer_name="Qwen/Qwen2-VL-2B-Instruct", + group_size=8, + use_wandb=True, max_num_workers=2, rollout_server_url="http://localhost:8000", total_steps=1000, - batch_size=1, - steps_per_eval=10, - ensure_scores_are_not_same=False, + batch_size=12, + steps_per_eval=100, + max_token_length=2048, ) server_configs = [ OpenaiConfig( - model_name="gpt-4o", - base_url=None, - api_key=os.environ.get("OPENAI_API_KEY"), - num_requests_for_eval=1, + model_name="Qwen/Qwen2-VL-2B-Instruct", + base_url="http://localhost:9001/v1", + api_key="x", + num_requests_for_eval=256, ), ] diff --git a/environments/multimodal_dpo/pixmo_clocks.py b/environments/multimodal_dpo/pixmo_clocks.py index bf4059e8..d929e9b4 100644 --- a/environments/multimodal_dpo/pixmo_clocks.py +++ b/environments/multimodal_dpo/pixmo_clocks.py @@ -1,9 +1,7 @@ import base64 import io -import os import random import re -import sys import traceback from typing import List, Optional, Tuple @@ -68,13 +66,9 @@ class ClockDatasetEnv(BaseEnv): history: GameHistory = (user_msg, assistant_msg) to_score.append((history, item[1], base64_image)) - return to_score, to_backlog + to_postprocess = await self.score(to_score) - async def postprocess_histories( - self, trajectories: List[GameHistory] - ) -> ScoredDataGroup: - # No custom post-processing - pass + return to_postprocess, to_backlog async def evaluate(self, *args, **kwargs): # No custom evaluation @@ -87,6 +81,7 @@ class ClockDatasetEnv(BaseEnv): self.iter = 0 async def get_next_item(self) -> Item: + try: entry = self.train[self.iter % len(self.train)] self.iter += 1 @@ -128,6 +123,7 @@ class ClockDatasetEnv(BaseEnv): scores["scores"] = [] scores["images"] = [] random.shuffle(rollout_group_data) + for item in rollout_group_data: out = tokenize_for_trainer(self.tokenizer, item[0]) tokens = out["tokens"] @@ -169,29 +165,26 @@ class ClockDatasetEnv(BaseEnv): @classmethod def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: - if not os.environ.get("OPENAI_API_KEY"): - print("ERROR: OPENAI_API_KEY environment variable is not set!") - sys.exit(1) config = BaseEnvConfig( - wandb_name="clocks", - tokenizer_name="gpt2", - group_size=2, - use_wandb=False, + wandb_name="pixmo_clocks", + tokenizer_name="Qwen/Qwen2-VL-2B-Instruct", + group_size=8, + use_wandb=True, max_num_workers=2, rollout_server_url="http://localhost:8000", total_steps=1000, - batch_size=1, - steps_per_eval=10, - ensure_scores_are_not_same=False, + batch_size=12, + steps_per_eval=100, + max_token_length=2048, ) server_configs = [ OpenaiConfig( - model_name="gpt-4o", - base_url=None, - api_key=os.environ.get("OPENAI_API_KEY"), - num_requests_for_eval=1, + model_name="Qwen/Qwen2-VL-2B-Instruct", + base_url="http://localhost:9001/v1", + api_key="x", + num_requests_for_eval=256, ), ] diff --git a/environments/multimodal_dpo/pixmo_count.py b/environments/multimodal_dpo/pixmo_count.py index ab158da6..6a23990d 100644 --- a/environments/multimodal_dpo/pixmo_count.py +++ b/environments/multimodal_dpo/pixmo_count.py @@ -1,9 +1,7 @@ import base64 import io -import os import random import re -import sys import traceback from typing import List, Optional, Tuple @@ -103,13 +101,9 @@ class PixmoCountEnv(BaseEnv): history: GameHistory = (user_hist, assistant_hist) to_score.append((history, gold, base64_image)) - return to_score, to_backlog + to_postprocess = await self.score(to_score) - async def postprocess_histories( - self, trajectories: List[GameHistory] - ) -> ScoredDataGroup: - # No custom post-processing - pass + return to_postprocess, to_backlog async def evaluate(self, *args, **kwargs): # No custom evaluation @@ -158,29 +152,26 @@ class PixmoCountEnv(BaseEnv): @classmethod def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: - if not os.environ.get("OPENAI_API_KEY"): - print("ERROR: OPENAI_API_KEY environment variable is not set!") - sys.exit(1) config = BaseEnvConfig( wandb_name="pixmo_count", - tokenizer_name="gpt2", - group_size=2, - use_wandb=False, + tokenizer_name="Qwen/Qwen2-VL-2B-Instruct", + group_size=8, + use_wandb=True, max_num_workers=2, rollout_server_url="http://localhost:8000", total_steps=1000, - batch_size=1, - steps_per_eval=10, - ensure_scores_are_not_same=False, + batch_size=12, + steps_per_eval=100, + max_token_length=2048, ) server_configs = [ OpenaiConfig( - model_name="gpt-4o", - base_url=None, - api_key=os.environ.get("OPENAI_API_KEY"), - num_requests_for_eval=1, + model_name="Qwen/Qwen2-VL-2B-Instruct", + base_url="http://localhost:9001/v1", + api_key="x", + num_requests_for_eval=256, ), ] diff --git a/environments/multimodal_dpo/pixmo_point_explanations.py b/environments/multimodal_dpo/pixmo_point_explanations.py index 5932a867..1eef6a4a 100644 --- a/environments/multimodal_dpo/pixmo_point_explanations.py +++ b/environments/multimodal_dpo/pixmo_point_explanations.py @@ -1,9 +1,7 @@ import base64 import io -import os import random import re -import sys import traceback from typing import List, Optional, Tuple @@ -50,8 +48,7 @@ class PixmoPointExplanationsEnv(BaseEnv): img.save(buf, format="PNG") img_bytes = buf.getvalue() base64_image = base64.b64encode(img_bytes).decode("utf-8") - except Exception as e: - print(f"Error loading image from URL: {e}") + except Exception: base64_image = None return (prompt, gold_answer, base64_image) @@ -113,13 +110,9 @@ class PixmoPointExplanationsEnv(BaseEnv): history: GameHistory = (user_hist, assistant_hist) to_score.append((history, gold, base64_image)) - return to_score, to_backlog + to_postprocess = await self.score(to_score) - async def postprocess_histories( - self, trajectories: List[GameHistory] - ) -> ScoredDataGroup: - # No custom post-processing needed - pass + return to_postprocess, to_backlog async def evaluate(self, *args, **kwargs): # No custom evaluation @@ -169,29 +162,26 @@ class PixmoPointExplanationsEnv(BaseEnv): @classmethod def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: - if not os.environ.get("OPENAI_API_KEY"): - print("ERROR: OPENAI_API_KEY environment variable is not set!") - sys.exit(1) config = BaseEnvConfig( wandb_name="pixmo_point_explanations", - tokenizer_name="gpt2", - group_size=2, - use_wandb=False, + tokenizer_name="Qwen/Qwen2-VL-2B-Instruct", + group_size=8, + use_wandb=True, max_num_workers=2, rollout_server_url="http://localhost:8000", total_steps=1000, - batch_size=1, - steps_per_eval=10, - ensure_scores_are_not_same=False, + batch_size=12, + steps_per_eval=100, + max_token_length=2048, ) server_configs = [ OpenaiConfig( - model_name="gpt-4o", - base_url=None, - api_key=os.environ.get("OPENAI_API_KEY"), - num_requests_for_eval=1, + model_name="Qwen/Qwen2-VL-2B-Instruct", + base_url="http://localhost:9001/v1", + api_key="x", + num_requests_for_eval=256, ), ] diff --git a/pyproject.toml b/pyproject.toml index 50ad2c59..36b6f2bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ run-api = "atroposlib.cli.run_api:main" inference-node-wandb-watcher = "atroposlib.cli.inference_node_wandb_watcher:main" view-run = "atroposlib.cli.view_run:main" +view-run-multimodal = "atroposlib.cli.view_run_multimodal:main" atropos-sft-gen = "atroposlib.cli.sft:main" atropos-dpo-gen = "atroposlib.cli.dpo:main"