mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
fix multimodal envs. add view_run_multimodal
This commit is contained in:
parent
a282604baa
commit
0f15be68a2
8 changed files with 265 additions and 187 deletions
184
atroposlib/cli/view_run_multimodal.py
Normal file
184
atroposlib/cli/view_run_multimodal.py
Normal file
|
|
@ -0,0 +1,184 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
import base64
|
||||
import re
|
||||
from io import BytesIO
|
||||
|
||||
import aiohttp
|
||||
import gradio as gr
|
||||
import PIL.Image
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
def find_common_prefix(strings):
|
||||
if not strings:
|
||||
return ""
|
||||
|
||||
prefix = strings[0]
|
||||
for s in strings[1:]:
|
||||
while not s.startswith(prefix):
|
||||
prefix = prefix[:-1]
|
||||
if not prefix:
|
||||
return ""
|
||||
return prefix
|
||||
|
||||
|
||||
async def register_to_api(group_size, max_token_len):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get("http://localhost:8000/reset_data") as response:
|
||||
print(await response.text())
|
||||
print(group_size)
|
||||
async with session.post(
|
||||
"http://localhost:8000/register",
|
||||
json={
|
||||
"wandb_group": "test",
|
||||
"wandb_project": "test",
|
||||
"batch_size": group_size
|
||||
* 8, # * 8 just in case you want to just sample from a large group
|
||||
"max_token_len": max_token_len,
|
||||
"checkpoint_dir": "checkpoints",
|
||||
"save_checkpoint_interval": 10,
|
||||
"starting_step": 0,
|
||||
"num_steps": 69,
|
||||
},
|
||||
) as response:
|
||||
print("output of register is")
|
||||
print(await response.text())
|
||||
|
||||
|
||||
async def check_for_batch():
|
||||
while True:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get("http://localhost:8000/batch") as response:
|
||||
data = await response.json()
|
||||
print(data)
|
||||
if data["batch"] is not None:
|
||||
return data["batch"]
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
def extract_image_from_chat(chat_text):
|
||||
# Extract the base64 image data from the chat text
|
||||
# Support both jpeg and png formats
|
||||
image_pattern = r'data:image/(jpeg|png);base64,([^"\\]*)'
|
||||
match = re.search(image_pattern, chat_text)
|
||||
|
||||
if match:
|
||||
base64_data = match.group(2)
|
||||
try:
|
||||
image_data = base64.b64decode(base64_data)
|
||||
image = PIL.Image.open(BytesIO(image_data))
|
||||
return image
|
||||
except Exception as e:
|
||||
print(f"Error decoding image: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def extract_text_from_chat(chat_text):
|
||||
# Try to extract text from JSON format first
|
||||
# Check if this is JSON multimodal content
|
||||
if '"type": "text"' in chat_text:
|
||||
text_pattern = r'"type": "text", "text": "([^"]*)"'
|
||||
match = re.search(text_pattern, chat_text)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
# If not in JSON format, look for [Image] prefix
|
||||
if "[Image]" in chat_text:
|
||||
return chat_text.split("[Image]", 1)[1].strip()
|
||||
|
||||
# Return original text if no pattern is found
|
||||
return chat_text
|
||||
|
||||
|
||||
async def build_interface(group_size, max_token_len, tokenizer, port):
|
||||
async def grab_batch():
|
||||
tok = AutoTokenizer.from_pretrained(tokenizer)
|
||||
data = await check_for_batch()
|
||||
print(data)
|
||||
chats = [tok.decode(chat) for chat in data[0]["tokens"]]
|
||||
|
||||
# Find common prefix
|
||||
prefix = find_common_prefix(chats)
|
||||
|
||||
# Handle base64 encoded image
|
||||
try:
|
||||
if "images" in data[0] and data[0]["images"] and data[0]["images"][0]:
|
||||
print("Found image data in batch")
|
||||
# Convert base64 string to image
|
||||
base64_image = data[0]["images"][0]
|
||||
|
||||
# If it's already a PIL Image, use it directly
|
||||
if isinstance(base64_image, PIL.Image.Image):
|
||||
image = base64_image
|
||||
# If it's a base64 string, decode it
|
||||
elif isinstance(base64_image, str):
|
||||
# Remove data:image prefix if present
|
||||
if base64_image.startswith("data:image"):
|
||||
# Extract just the base64 part
|
||||
image_data = base64_image.split(",", 1)[1]
|
||||
else:
|
||||
image_data = base64_image
|
||||
|
||||
# Decode base64 to bytes and create image
|
||||
image_bytes = base64.b64decode(image_data)
|
||||
image = PIL.Image.open(BytesIO(image_bytes))
|
||||
else:
|
||||
print(f"Image type not recognized: {type(base64_image)}")
|
||||
image = None
|
||||
else:
|
||||
# Try to extract image from chat text as fallback
|
||||
print("No images field found, trying to extract from chat text")
|
||||
image = extract_image_from_chat(prefix)
|
||||
except Exception as e:
|
||||
print(f"Error processing image: {e}")
|
||||
image = None
|
||||
|
||||
# Extract text prompt from prefix
|
||||
text_prompt = extract_text_from_chat(prefix)
|
||||
|
||||
return (
|
||||
image, # Image
|
||||
text_prompt, # Text prompt
|
||||
*[chat.split(prefix)[1] for chat in chats[:group_size]], # Model outputs
|
||||
*data[0]["scores"][:group_size], # Scores
|
||||
)
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
image_blk = gr.Image(label="Image", type="pil")
|
||||
prompt_blk = gr.Textbox(label="Text Prompt")
|
||||
|
||||
with gr.Row():
|
||||
score_blks = [gr.Textbox(label=f"Score_{i+1}") for i in range(group_size)]
|
||||
|
||||
with gr.Row():
|
||||
outputs_blks = [
|
||||
gr.Textbox(label=f"Output_{i+1}") for i in range(group_size)
|
||||
]
|
||||
|
||||
with gr.Row():
|
||||
grab_next = gr.Button(value="Grab Next Batch")
|
||||
|
||||
grab_next.click(
|
||||
fn=grab_batch,
|
||||
outputs=[image_blk, prompt_blk] + outputs_blks + score_blks,
|
||||
api_name="get_batch",
|
||||
)
|
||||
await register_to_api(group_size, max_token_len)
|
||||
demo.launch(server_port=port, share=True)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--port", type=int, default=9001)
|
||||
parser.add_argument("--group-size", type=int, default=2)
|
||||
parser.add_argument("--max-token-len", type=int, default=2048)
|
||||
parser.add_argument("--tokenizer", type=str, default="Qwen/Qwen2-VL-2B-Instruct")
|
||||
args = parser.parse_args()
|
||||
asyncio.run(
|
||||
build_interface(args.group_size, args.max_token_len, args.tokenizer, args.port)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,9 +1,7 @@
|
|||
import base64
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
|
|
@ -24,13 +22,11 @@ class MultimodalExampleEnv(BaseEnv):
|
|||
async def collect_trajectories(
|
||||
self, item: Item
|
||||
) -> Tuple[GameHistory | None, List[Item]]:
|
||||
print("DEBUG: Starting collect_trajectories")
|
||||
to_score = list()
|
||||
to_backlog = list()
|
||||
|
||||
# Get the current image if it was stored
|
||||
if hasattr(self, "current_image"):
|
||||
print("DEBUG: Using current_image for multimodal content")
|
||||
|
||||
# Convert PIL image to base64
|
||||
import io
|
||||
|
|
@ -56,14 +52,12 @@ class MultimodalExampleEnv(BaseEnv):
|
|||
|
||||
if not text_content:
|
||||
text_content = "Please solve this problem and provide your answer as \\boxed{answer}."
|
||||
except Exception as e:
|
||||
print(f"DEBUG: Error parsing JSON: {e}")
|
||||
except Exception:
|
||||
text_content = "Please solve this problem and provide your answer as \\boxed{answer}."
|
||||
else:
|
||||
text_content = user_content
|
||||
|
||||
# Create messages with the new format
|
||||
print("DEBUG: Creating multimodal message with new format")
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
|
|
@ -84,7 +78,6 @@ class MultimodalExampleEnv(BaseEnv):
|
|||
]
|
||||
|
||||
else:
|
||||
print("DEBUG: No image available, using text-only message")
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
|
|
@ -93,26 +86,20 @@ class MultimodalExampleEnv(BaseEnv):
|
|||
dict(item[0][0]),
|
||||
]
|
||||
|
||||
print("DEBUG: About to call chat_completion")
|
||||
chat_completions = await self.server.chat_completion(
|
||||
messages=messages,
|
||||
n=self.config.group_size,
|
||||
max_tokens=1024 * 2,
|
||||
timeout=60, # Add timeout to prevent hanging (60 seconds is more reasonable)
|
||||
)
|
||||
print("DEBUG: chat_completion call successful")
|
||||
|
||||
for i, chat_completion in enumerate(chat_completions.choices):
|
||||
print(f"DEBUG: Processing completion {i+1}/{len(chat_completions.choices)}")
|
||||
messages = (
|
||||
dict(item[0][0]),
|
||||
{"role": "assistant", "content": chat_completion.message.content},
|
||||
)
|
||||
to_score.append((messages, item[1], base64_image))
|
||||
|
||||
print("DEBUG: Finished processing completions")
|
||||
|
||||
print("DEBUG: Returning from collect_trajectories")
|
||||
return to_score, to_backlog
|
||||
|
||||
async def postprocess_histories(
|
||||
|
|
@ -141,20 +128,12 @@ class MultimodalExampleEnv(BaseEnv):
|
|||
Get the next items to be rolled out, including the image
|
||||
"""
|
||||
try:
|
||||
print("DEBUG: Starting get_next_item")
|
||||
|
||||
# Get next dataset item
|
||||
next_item = self.train[self.iter % len(self.train)]
|
||||
self.iter += 1
|
||||
|
||||
print(f"DEBUG: Retrieved dataset item {self.iter-1}")
|
||||
|
||||
# For debugging, we'll use a simple text-only prompt and store the image separately
|
||||
# This is because the collect_trajectories method will handle the multimodal formatting
|
||||
|
||||
# Store image as a class attribute so collect_trajectories can access it
|
||||
self.current_image = next_item["image"]
|
||||
print("DEBUG: Stored image in current_image attribute")
|
||||
|
||||
# Create a simple text prompt - the image will be added in collect_trajectories
|
||||
# This avoids the unhashable type error with lists in frozensets
|
||||
|
|
@ -177,11 +156,9 @@ class MultimodalExampleEnv(BaseEnv):
|
|||
img_byte_arr = img_byte_arr.getvalue()
|
||||
base64_image = base64.b64encode(img_byte_arr).decode("utf-8")
|
||||
|
||||
print("DEBUG: Created simple text-only prompt for get_next_item")
|
||||
return (prompt, answer, base64_image)
|
||||
|
||||
except Exception as e:
|
||||
print(f"DEBUG: Error in get_next_item: {str(e)}")
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
# Create a dummy item as fallback
|
||||
|
|
@ -212,9 +189,6 @@ class MultimodalExampleEnv(BaseEnv):
|
|||
model_answer = (
|
||||
item[0][-1]["content"].split("\\boxed{")[-1].split("}")[0]
|
||||
)
|
||||
print(
|
||||
f"DEBUG: Model answer: {model_answer} and RG data: {rollout_group_data[0][1]}"
|
||||
)
|
||||
|
||||
pattern = r"<answer>\s*(\d{1,2})\s*</answer>"
|
||||
string = rollout_group_data[0][1]
|
||||
|
|
@ -243,35 +217,26 @@ class MultimodalExampleEnv(BaseEnv):
|
|||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
print("ERROR: OPENAI_API_KEY environment variable is not set!")
|
||||
print("Please set it using: export OPENAI_API_KEY=your_api_key")
|
||||
sys.exit(1)
|
||||
|
||||
print(
|
||||
f"DEBUG: Using API key starting with: {os.environ.get('OPENAI_API_KEY')[:5]}..."
|
||||
)
|
||||
|
||||
config = BaseEnvConfig(
|
||||
wandb_name="clevr_cogen",
|
||||
tokenizer_name="gpt2",
|
||||
group_size=2,
|
||||
use_wandb=False,
|
||||
wandb_name="clevr_cogen_a_train",
|
||||
tokenizer_name="Qwen/Qwen2-VL-2B-Instruct",
|
||||
group_size=8,
|
||||
use_wandb=True,
|
||||
max_num_workers=2,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=1,
|
||||
steps_per_eval=10,
|
||||
ensure_scores_are_not_same=False,
|
||||
batch_size=12,
|
||||
steps_per_eval=100,
|
||||
max_token_length=2048,
|
||||
)
|
||||
|
||||
print("DEBUG: Creating OpenAI configuration")
|
||||
server_configs = [
|
||||
OpenaiConfig(
|
||||
model_name="gpt-4o", # Using GPT-4o which has multimodal capabilities
|
||||
base_url=None,
|
||||
api_key=os.environ.get("OPENAI_API_KEY"),
|
||||
num_requests_for_eval=1,
|
||||
model_name="Qwen/Qwen2-VL-2B-Instruct",
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=256,
|
||||
),
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
import base64
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import traceback
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
|
|
@ -10,6 +8,7 @@ from datasets import load_dataset
|
|||
|
||||
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
|
||||
from atroposlib.type_definitions import GameHistory, Item
|
||||
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
|
||||
|
|
@ -23,13 +22,11 @@ class MultimodalComplexEnv(BaseEnv):
|
|||
async def collect_trajectories(
|
||||
self, item: Item
|
||||
) -> Tuple[GameHistory | None, List[Item]]:
|
||||
print("DEBUG: Starting collect_trajectories")
|
||||
to_score = list()
|
||||
to_backlog = list()
|
||||
|
||||
# Get the current image if it was stored
|
||||
if hasattr(self, "current_image"):
|
||||
print("DEBUG: Using current_image for multimodal content")
|
||||
|
||||
# Convert PIL image to base64
|
||||
import io
|
||||
|
|
@ -55,14 +52,12 @@ class MultimodalComplexEnv(BaseEnv):
|
|||
|
||||
if not text_content:
|
||||
text_content = "Please solve this problem and provide your answer as \\boxed{answer}."
|
||||
except Exception as e:
|
||||
print(f"DEBUG: Error parsing JSON: {e}")
|
||||
except Exception:
|
||||
text_content = "Please solve this problem and provide your answer as \\boxed{answer}."
|
||||
else:
|
||||
text_content = user_content
|
||||
|
||||
# Create messages with the new format
|
||||
print("DEBUG: Creating multimodal message with new format")
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
|
|
@ -83,7 +78,6 @@ class MultimodalComplexEnv(BaseEnv):
|
|||
]
|
||||
|
||||
else:
|
||||
print("DEBUG: No image available, using text-only message")
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
|
|
@ -92,32 +86,23 @@ class MultimodalComplexEnv(BaseEnv):
|
|||
dict(item[0][0]),
|
||||
]
|
||||
|
||||
print("DEBUG: About to call chat_completion")
|
||||
chat_completions = await self.server.chat_completion(
|
||||
messages=messages,
|
||||
n=self.config.group_size,
|
||||
max_tokens=1024 * 2,
|
||||
timeout=60, # Add timeout to prevent hanging (60 seconds is more reasonable)
|
||||
)
|
||||
print("DEBUG: chat_completion call successful")
|
||||
|
||||
for i, chat_completion in enumerate(chat_completions.choices):
|
||||
print(f"DEBUG: Processing completion {i+1}/{len(chat_completions.choices)}")
|
||||
messages = (
|
||||
dict(item[0][0]),
|
||||
{"role": "assistant", "content": chat_completion.message.content},
|
||||
)
|
||||
to_score.append((messages, item[1], base64_image))
|
||||
|
||||
print("DEBUG: Finished processing completions")
|
||||
to_postprocess = await self.score(to_score)
|
||||
|
||||
print("DEBUG: Returning from collect_trajectories")
|
||||
return to_score, to_backlog
|
||||
|
||||
async def postprocess_histories(
|
||||
self, trajectories: List[GameHistory]
|
||||
) -> ScoredDataGroup:
|
||||
pass
|
||||
return to_postprocess, to_backlog
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""
|
||||
|
|
@ -140,20 +125,13 @@ class MultimodalComplexEnv(BaseEnv):
|
|||
Get the next items to be rolled out, including the image
|
||||
"""
|
||||
try:
|
||||
print("DEBUG: Starting get_next_item")
|
||||
|
||||
# Get next dataset item
|
||||
next_item = self.train[self.iter % len(self.train)]
|
||||
self.iter += 1
|
||||
|
||||
print(f"DEBUG: Retrieved dataset item {self.iter-1}")
|
||||
|
||||
# For debugging, we'll use a simple text-only prompt and store the image separately
|
||||
# This is because the collect_trajectories method will handle the multimodal formatting
|
||||
|
||||
# Store image as a class attribute so collect_trajectories can access it
|
||||
self.current_image = next_item["image"]
|
||||
print("DEBUG: Stored image in current_image attribute")
|
||||
|
||||
# Create a simple text prompt - the image will be added in collect_trajectories
|
||||
# This avoids the unhashable type error with lists in frozensets
|
||||
|
|
@ -173,11 +151,9 @@ class MultimodalComplexEnv(BaseEnv):
|
|||
img_byte_arr = img_byte_arr.getvalue()
|
||||
base64_image = base64.b64encode(img_byte_arr).decode("utf-8")
|
||||
|
||||
print("DEBUG: Created simple text-only prompt for get_next_item")
|
||||
return (prompt, answer, base64_image)
|
||||
|
||||
except Exception as e:
|
||||
print(f"DEBUG: Error in get_next_item: {str(e)}")
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
# Create a dummy item as fallback
|
||||
|
|
@ -208,9 +184,6 @@ class MultimodalComplexEnv(BaseEnv):
|
|||
model_answer = (
|
||||
item[0][-1]["content"].split("\\boxed{")[-1].split("}")[0]
|
||||
)
|
||||
print(
|
||||
f"DEBUG: Model answer: {model_answer} and RG data: {rollout_group_data[0][1]}"
|
||||
)
|
||||
|
||||
# Handle both numeric and yes/no answers
|
||||
gold_answer = rollout_group_data[0][1]
|
||||
|
|
@ -244,35 +217,26 @@ class MultimodalComplexEnv(BaseEnv):
|
|||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
print("ERROR: OPENAI_API_KEY environment variable is not set!")
|
||||
print("Please set it using: export OPENAI_API_KEY=your_api_key")
|
||||
sys.exit(1)
|
||||
|
||||
print(
|
||||
f"DEBUG: Using API key starting with: {os.environ.get('OPENAI_API_KEY')[:5]}..."
|
||||
)
|
||||
|
||||
config = BaseEnvConfig(
|
||||
wandb_name="clevr_complex",
|
||||
tokenizer_name="gpt2",
|
||||
group_size=2,
|
||||
use_wandb=False,
|
||||
tokenizer_name="Qwen/Qwen2-VL-2B-Instruct",
|
||||
group_size=8,
|
||||
use_wandb=True,
|
||||
max_num_workers=2,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=1,
|
||||
steps_per_eval=10,
|
||||
ensure_scores_are_not_same=False,
|
||||
batch_size=12,
|
||||
steps_per_eval=100,
|
||||
max_token_length=2048,
|
||||
)
|
||||
|
||||
print("DEBUG: Creating OpenAI configuration")
|
||||
server_configs = [
|
||||
OpenaiConfig(
|
||||
model_name="gpt-4o", # Using GPT-4o which has multimodal capabilities
|
||||
base_url=None,
|
||||
api_key=os.environ.get("OPENAI_API_KEY"),
|
||||
num_requests_for_eval=1,
|
||||
model_name="Qwen/Qwen2-VL-2B-Instruct",
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=256,
|
||||
),
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
import base64
|
||||
import io
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
|
|
@ -68,13 +66,8 @@ class OcrVqaEnv(BaseEnv):
|
|||
history: GameHistory = (user_hist, assistant_hist)
|
||||
to_score.append((history, gold, base64_image))
|
||||
|
||||
return to_score, to_backlog
|
||||
|
||||
async def postprocess_histories(
|
||||
self, trajectories: List[GameHistory]
|
||||
) -> ScoredDataGroup:
|
||||
# No additional post-processing needed
|
||||
pass
|
||||
to_postprocess = await self.score(to_score)
|
||||
return to_postprocess, to_backlog
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
# No custom evaluation
|
||||
|
|
@ -167,29 +160,26 @@ class OcrVqaEnv(BaseEnv):
|
|||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
print("ERROR: OPENAI_API_KEY environment variable is not set!")
|
||||
sys.exit(1)
|
||||
|
||||
config = BaseEnvConfig(
|
||||
wandb_name="ocr_vqa",
|
||||
tokenizer_name="gpt2",
|
||||
group_size=2,
|
||||
use_wandb=False,
|
||||
tokenizer_name="Qwen/Qwen2-VL-2B-Instruct",
|
||||
group_size=8,
|
||||
use_wandb=True,
|
||||
max_num_workers=2,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=1,
|
||||
steps_per_eval=10,
|
||||
ensure_scores_are_not_same=False,
|
||||
batch_size=12,
|
||||
steps_per_eval=100,
|
||||
max_token_length=2048,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
OpenaiConfig(
|
||||
model_name="gpt-4o",
|
||||
base_url=None,
|
||||
api_key=os.environ.get("OPENAI_API_KEY"),
|
||||
num_requests_for_eval=1,
|
||||
model_name="Qwen/Qwen2-VL-2B-Instruct",
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=256,
|
||||
),
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
import base64
|
||||
import io
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
|
|
@ -68,13 +66,9 @@ class ClockDatasetEnv(BaseEnv):
|
|||
history: GameHistory = (user_msg, assistant_msg)
|
||||
to_score.append((history, item[1], base64_image))
|
||||
|
||||
return to_score, to_backlog
|
||||
to_postprocess = await self.score(to_score)
|
||||
|
||||
async def postprocess_histories(
|
||||
self, trajectories: List[GameHistory]
|
||||
) -> ScoredDataGroup:
|
||||
# No custom post-processing
|
||||
pass
|
||||
return to_postprocess, to_backlog
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
# No custom evaluation
|
||||
|
|
@ -87,6 +81,7 @@ class ClockDatasetEnv(BaseEnv):
|
|||
self.iter = 0
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
|
||||
try:
|
||||
entry = self.train[self.iter % len(self.train)]
|
||||
self.iter += 1
|
||||
|
|
@ -128,6 +123,7 @@ class ClockDatasetEnv(BaseEnv):
|
|||
scores["scores"] = []
|
||||
scores["images"] = []
|
||||
random.shuffle(rollout_group_data)
|
||||
|
||||
for item in rollout_group_data:
|
||||
out = tokenize_for_trainer(self.tokenizer, item[0])
|
||||
tokens = out["tokens"]
|
||||
|
|
@ -169,29 +165,26 @@ class ClockDatasetEnv(BaseEnv):
|
|||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
print("ERROR: OPENAI_API_KEY environment variable is not set!")
|
||||
sys.exit(1)
|
||||
|
||||
config = BaseEnvConfig(
|
||||
wandb_name="clocks",
|
||||
tokenizer_name="gpt2",
|
||||
group_size=2,
|
||||
use_wandb=False,
|
||||
wandb_name="pixmo_clocks",
|
||||
tokenizer_name="Qwen/Qwen2-VL-2B-Instruct",
|
||||
group_size=8,
|
||||
use_wandb=True,
|
||||
max_num_workers=2,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=1,
|
||||
steps_per_eval=10,
|
||||
ensure_scores_are_not_same=False,
|
||||
batch_size=12,
|
||||
steps_per_eval=100,
|
||||
max_token_length=2048,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
OpenaiConfig(
|
||||
model_name="gpt-4o",
|
||||
base_url=None,
|
||||
api_key=os.environ.get("OPENAI_API_KEY"),
|
||||
num_requests_for_eval=1,
|
||||
model_name="Qwen/Qwen2-VL-2B-Instruct",
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=256,
|
||||
),
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
import base64
|
||||
import io
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
|
|
@ -103,13 +101,9 @@ class PixmoCountEnv(BaseEnv):
|
|||
history: GameHistory = (user_hist, assistant_hist)
|
||||
to_score.append((history, gold, base64_image))
|
||||
|
||||
return to_score, to_backlog
|
||||
to_postprocess = await self.score(to_score)
|
||||
|
||||
async def postprocess_histories(
|
||||
self, trajectories: List[GameHistory]
|
||||
) -> ScoredDataGroup:
|
||||
# No custom post-processing
|
||||
pass
|
||||
return to_postprocess, to_backlog
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
# No custom evaluation
|
||||
|
|
@ -158,29 +152,26 @@ class PixmoCountEnv(BaseEnv):
|
|||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
print("ERROR: OPENAI_API_KEY environment variable is not set!")
|
||||
sys.exit(1)
|
||||
|
||||
config = BaseEnvConfig(
|
||||
wandb_name="pixmo_count",
|
||||
tokenizer_name="gpt2",
|
||||
group_size=2,
|
||||
use_wandb=False,
|
||||
tokenizer_name="Qwen/Qwen2-VL-2B-Instruct",
|
||||
group_size=8,
|
||||
use_wandb=True,
|
||||
max_num_workers=2,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=1,
|
||||
steps_per_eval=10,
|
||||
ensure_scores_are_not_same=False,
|
||||
batch_size=12,
|
||||
steps_per_eval=100,
|
||||
max_token_length=2048,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
OpenaiConfig(
|
||||
model_name="gpt-4o",
|
||||
base_url=None,
|
||||
api_key=os.environ.get("OPENAI_API_KEY"),
|
||||
num_requests_for_eval=1,
|
||||
model_name="Qwen/Qwen2-VL-2B-Instruct",
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=256,
|
||||
),
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
import base64
|
||||
import io
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
|
|
@ -50,8 +48,7 @@ class PixmoPointExplanationsEnv(BaseEnv):
|
|||
img.save(buf, format="PNG")
|
||||
img_bytes = buf.getvalue()
|
||||
base64_image = base64.b64encode(img_bytes).decode("utf-8")
|
||||
except Exception as e:
|
||||
print(f"Error loading image from URL: {e}")
|
||||
except Exception:
|
||||
base64_image = None
|
||||
|
||||
return (prompt, gold_answer, base64_image)
|
||||
|
|
@ -113,13 +110,9 @@ class PixmoPointExplanationsEnv(BaseEnv):
|
|||
history: GameHistory = (user_hist, assistant_hist)
|
||||
to_score.append((history, gold, base64_image))
|
||||
|
||||
return to_score, to_backlog
|
||||
to_postprocess = await self.score(to_score)
|
||||
|
||||
async def postprocess_histories(
|
||||
self, trajectories: List[GameHistory]
|
||||
) -> ScoredDataGroup:
|
||||
# No custom post-processing needed
|
||||
pass
|
||||
return to_postprocess, to_backlog
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
# No custom evaluation
|
||||
|
|
@ -169,29 +162,26 @@ class PixmoPointExplanationsEnv(BaseEnv):
|
|||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
print("ERROR: OPENAI_API_KEY environment variable is not set!")
|
||||
sys.exit(1)
|
||||
|
||||
config = BaseEnvConfig(
|
||||
wandb_name="pixmo_point_explanations",
|
||||
tokenizer_name="gpt2",
|
||||
group_size=2,
|
||||
use_wandb=False,
|
||||
tokenizer_name="Qwen/Qwen2-VL-2B-Instruct",
|
||||
group_size=8,
|
||||
use_wandb=True,
|
||||
max_num_workers=2,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=1,
|
||||
steps_per_eval=10,
|
||||
ensure_scores_are_not_same=False,
|
||||
batch_size=12,
|
||||
steps_per_eval=100,
|
||||
max_token_length=2048,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
OpenaiConfig(
|
||||
model_name="gpt-4o",
|
||||
base_url=None,
|
||||
api_key=os.environ.get("OPENAI_API_KEY"),
|
||||
num_requests_for_eval=1,
|
||||
model_name="Qwen/Qwen2-VL-2B-Instruct",
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=256,
|
||||
),
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ dependencies = [
|
|||
run-api = "atroposlib.cli.run_api:main"
|
||||
inference-node-wandb-watcher = "atroposlib.cli.inference_node_wandb_watcher:main"
|
||||
view-run = "atroposlib.cli.view_run:main"
|
||||
view-run-multimodal = "atroposlib.cli.view_run_multimodal:main"
|
||||
atropos-sft-gen = "atroposlib.cli.sft:main"
|
||||
atropos-dpo-gen = "atroposlib.cli.dpo:main"
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue