mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
first commit
This commit is contained in:
commit
621d00dd80
89 changed files with 15315 additions and 0 deletions
282
environments/multimodal_dpo/clevr_cogen_a_train.py
Normal file
282
environments/multimodal_dpo/clevr_cogen_a_train.py
Normal file
|
|
@ -0,0 +1,282 @@
|
|||
import base64
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
|
||||
from atroposlib.type_definitions import GameHistory, Item
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
|
||||
class MultimodalExampleEnv(BaseEnv):
|
||||
name = "clevr_cogen_a_train"
|
||||
name_config_cls = BaseEnvConfig
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: Item
|
||||
) -> Tuple[GameHistory | None, List[Item]]:
|
||||
print("DEBUG: Starting collect_trajectories")
|
||||
to_score = list()
|
||||
to_backlog = list()
|
||||
|
||||
# Get the current image if it was stored
|
||||
if hasattr(self, "current_image"):
|
||||
print("DEBUG: Using current_image for multimodal content")
|
||||
|
||||
# Convert PIL image to base64
|
||||
import io
|
||||
|
||||
img_byte_arr = io.BytesIO()
|
||||
self.current_image.save(img_byte_arr, format="PNG")
|
||||
img_byte_arr = img_byte_arr.getvalue()
|
||||
base64_image = base64.b64encode(img_byte_arr).decode("utf-8")
|
||||
|
||||
# Extract text content from item
|
||||
user_content = dict(item[0][0]).get("content", "")
|
||||
|
||||
# Try to parse if it's JSON
|
||||
if isinstance(user_content, str) and (
|
||||
user_content.startswith("[") or user_content.startswith("{")
|
||||
):
|
||||
try:
|
||||
parsed = json.loads(user_content)
|
||||
text_content = ""
|
||||
for element in parsed:
|
||||
if element.get("type") == "text":
|
||||
text_content = element.get("text", "")
|
||||
|
||||
if not text_content:
|
||||
text_content = "Please solve this problem and provide your answer as \\boxed{answer}."
|
||||
except Exception as e:
|
||||
print(f"DEBUG: Error parsing JSON: {e}")
|
||||
text_content = "Please solve this problem and provide your answer as \\boxed{answer}."
|
||||
else:
|
||||
text_content = user_content
|
||||
|
||||
# Create messages with the new format
|
||||
print("DEBUG: Creating multimodal message with new format")
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You must submit your answer with \\boxed{answer}, please make sure to do this",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": text_content},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{base64_image}",
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
else:
|
||||
print("DEBUG: No image available, using text-only message")
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You must submit your answer with \\boxed{answer}",
|
||||
},
|
||||
dict(item[0][0]),
|
||||
]
|
||||
|
||||
print("DEBUG: About to call chat_completion")
|
||||
chat_completions = await self.server.chat_completion(
|
||||
messages=messages,
|
||||
n=self.config.group_size,
|
||||
max_tokens=1024 * 2,
|
||||
timeout=60, # Add timeout to prevent hanging (60 seconds is more reasonable)
|
||||
)
|
||||
print("DEBUG: chat_completion call successful")
|
||||
|
||||
for i, chat_completion in enumerate(chat_completions.choices):
|
||||
print(f"DEBUG: Processing completion {i+1}/{len(chat_completions.choices)}")
|
||||
messages = (
|
||||
dict(item[0][0]),
|
||||
{"role": "assistant", "content": chat_completion.message.content},
|
||||
)
|
||||
to_score.append((messages, item[1], base64_image))
|
||||
|
||||
print("DEBUG: Finished processing completions")
|
||||
|
||||
print("DEBUG: Returning from collect_trajectories")
|
||||
return to_score, to_backlog
|
||||
|
||||
async def postprocess_histories(
|
||||
self, trajectories: List[GameHistory]
|
||||
) -> ScoredDataGroup:
|
||||
pass
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""
|
||||
Evaluate the environment, this is called every steps_per_eval steps
|
||||
|
||||
:param args:
|
||||
:param kwargs:
|
||||
:return: None.
|
||||
"""
|
||||
return
|
||||
|
||||
async def setup(self):
|
||||
"""Setup the environment and load the multimodal dataset"""
|
||||
self.dataset = load_dataset("leonardPKU/clevr_cogen_a_train")
|
||||
self.train = self.dataset["train"]
|
||||
self.iter = 0
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
"""
|
||||
Get the next items to be rolled out, including the image
|
||||
"""
|
||||
try:
|
||||
print("DEBUG: Starting get_next_item")
|
||||
|
||||
# Get next dataset item
|
||||
next_item = self.train[self.iter % len(self.train)]
|
||||
self.iter += 1
|
||||
|
||||
print(f"DEBUG: Retrieved dataset item {self.iter-1}")
|
||||
|
||||
# For debugging, we'll use a simple text-only prompt and store the image separately
|
||||
# This is because the collect_trajectories method will handle the multimodal formatting
|
||||
|
||||
# Store image as a class attribute so collect_trajectories can access it
|
||||
self.current_image = next_item["image"]
|
||||
print("DEBUG: Stored image in current_image attribute")
|
||||
|
||||
# Create a simple text prompt - the image will be added in collect_trajectories
|
||||
# This avoids the unhashable type error with lists in frozensets
|
||||
text_prompt = next_item["problem"]
|
||||
|
||||
# Create a simple text-only prompt
|
||||
prompt = tuple(
|
||||
[frozenset({"role": "user", "content": text_prompt}.items())]
|
||||
)
|
||||
answer = next_item["solution"]
|
||||
|
||||
# get image as base64
|
||||
# image = next_item["image"]
|
||||
|
||||
# Convert PIL image to base64
|
||||
import io
|
||||
|
||||
img_byte_arr = io.BytesIO()
|
||||
self.current_image.save(img_byte_arr, format="PNG")
|
||||
img_byte_arr = img_byte_arr.getvalue()
|
||||
base64_image = base64.b64encode(img_byte_arr).decode("utf-8")
|
||||
|
||||
print("DEBUG: Created simple text-only prompt for get_next_item")
|
||||
return (prompt, answer, base64_image)
|
||||
|
||||
except Exception as e:
|
||||
print(f"DEBUG: Error in get_next_item: {str(e)}")
|
||||
traceback.print_exc()
|
||||
|
||||
# Create a dummy item as fallback
|
||||
prompt = tuple(
|
||||
[
|
||||
frozenset(
|
||||
{"role": "user", "content": "Please solve: 2 + 2 = ?"}.items()
|
||||
)
|
||||
]
|
||||
)
|
||||
answer = "4"
|
||||
return (prompt, answer, "obobob")
|
||||
|
||||
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
|
||||
scores = ScoredDataGroup()
|
||||
scores["tokens"] = list()
|
||||
scores["masks"] = list()
|
||||
scores["scores"] = list()
|
||||
scores["images"] = list()
|
||||
random.shuffle(rollout_group_data)
|
||||
for item in rollout_group_data:
|
||||
out_dict = tokenize_for_trainer(self.tokenizer, item[0])
|
||||
tokens = out_dict["tokens"]
|
||||
masks = out_dict["masks"]
|
||||
|
||||
# Extract the answer from the model's response
|
||||
try:
|
||||
model_answer = (
|
||||
item[0][-1]["content"].split("\\boxed{")[-1].split("}")[0]
|
||||
)
|
||||
print(
|
||||
f"DEBUG: Model answer: {model_answer} and RG data: {rollout_group_data[0][1]}"
|
||||
)
|
||||
|
||||
pattern = r"<answer>\s*(\d{1,2})\s*</answer>"
|
||||
string = rollout_group_data[0][1]
|
||||
gold_answer = re.search(pattern, string).group(1)
|
||||
|
||||
reward = gold_answer == model_answer
|
||||
except IndexError:
|
||||
reward = False
|
||||
|
||||
# remove obviously bad examples
|
||||
if len([1 for i in masks if i != -100]) < 10:
|
||||
continue
|
||||
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
scores["scores"].append(1.0 if reward else -1.0)
|
||||
|
||||
try:
|
||||
scores["images"].append(item[2])
|
||||
except IndexError:
|
||||
scores["images"].append(None)
|
||||
if len(scores["tokens"]) >= self.config.group_size:
|
||||
break
|
||||
|
||||
return scores
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
print("ERROR: OPENAI_API_KEY environment variable is not set!")
|
||||
print("Please set it using: export OPENAI_API_KEY=your_api_key")
|
||||
sys.exit(1)
|
||||
|
||||
print(
|
||||
f"DEBUG: Using API key starting with: {os.environ.get('OPENAI_API_KEY')[:5]}..."
|
||||
)
|
||||
|
||||
config = BaseEnvConfig(
|
||||
wandb_name="clevr_cogen",
|
||||
tokenizer_name="gpt2",
|
||||
group_size=2,
|
||||
use_wandb=False,
|
||||
max_num_workers=2,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=1,
|
||||
steps_per_eval=10,
|
||||
ensure_scores_are_not_same=False,
|
||||
)
|
||||
|
||||
print("DEBUG: Creating OpenAI configuration")
|
||||
server_configs = [
|
||||
OpenaiConfig(
|
||||
model_name="gpt-4o", # Using GPT-4o which has multimodal capabilities
|
||||
base_url=None,
|
||||
api_key=os.environ.get("OPENAI_API_KEY"),
|
||||
num_requests_for_eval=1,
|
||||
),
|
||||
]
|
||||
|
||||
return config, server_configs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
MultimodalExampleEnv.cli()
|
||||
283
environments/multimodal_dpo/clevr_complex.py
Normal file
283
environments/multimodal_dpo/clevr_complex.py
Normal file
|
|
@ -0,0 +1,283 @@
|
|||
import base64
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import traceback
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
|
||||
from atroposlib.type_definitions import GameHistory, Item
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
|
||||
class MultimodalComplexEnv(BaseEnv):
|
||||
name = "clevr_complex"
|
||||
name_config_cls = BaseEnvConfig
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: Item
|
||||
) -> Tuple[GameHistory | None, List[Item]]:
|
||||
print("DEBUG: Starting collect_trajectories")
|
||||
to_score = list()
|
||||
to_backlog = list()
|
||||
|
||||
# Get the current image if it was stored
|
||||
if hasattr(self, "current_image"):
|
||||
print("DEBUG: Using current_image for multimodal content")
|
||||
|
||||
# Convert PIL image to base64
|
||||
import io
|
||||
|
||||
img_byte_arr = io.BytesIO()
|
||||
self.current_image.save(img_byte_arr, format="PNG")
|
||||
img_byte_arr = img_byte_arr.getvalue()
|
||||
base64_image = base64.b64encode(img_byte_arr).decode("utf-8")
|
||||
|
||||
# Extract text content from item
|
||||
user_content = dict(item[0][0]).get("content", "")
|
||||
|
||||
# Try to parse if it's JSON
|
||||
if isinstance(user_content, str) and (
|
||||
user_content.startswith("[") or user_content.startswith("{")
|
||||
):
|
||||
try:
|
||||
parsed = json.loads(user_content)
|
||||
text_content = ""
|
||||
for element in parsed:
|
||||
if element.get("type") == "text":
|
||||
text_content = element.get("text", "")
|
||||
|
||||
if not text_content:
|
||||
text_content = "Please solve this problem and provide your answer as \\boxed{answer}."
|
||||
except Exception as e:
|
||||
print(f"DEBUG: Error parsing JSON: {e}")
|
||||
text_content = "Please solve this problem and provide your answer as \\boxed{answer}."
|
||||
else:
|
||||
text_content = user_content
|
||||
|
||||
# Create messages with the new format
|
||||
print("DEBUG: Creating multimodal message with new format")
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You must submit your answer with \\boxed{answer}, please make sure to do this",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": text_content},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{base64_image}",
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
else:
|
||||
print("DEBUG: No image available, using text-only message")
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You must submit your answer with \\boxed{answer}",
|
||||
},
|
||||
dict(item[0][0]),
|
||||
]
|
||||
|
||||
print("DEBUG: About to call chat_completion")
|
||||
chat_completions = await self.server.chat_completion(
|
||||
messages=messages,
|
||||
n=self.config.group_size,
|
||||
max_tokens=1024 * 2,
|
||||
timeout=60, # Add timeout to prevent hanging (60 seconds is more reasonable)
|
||||
)
|
||||
print("DEBUG: chat_completion call successful")
|
||||
|
||||
for i, chat_completion in enumerate(chat_completions.choices):
|
||||
print(f"DEBUG: Processing completion {i+1}/{len(chat_completions.choices)}")
|
||||
messages = (
|
||||
dict(item[0][0]),
|
||||
{"role": "assistant", "content": chat_completion.message.content},
|
||||
)
|
||||
to_score.append((messages, item[1], base64_image))
|
||||
|
||||
print("DEBUG: Finished processing completions")
|
||||
|
||||
print("DEBUG: Returning from collect_trajectories")
|
||||
return to_score, to_backlog
|
||||
|
||||
async def postprocess_histories(
|
||||
self, trajectories: List[GameHistory]
|
||||
) -> ScoredDataGroup:
|
||||
pass
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""
|
||||
Evaluate the environment, this is called every steps_per_eval steps
|
||||
|
||||
:param args:
|
||||
:param kwargs:
|
||||
:return: None.
|
||||
"""
|
||||
return
|
||||
|
||||
async def setup(self):
|
||||
"""Setup the environment and load the multimodal dataset"""
|
||||
self.dataset = load_dataset("MMInstruction/Clevr_CoGenT_TrainA_70K_Complex")
|
||||
self.train = self.dataset["train"]
|
||||
self.iter = 0
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
"""
|
||||
Get the next items to be rolled out, including the image
|
||||
"""
|
||||
try:
|
||||
print("DEBUG: Starting get_next_item")
|
||||
|
||||
# Get next dataset item
|
||||
next_item = self.train[self.iter % len(self.train)]
|
||||
self.iter += 1
|
||||
|
||||
print(f"DEBUG: Retrieved dataset item {self.iter-1}")
|
||||
|
||||
# For debugging, we'll use a simple text-only prompt and store the image separately
|
||||
# This is because the collect_trajectories method will handle the multimodal formatting
|
||||
|
||||
# Store image as a class attribute so collect_trajectories can access it
|
||||
self.current_image = next_item["image"]
|
||||
print("DEBUG: Stored image in current_image attribute")
|
||||
|
||||
# Create a simple text prompt - the image will be added in collect_trajectories
|
||||
# This avoids the unhashable type error with lists in frozensets
|
||||
text_prompt = next_item["problem"]
|
||||
|
||||
# Create a simple text-only prompt
|
||||
prompt = tuple(
|
||||
[frozenset({"role": "user", "content": text_prompt}.items())]
|
||||
)
|
||||
answer = next_item["solution"]
|
||||
|
||||
# Convert PIL image to base64
|
||||
import io
|
||||
|
||||
img_byte_arr = io.BytesIO()
|
||||
self.current_image.save(img_byte_arr, format="PNG")
|
||||
img_byte_arr = img_byte_arr.getvalue()
|
||||
base64_image = base64.b64encode(img_byte_arr).decode("utf-8")
|
||||
|
||||
print("DEBUG: Created simple text-only prompt for get_next_item")
|
||||
return (prompt, answer, base64_image)
|
||||
|
||||
except Exception as e:
|
||||
print(f"DEBUG: Error in get_next_item: {str(e)}")
|
||||
traceback.print_exc()
|
||||
|
||||
# Create a dummy item as fallback
|
||||
prompt = tuple(
|
||||
[
|
||||
frozenset(
|
||||
{"role": "user", "content": "Please solve: 2 + 2 = ?"}.items()
|
||||
)
|
||||
]
|
||||
)
|
||||
answer = "4"
|
||||
return (prompt, answer, "obobob")
|
||||
|
||||
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
|
||||
scores = ScoredDataGroup()
|
||||
scores["tokens"] = list()
|
||||
scores["masks"] = list()
|
||||
scores["scores"] = list()
|
||||
scores["images"] = list()
|
||||
random.shuffle(rollout_group_data)
|
||||
for item in rollout_group_data:
|
||||
out_dict = tokenize_for_trainer(self.tokenizer, item[0])
|
||||
tokens = out_dict["tokens"]
|
||||
masks = out_dict["masks"]
|
||||
|
||||
# Extract the answer from the model's response
|
||||
try:
|
||||
model_answer = (
|
||||
item[0][-1]["content"].split("\\boxed{")[-1].split("}")[0]
|
||||
)
|
||||
print(
|
||||
f"DEBUG: Model answer: {model_answer} and RG data: {rollout_group_data[0][1]}"
|
||||
)
|
||||
|
||||
# Handle both numeric and yes/no answers
|
||||
gold_answer = rollout_group_data[0][1]
|
||||
|
||||
# Case-insensitive comparison for yes/no and direct comparison for numbers
|
||||
if gold_answer.lower() in ["yes", "no"]:
|
||||
reward = gold_answer.lower() == model_answer.lower()
|
||||
else:
|
||||
# For numeric answers
|
||||
reward = gold_answer == model_answer
|
||||
|
||||
except IndexError:
|
||||
reward = False
|
||||
|
||||
# remove obviously bad examples
|
||||
if len([1 for i in masks if i != -100]) < 10:
|
||||
continue
|
||||
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
scores["scores"].append(1.0 if reward else -1.0)
|
||||
|
||||
try:
|
||||
scores["images"].append(item[2])
|
||||
except IndexError:
|
||||
scores["images"].append(None)
|
||||
if len(scores["tokens"]) >= self.config.group_size:
|
||||
break
|
||||
|
||||
return scores
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
print("ERROR: OPENAI_API_KEY environment variable is not set!")
|
||||
print("Please set it using: export OPENAI_API_KEY=your_api_key")
|
||||
sys.exit(1)
|
||||
|
||||
print(
|
||||
f"DEBUG: Using API key starting with: {os.environ.get('OPENAI_API_KEY')[:5]}..."
|
||||
)
|
||||
|
||||
config = BaseEnvConfig(
|
||||
wandb_name="clevr_complex",
|
||||
tokenizer_name="gpt2",
|
||||
group_size=2,
|
||||
use_wandb=False,
|
||||
max_num_workers=2,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=1,
|
||||
steps_per_eval=10,
|
||||
ensure_scores_are_not_same=False,
|
||||
)
|
||||
|
||||
print("DEBUG: Creating OpenAI configuration")
|
||||
server_configs = [
|
||||
OpenaiConfig(
|
||||
model_name="gpt-4o", # Using GPT-4o which has multimodal capabilities
|
||||
base_url=None,
|
||||
api_key=os.environ.get("OPENAI_API_KEY"),
|
||||
num_requests_for_eval=1,
|
||||
),
|
||||
]
|
||||
|
||||
return config, server_configs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
MultimodalComplexEnv.cli()
|
||||
200
environments/multimodal_dpo/ocr_vqa.py
Normal file
200
environments/multimodal_dpo/ocr_vqa.py
Normal file
|
|
@ -0,0 +1,200 @@
|
|||
import base64
|
||||
import io
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
|
||||
from atroposlib.type_definitions import GameHistory, Item
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
|
||||
class OcrVqaEnv(BaseEnv):
|
||||
name = "ocr_vqa"
|
||||
name_config_cls = BaseEnvConfig
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: Item
|
||||
) -> Tuple[GameHistory | None, List[Item]]:
|
||||
to_score: List[Tuple[GameHistory, str, Optional[str]]] = []
|
||||
to_backlog: List[Item] = []
|
||||
|
||||
# Extract question and image from item
|
||||
prompt_tuple, gold, base64_image = item
|
||||
# The prompt_tuple contains the user prompt as a frozenset
|
||||
text_prompt = dict(prompt_tuple[0])["content"]
|
||||
|
||||
# System instruction for answer formatting
|
||||
system_msg = {
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You must submit your answer enclosed in <answer> tags, "
|
||||
"e.g., <answer>YOUR_ANSWER</answer>"
|
||||
),
|
||||
}
|
||||
# Multimodal user message with text and image
|
||||
user_msg = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": text_prompt},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{base64_image}"},
|
||||
},
|
||||
],
|
||||
}
|
||||
messages = [system_msg, user_msg]
|
||||
|
||||
# Call the chat completion endpoint
|
||||
chat_completions = await self.server.chat_completion(
|
||||
messages=messages,
|
||||
n=self.config.group_size,
|
||||
max_tokens=512,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
# Build trajectories for scoring
|
||||
for choice in chat_completions.choices:
|
||||
user_hist = {"role": "user", "content": text_prompt}
|
||||
assistant_hist = {"role": "assistant", "content": choice.message.content}
|
||||
history: GameHistory = (user_hist, assistant_hist)
|
||||
to_score.append((history, gold, base64_image))
|
||||
|
||||
return to_score, to_backlog
|
||||
|
||||
async def postprocess_histories(
|
||||
self, trajectories: List[GameHistory]
|
||||
) -> ScoredDataGroup:
|
||||
# No additional post-processing needed
|
||||
pass
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
# No custom evaluation
|
||||
return
|
||||
|
||||
async def setup(self):
|
||||
# Load the OCR-VQA dataset
|
||||
self.dataset = load_dataset("howard-hou/OCR-VQA")
|
||||
self.train = self.dataset["train"]
|
||||
self.iter = 0
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
try:
|
||||
entry = self.train[self.iter % len(self.train)]
|
||||
self.iter += 1
|
||||
|
||||
# Take the first question and answer
|
||||
question = entry["questions"][0]
|
||||
answer = entry["answers"][0]
|
||||
text_prompt = question
|
||||
prompt = tuple(
|
||||
[frozenset({"role": "user", "content": text_prompt}.items())]
|
||||
)
|
||||
|
||||
# Format the gold answer
|
||||
gold_answer = f"<answer>{answer}</answer>"
|
||||
|
||||
# Convert image to base64
|
||||
img = entry["image"]
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="PNG")
|
||||
img_bytes = buf.getvalue()
|
||||
base64_image = base64.b64encode(img_bytes).decode("utf-8")
|
||||
|
||||
return (prompt, gold_answer, base64_image)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
# Fallback example
|
||||
fallback_prompt = tuple(
|
||||
[
|
||||
frozenset(
|
||||
{"role": "user", "content": "Please solve: 2 + 2 = ?"}.items()
|
||||
)
|
||||
]
|
||||
)
|
||||
return (fallback_prompt, "<answer>4</answer>", None)
|
||||
|
||||
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
|
||||
scores = ScoredDataGroup()
|
||||
scores["tokens"] = []
|
||||
scores["masks"] = []
|
||||
scores["scores"] = []
|
||||
scores["images"] = []
|
||||
random.shuffle(rollout_group_data)
|
||||
for item in rollout_group_data:
|
||||
out = tokenize_for_trainer(self.tokenizer, item[0])
|
||||
tokens = out["tokens"]
|
||||
masks = out["masks"]
|
||||
|
||||
# Extract model and gold answers
|
||||
try:
|
||||
reply = item[0][-1]["content"]
|
||||
m = re.search(r"<answer>\s*(.*?)\s*</answer>", reply, re.IGNORECASE)
|
||||
model_answer = m.group(1).strip() if m else reply.strip()
|
||||
|
||||
gold = item[1]
|
||||
g = re.search(r"<answer>\s*(.*?)\s*</answer>", gold, re.IGNORECASE)
|
||||
gold_answer = g.group(1).strip() if g else gold.strip()
|
||||
|
||||
reward = model_answer.lower() == gold_answer.lower()
|
||||
except Exception:
|
||||
reward = False
|
||||
|
||||
# Filter out short examples
|
||||
if len([i for i in masks if i != -100]) < 10:
|
||||
continue
|
||||
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
scores["scores"].append(1.0 if reward else -1.0)
|
||||
try:
|
||||
scores["images"].append(item[2])
|
||||
except Exception:
|
||||
scores["images"].append(None)
|
||||
|
||||
if len(scores["tokens"]) >= self.config.group_size:
|
||||
break
|
||||
|
||||
return scores
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
print("ERROR: OPENAI_API_KEY environment variable is not set!")
|
||||
sys.exit(1)
|
||||
|
||||
config = BaseEnvConfig(
|
||||
wandb_name="ocr_vqa",
|
||||
tokenizer_name="gpt2",
|
||||
group_size=2,
|
||||
use_wandb=False,
|
||||
max_num_workers=2,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=1,
|
||||
steps_per_eval=10,
|
||||
ensure_scores_are_not_same=False,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
OpenaiConfig(
|
||||
model_name="gpt-4o",
|
||||
base_url=None,
|
||||
api_key=os.environ.get("OPENAI_API_KEY"),
|
||||
num_requests_for_eval=1,
|
||||
),
|
||||
]
|
||||
|
||||
return config, server_configs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
OcrVqaEnv.cli()
|
||||
202
environments/multimodal_dpo/pixmo_clocks.py
Normal file
202
environments/multimodal_dpo/pixmo_clocks.py
Normal file
|
|
@ -0,0 +1,202 @@
|
|||
import base64
|
||||
import io
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
|
||||
from atroposlib.type_definitions import GameHistory, Item
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
|
||||
class ClockDatasetEnv(BaseEnv):
|
||||
name = "pixmo_clocks"
|
||||
name_config_cls = BaseEnvConfig
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: Item
|
||||
) -> Tuple[GameHistory | None, List[Item]]:
|
||||
to_score: List[Tuple[GameHistory, str, Optional[str]]] = []
|
||||
to_backlog: List[Item] = []
|
||||
|
||||
# Extract the base64 image
|
||||
base64_image = item[2]
|
||||
|
||||
# Build system instruction and multimodal user message
|
||||
system_msg = {
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You must submit your answer enclosed in <answer> tags, "
|
||||
"e.g., <answer>HH:MM</answer>"
|
||||
),
|
||||
}
|
||||
user_prompt_text = "What time does the clock show?"
|
||||
user_msg_multimodal = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": user_prompt_text},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{base64_image}"},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
messages = [system_msg, user_msg_multimodal]
|
||||
|
||||
# Call chat completion
|
||||
chat_completions = await self.server.chat_completion(
|
||||
messages=messages,
|
||||
n=self.config.group_size,
|
||||
max_tokens=512,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
# Prepare trajectories for scoring
|
||||
for choice in chat_completions.choices:
|
||||
# Use text-only prompt for history
|
||||
user_msg = {"role": "user", "content": user_prompt_text}
|
||||
assistant_msg = {"role": "assistant", "content": choice.message.content}
|
||||
history: GameHistory = (user_msg, assistant_msg)
|
||||
to_score.append((history, item[1], base64_image))
|
||||
|
||||
return to_score, to_backlog
|
||||
|
||||
async def postprocess_histories(
|
||||
self, trajectories: List[GameHistory]
|
||||
) -> ScoredDataGroup:
|
||||
# No custom post-processing
|
||||
pass
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
# No custom evaluation
|
||||
return
|
||||
|
||||
async def setup(self):
|
||||
# Load the clock dataset
|
||||
self.dataset = load_dataset("junyeong-nero/clock-dataset")
|
||||
self.train = self.dataset["train"]
|
||||
self.iter = 0
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
try:
|
||||
entry = self.train[self.iter % len(self.train)]
|
||||
self.iter += 1
|
||||
|
||||
text_prompt = "What time does the clock show"
|
||||
prompt = tuple(
|
||||
[frozenset({"role": "user", "content": text_prompt}.items())]
|
||||
)
|
||||
|
||||
# Format gold answer
|
||||
hour = entry["hour"]
|
||||
minute = entry["minute"]
|
||||
gold_answer = f"<answer>{hour}:{minute:02d}</answer>"
|
||||
|
||||
# Convert image to base64
|
||||
img = entry["image"]
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="PNG")
|
||||
image_bytes = buf.getvalue()
|
||||
base64_image = base64.b64encode(image_bytes).decode("utf-8")
|
||||
|
||||
return (prompt, gold_answer, base64_image)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
# Fallback
|
||||
fallback_prompt = tuple(
|
||||
[
|
||||
frozenset(
|
||||
{"role": "user", "content": "Please solve: 2 + 2 = ?"}.items()
|
||||
)
|
||||
]
|
||||
)
|
||||
return (fallback_prompt, "<answer>4</answer>", None)
|
||||
|
||||
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
|
||||
scores = ScoredDataGroup()
|
||||
scores["tokens"] = []
|
||||
scores["masks"] = []
|
||||
scores["scores"] = []
|
||||
scores["images"] = []
|
||||
random.shuffle(rollout_group_data)
|
||||
for item in rollout_group_data:
|
||||
out = tokenize_for_trainer(self.tokenizer, item[0])
|
||||
tokens = out["tokens"]
|
||||
masks = out["masks"]
|
||||
|
||||
# Extract answers
|
||||
try:
|
||||
reply = item[0][-1]["content"]
|
||||
m_match = re.search(
|
||||
r"<answer>\s*(.*?)\s*</answer>", reply, re.IGNORECASE
|
||||
)
|
||||
model_answer = m_match.group(1).strip() if m_match else reply.strip()
|
||||
|
||||
gold = item[1]
|
||||
g_match = re.search(
|
||||
r"<answer>\s*(.*?)\s*</answer>", gold, re.IGNORECASE
|
||||
)
|
||||
gold_answer = g_match.group(1).strip() if g_match else gold.strip()
|
||||
|
||||
reward = model_answer == gold_answer
|
||||
except Exception:
|
||||
reward = False
|
||||
|
||||
if len([i for i in masks if i != -100]) < 10:
|
||||
continue
|
||||
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
scores["scores"].append(1.0 if reward else -1.0)
|
||||
try:
|
||||
scores["images"].append(item[2])
|
||||
except Exception:
|
||||
scores["images"].append(None)
|
||||
|
||||
if len(scores["tokens"]) >= self.config.group_size:
|
||||
break
|
||||
|
||||
return scores
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
print("ERROR: OPENAI_API_KEY environment variable is not set!")
|
||||
sys.exit(1)
|
||||
|
||||
config = BaseEnvConfig(
|
||||
wandb_name="clocks",
|
||||
tokenizer_name="gpt2",
|
||||
group_size=2,
|
||||
use_wandb=False,
|
||||
max_num_workers=2,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=1,
|
||||
steps_per_eval=10,
|
||||
ensure_scores_are_not_same=False,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
OpenaiConfig(
|
||||
model_name="gpt-4o",
|
||||
base_url=None,
|
||||
api_key=os.environ.get("OPENAI_API_KEY"),
|
||||
num_requests_for_eval=1,
|
||||
),
|
||||
]
|
||||
|
||||
return config, server_configs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ClockDatasetEnv.cli()
|
||||
191
environments/multimodal_dpo/pixmo_count.py
Normal file
191
environments/multimodal_dpo/pixmo_count.py
Normal file
|
|
@ -0,0 +1,191 @@
|
|||
import base64
|
||||
import io
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
from datasets import load_dataset
|
||||
from PIL import Image
|
||||
|
||||
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
|
||||
from atroposlib.type_definitions import GameHistory, Item
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
|
||||
class PixmoCountEnv(BaseEnv):
|
||||
name = "pixmo_count"
|
||||
name_config_cls = BaseEnvConfig
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def setup(self):
|
||||
# Load the pixmo-count dataset
|
||||
self.dataset = load_dataset("allenai/pixmo-count")
|
||||
self.train = self.dataset["train"]
|
||||
self.iter = 0
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
try:
|
||||
entry = self.train[self.iter % len(self.train)]
|
||||
self.iter += 1
|
||||
|
||||
label = entry["label"]
|
||||
count = entry["count"]
|
||||
question = f"how many {label} are in the image?"
|
||||
prompt = tuple([frozenset({"role": "user", "content": question}.items())])
|
||||
|
||||
gold_answer = f"<answer>{count}</answer>"
|
||||
|
||||
# Load image from URL and convert to base64
|
||||
image_url = entry["image_url"]
|
||||
response = requests.get(image_url)
|
||||
img = Image.open(io.BytesIO(response.content))
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="PNG")
|
||||
img_bytes = buf.getvalue()
|
||||
base64_image = base64.b64encode(img_bytes).decode("utf-8")
|
||||
|
||||
return (prompt, gold_answer, base64_image)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
fallback = tuple(
|
||||
[
|
||||
frozenset(
|
||||
{"role": "user", "content": "Please solve: 2 + 2 = ?"}.items()
|
||||
)
|
||||
]
|
||||
)
|
||||
return (fallback, "<answer>4</answer>", None)
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: Item
|
||||
) -> Tuple[GameHistory | None, List[Item]]:
|
||||
to_score: List[Tuple[GameHistory, str, Optional[str]]] = []
|
||||
to_backlog: List[Item] = []
|
||||
|
||||
prompt_tuple, gold, base64_image = item
|
||||
text_prompt = dict(prompt_tuple[0])["content"]
|
||||
|
||||
system_msg = {
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You must submit your answer enclosed in <answer> tags, "
|
||||
"e.g., <answer>3</answer>"
|
||||
),
|
||||
}
|
||||
user_msg = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": text_prompt},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{base64_image}"},
|
||||
},
|
||||
],
|
||||
}
|
||||
messages = [system_msg, user_msg]
|
||||
|
||||
chat_completions = await self.server.chat_completion(
|
||||
messages=messages,
|
||||
n=self.config.group_size,
|
||||
max_tokens=512,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
for choice in chat_completions.choices:
|
||||
user_hist = {"role": "user", "content": text_prompt}
|
||||
assistant_hist = {"role": "assistant", "content": choice.message.content}
|
||||
history: GameHistory = (user_hist, assistant_hist)
|
||||
to_score.append((history, gold, base64_image))
|
||||
|
||||
return to_score, to_backlog
|
||||
|
||||
async def postprocess_histories(
|
||||
self, trajectories: List[GameHistory]
|
||||
) -> ScoredDataGroup:
|
||||
# No custom post-processing
|
||||
pass
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
# No custom evaluation
|
||||
return
|
||||
|
||||
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
|
||||
scores = ScoredDataGroup()
|
||||
scores["tokens"] = []
|
||||
scores["masks"] = []
|
||||
scores["scores"] = []
|
||||
scores["images"] = []
|
||||
random.shuffle(rollout_group_data)
|
||||
for item in rollout_group_data:
|
||||
out = tokenize_for_trainer(self.tokenizer, item[0])
|
||||
tokens = out["tokens"]
|
||||
masks = out["masks"]
|
||||
|
||||
try:
|
||||
reply = item[0][-1]["content"]
|
||||
m = re.search(r"<answer>\s*(.*?)\s*</answer>", reply, re.IGNORECASE)
|
||||
model_answer = m.group(1).strip() if m else reply.strip()
|
||||
|
||||
gold = item[1]
|
||||
g = re.search(r"<answer>\s*(.*?)\s*</answer>", gold, re.IGNORECASE)
|
||||
gold_answer = g.group(1).strip() if g else gold.strip()
|
||||
|
||||
reward = model_answer.lower() == gold_answer.lower()
|
||||
except Exception:
|
||||
reward = False
|
||||
|
||||
if len([i for i in masks if i != -100]) < 10:
|
||||
continue
|
||||
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
scores["scores"].append(1.0 if reward else -1.0)
|
||||
try:
|
||||
scores["images"].append(item[2])
|
||||
except Exception:
|
||||
scores["images"].append(None)
|
||||
|
||||
if len(scores["tokens"]) >= self.config.group_size:
|
||||
break
|
||||
|
||||
return scores
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
print("ERROR: OPENAI_API_KEY environment variable is not set!")
|
||||
sys.exit(1)
|
||||
|
||||
config = BaseEnvConfig(
|
||||
wandb_name="pixmo_count",
|
||||
tokenizer_name="gpt2",
|
||||
group_size=2,
|
||||
use_wandb=False,
|
||||
max_num_workers=2,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=1,
|
||||
steps_per_eval=10,
|
||||
ensure_scores_are_not_same=False,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
OpenaiConfig(
|
||||
model_name="gpt-4o",
|
||||
base_url=None,
|
||||
api_key=os.environ.get("OPENAI_API_KEY"),
|
||||
num_requests_for_eval=1,
|
||||
),
|
||||
]
|
||||
|
||||
return config, server_configs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
PixmoCountEnv.cli()
|
||||
202
environments/multimodal_dpo/pixmo_point_explanations.py
Normal file
202
environments/multimodal_dpo/pixmo_point_explanations.py
Normal file
|
|
@ -0,0 +1,202 @@
|
|||
import base64
|
||||
import io
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
from datasets import load_dataset
|
||||
from PIL import Image
|
||||
|
||||
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
|
||||
from atroposlib.type_definitions import GameHistory, Item
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
|
||||
class PixmoPointExplanationsEnv(BaseEnv):
|
||||
name = "pixmo_point_explanations"
|
||||
name_config_cls = BaseEnvConfig
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def setup(self):
|
||||
# Load the pixmo-point-explanations dataset
|
||||
self.dataset = load_dataset("allenai/pixmo-point-explanations")
|
||||
self.train = self.dataset["train"]
|
||||
self.iter = 0
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
try:
|
||||
entry = self.train[self.iter % len(self.train)]
|
||||
self.iter += 1
|
||||
|
||||
question = entry["question"]
|
||||
# Use the first inline text as the answer
|
||||
answer_text = entry["inline_text"][0]
|
||||
prompt = tuple([frozenset({"role": "user", "content": question}.items())])
|
||||
gold_answer = f"<answer>{answer_text}</answer>"
|
||||
|
||||
# Load image from URL and convert to base64
|
||||
try:
|
||||
image_url = entry["image_url"]
|
||||
response = requests.get(image_url, timeout=10)
|
||||
response.raise_for_status()
|
||||
img = Image.open(io.BytesIO(response.content))
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="PNG")
|
||||
img_bytes = buf.getvalue()
|
||||
base64_image = base64.b64encode(img_bytes).decode("utf-8")
|
||||
except Exception as e:
|
||||
print(f"Error loading image from URL: {e}")
|
||||
base64_image = None
|
||||
|
||||
return (prompt, gold_answer, base64_image)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
fallback = tuple(
|
||||
[
|
||||
frozenset(
|
||||
{"role": "user", "content": "Please solve: 2 + 2 = ?"}.items()
|
||||
)
|
||||
]
|
||||
)
|
||||
return (fallback, "<answer>4</answer>", None)
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: Item
|
||||
) -> Tuple[GameHistory | None, List[Item]]:
|
||||
to_score: List[Tuple[GameHistory, str, Optional[str]]] = []
|
||||
to_backlog: List[Item] = []
|
||||
|
||||
prompt_tuple, gold, base64_image = item
|
||||
text_prompt = dict(prompt_tuple[0])["content"]
|
||||
|
||||
system_msg = {
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You must submit your answer enclosed in <answer> tags, "
|
||||
"e.g., <answer>YOUR_ANSWER</answer>"
|
||||
),
|
||||
}
|
||||
user_msg = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": text_prompt},
|
||||
],
|
||||
}
|
||||
|
||||
# Only add image if we have a valid base64 image
|
||||
if base64_image:
|
||||
user_msg["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{base64_image}"},
|
||||
}
|
||||
)
|
||||
messages = [system_msg, user_msg]
|
||||
|
||||
# Call chat completion
|
||||
chat_completions = await self.server.chat_completion(
|
||||
messages=messages,
|
||||
n=self.config.group_size,
|
||||
max_tokens=512,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
for choice in chat_completions.choices:
|
||||
user_hist = {"role": "user", "content": text_prompt}
|
||||
assistant_hist = {"role": "assistant", "content": choice.message.content}
|
||||
history: GameHistory = (user_hist, assistant_hist)
|
||||
to_score.append((history, gold, base64_image))
|
||||
|
||||
return to_score, to_backlog
|
||||
|
||||
async def postprocess_histories(
|
||||
self, trajectories: List[GameHistory]
|
||||
) -> ScoredDataGroup:
|
||||
# No custom post-processing needed
|
||||
pass
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
# No custom evaluation
|
||||
return
|
||||
|
||||
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
|
||||
scores = ScoredDataGroup()
|
||||
scores["tokens"] = []
|
||||
scores["masks"] = []
|
||||
scores["scores"] = []
|
||||
scores["images"] = []
|
||||
random.shuffle(rollout_group_data)
|
||||
for item in rollout_group_data:
|
||||
out = tokenize_for_trainer(self.tokenizer, item[0])
|
||||
tokens = out["tokens"]
|
||||
masks = out["masks"]
|
||||
|
||||
try:
|
||||
reply = item[0][-1]["content"]
|
||||
m = re.search(r"<answer>\s*(.*?)\s*</answer>", reply, re.IGNORECASE)
|
||||
model_answer = m.group(1).strip() if m else reply.strip()
|
||||
|
||||
gold = item[1]
|
||||
g = re.search(r"<answer>\s*(.*?)\s*</answer>", gold, re.IGNORECASE)
|
||||
gold_answer = g.group(1).strip() if g else gold.strip()
|
||||
|
||||
reward = model_answer.lower() == gold_answer.lower()
|
||||
except Exception:
|
||||
reward = False
|
||||
|
||||
# Filter out short examples
|
||||
if len([i for i in masks if i != -100]) < 10:
|
||||
continue
|
||||
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
scores["scores"].append(1.0 if reward else -1.0)
|
||||
try:
|
||||
scores["images"].append(item[2])
|
||||
except Exception:
|
||||
scores["images"].append(None)
|
||||
|
||||
if len(scores["tokens"]) >= self.config.group_size:
|
||||
break
|
||||
|
||||
return scores
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
print("ERROR: OPENAI_API_KEY environment variable is not set!")
|
||||
sys.exit(1)
|
||||
|
||||
config = BaseEnvConfig(
|
||||
wandb_name="pixmo_point_explanations",
|
||||
tokenizer_name="gpt2",
|
||||
group_size=2,
|
||||
use_wandb=False,
|
||||
max_num_workers=2,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=1,
|
||||
steps_per_eval=10,
|
||||
ensure_scores_are_not_same=False,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
OpenaiConfig(
|
||||
model_name="gpt-4o",
|
||||
base_url=None,
|
||||
api_key=os.environ.get("OPENAI_API_KEY"),
|
||||
num_requests_for_eval=1,
|
||||
),
|
||||
]
|
||||
|
||||
return config, server_configs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
PixmoPointExplanationsEnv.cli()
|
||||
Loading…
Add table
Add a link
Reference in a new issue