first commit

This commit is contained in:
Dakota Nous 2025-04-29 12:10:10 -07:00
commit 621d00dd80
89 changed files with 15315 additions and 0 deletions

View file

@ -0,0 +1,282 @@
import base64
import json
import os
import random
import re
import sys
import traceback
from typing import List, Optional, Tuple
from datasets import load_dataset
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
from atroposlib.type_definitions import GameHistory, Item
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
class MultimodalExampleEnv(BaseEnv):
name = "clevr_cogen_a_train"
name_config_cls = BaseEnvConfig
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
async def collect_trajectories(
self, item: Item
) -> Tuple[GameHistory | None, List[Item]]:
print("DEBUG: Starting collect_trajectories")
to_score = list()
to_backlog = list()
# Get the current image if it was stored
if hasattr(self, "current_image"):
print("DEBUG: Using current_image for multimodal content")
# Convert PIL image to base64
import io
img_byte_arr = io.BytesIO()
self.current_image.save(img_byte_arr, format="PNG")
img_byte_arr = img_byte_arr.getvalue()
base64_image = base64.b64encode(img_byte_arr).decode("utf-8")
# Extract text content from item
user_content = dict(item[0][0]).get("content", "")
# Try to parse if it's JSON
if isinstance(user_content, str) and (
user_content.startswith("[") or user_content.startswith("{")
):
try:
parsed = json.loads(user_content)
text_content = ""
for element in parsed:
if element.get("type") == "text":
text_content = element.get("text", "")
if not text_content:
text_content = "Please solve this problem and provide your answer as \\boxed{answer}."
except Exception as e:
print(f"DEBUG: Error parsing JSON: {e}")
text_content = "Please solve this problem and provide your answer as \\boxed{answer}."
else:
text_content = user_content
# Create messages with the new format
print("DEBUG: Creating multimodal message with new format")
messages = [
{
"role": "system",
"content": "You must submit your answer with \\boxed{answer}, please make sure to do this",
},
{
"role": "user",
"content": [
{"type": "text", "text": text_content},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64_image}",
},
},
],
},
]
else:
print("DEBUG: No image available, using text-only message")
messages = [
{
"role": "system",
"content": "You must submit your answer with \\boxed{answer}",
},
dict(item[0][0]),
]
print("DEBUG: About to call chat_completion")
chat_completions = await self.server.chat_completion(
messages=messages,
n=self.config.group_size,
max_tokens=1024 * 2,
timeout=60, # Add timeout to prevent hanging (60 seconds is more reasonable)
)
print("DEBUG: chat_completion call successful")
for i, chat_completion in enumerate(chat_completions.choices):
print(f"DEBUG: Processing completion {i+1}/{len(chat_completions.choices)}")
messages = (
dict(item[0][0]),
{"role": "assistant", "content": chat_completion.message.content},
)
to_score.append((messages, item[1], base64_image))
print("DEBUG: Finished processing completions")
print("DEBUG: Returning from collect_trajectories")
return to_score, to_backlog
async def postprocess_histories(
self, trajectories: List[GameHistory]
) -> ScoredDataGroup:
pass
async def evaluate(self, *args, **kwargs):
"""
Evaluate the environment, this is called every steps_per_eval steps
:param args:
:param kwargs:
:return: None.
"""
return
async def setup(self):
"""Setup the environment and load the multimodal dataset"""
self.dataset = load_dataset("leonardPKU/clevr_cogen_a_train")
self.train = self.dataset["train"]
self.iter = 0
async def get_next_item(self) -> Item:
"""
Get the next items to be rolled out, including the image
"""
try:
print("DEBUG: Starting get_next_item")
# Get next dataset item
next_item = self.train[self.iter % len(self.train)]
self.iter += 1
print(f"DEBUG: Retrieved dataset item {self.iter-1}")
# For debugging, we'll use a simple text-only prompt and store the image separately
# This is because the collect_trajectories method will handle the multimodal formatting
# Store image as a class attribute so collect_trajectories can access it
self.current_image = next_item["image"]
print("DEBUG: Stored image in current_image attribute")
# Create a simple text prompt - the image will be added in collect_trajectories
# This avoids the unhashable type error with lists in frozensets
text_prompt = next_item["problem"]
# Create a simple text-only prompt
prompt = tuple(
[frozenset({"role": "user", "content": text_prompt}.items())]
)
answer = next_item["solution"]
# get image as base64
# image = next_item["image"]
# Convert PIL image to base64
import io
img_byte_arr = io.BytesIO()
self.current_image.save(img_byte_arr, format="PNG")
img_byte_arr = img_byte_arr.getvalue()
base64_image = base64.b64encode(img_byte_arr).decode("utf-8")
print("DEBUG: Created simple text-only prompt for get_next_item")
return (prompt, answer, base64_image)
except Exception as e:
print(f"DEBUG: Error in get_next_item: {str(e)}")
traceback.print_exc()
# Create a dummy item as fallback
prompt = tuple(
[
frozenset(
{"role": "user", "content": "Please solve: 2 + 2 = ?"}.items()
)
]
)
answer = "4"
return (prompt, answer, "obobob")
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
scores = ScoredDataGroup()
scores["tokens"] = list()
scores["masks"] = list()
scores["scores"] = list()
scores["images"] = list()
random.shuffle(rollout_group_data)
for item in rollout_group_data:
out_dict = tokenize_for_trainer(self.tokenizer, item[0])
tokens = out_dict["tokens"]
masks = out_dict["masks"]
# Extract the answer from the model's response
try:
model_answer = (
item[0][-1]["content"].split("\\boxed{")[-1].split("}")[0]
)
print(
f"DEBUG: Model answer: {model_answer} and RG data: {rollout_group_data[0][1]}"
)
pattern = r"<answer>\s*(\d{1,2})\s*</answer>"
string = rollout_group_data[0][1]
gold_answer = re.search(pattern, string).group(1)
reward = gold_answer == model_answer
except IndexError:
reward = False
# remove obviously bad examples
if len([1 for i in masks if i != -100]) < 10:
continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(1.0 if reward else -1.0)
try:
scores["images"].append(item[2])
except IndexError:
scores["images"].append(None)
if len(scores["tokens"]) >= self.config.group_size:
break
return scores
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
if not os.environ.get("OPENAI_API_KEY"):
print("ERROR: OPENAI_API_KEY environment variable is not set!")
print("Please set it using: export OPENAI_API_KEY=your_api_key")
sys.exit(1)
print(
f"DEBUG: Using API key starting with: {os.environ.get('OPENAI_API_KEY')[:5]}..."
)
config = BaseEnvConfig(
wandb_name="clevr_cogen",
tokenizer_name="gpt2",
group_size=2,
use_wandb=False,
max_num_workers=2,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=1,
steps_per_eval=10,
ensure_scores_are_not_same=False,
)
print("DEBUG: Creating OpenAI configuration")
server_configs = [
OpenaiConfig(
model_name="gpt-4o", # Using GPT-4o which has multimodal capabilities
base_url=None,
api_key=os.environ.get("OPENAI_API_KEY"),
num_requests_for_eval=1,
),
]
return config, server_configs
if __name__ == "__main__":
MultimodalExampleEnv.cli()

View file

@ -0,0 +1,283 @@
import base64
import json
import os
import random
import sys
import traceback
from typing import List, Optional, Tuple
from datasets import load_dataset
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
from atroposlib.type_definitions import GameHistory, Item
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
class MultimodalComplexEnv(BaseEnv):
name = "clevr_complex"
name_config_cls = BaseEnvConfig
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
async def collect_trajectories(
self, item: Item
) -> Tuple[GameHistory | None, List[Item]]:
print("DEBUG: Starting collect_trajectories")
to_score = list()
to_backlog = list()
# Get the current image if it was stored
if hasattr(self, "current_image"):
print("DEBUG: Using current_image for multimodal content")
# Convert PIL image to base64
import io
img_byte_arr = io.BytesIO()
self.current_image.save(img_byte_arr, format="PNG")
img_byte_arr = img_byte_arr.getvalue()
base64_image = base64.b64encode(img_byte_arr).decode("utf-8")
# Extract text content from item
user_content = dict(item[0][0]).get("content", "")
# Try to parse if it's JSON
if isinstance(user_content, str) and (
user_content.startswith("[") or user_content.startswith("{")
):
try:
parsed = json.loads(user_content)
text_content = ""
for element in parsed:
if element.get("type") == "text":
text_content = element.get("text", "")
if not text_content:
text_content = "Please solve this problem and provide your answer as \\boxed{answer}."
except Exception as e:
print(f"DEBUG: Error parsing JSON: {e}")
text_content = "Please solve this problem and provide your answer as \\boxed{answer}."
else:
text_content = user_content
# Create messages with the new format
print("DEBUG: Creating multimodal message with new format")
messages = [
{
"role": "system",
"content": "You must submit your answer with \\boxed{answer}, please make sure to do this",
},
{
"role": "user",
"content": [
{"type": "text", "text": text_content},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64_image}",
},
},
],
},
]
else:
print("DEBUG: No image available, using text-only message")
messages = [
{
"role": "system",
"content": "You must submit your answer with \\boxed{answer}",
},
dict(item[0][0]),
]
print("DEBUG: About to call chat_completion")
chat_completions = await self.server.chat_completion(
messages=messages,
n=self.config.group_size,
max_tokens=1024 * 2,
timeout=60, # Add timeout to prevent hanging (60 seconds is more reasonable)
)
print("DEBUG: chat_completion call successful")
for i, chat_completion in enumerate(chat_completions.choices):
print(f"DEBUG: Processing completion {i+1}/{len(chat_completions.choices)}")
messages = (
dict(item[0][0]),
{"role": "assistant", "content": chat_completion.message.content},
)
to_score.append((messages, item[1], base64_image))
print("DEBUG: Finished processing completions")
print("DEBUG: Returning from collect_trajectories")
return to_score, to_backlog
async def postprocess_histories(
self, trajectories: List[GameHistory]
) -> ScoredDataGroup:
pass
async def evaluate(self, *args, **kwargs):
"""
Evaluate the environment, this is called every steps_per_eval steps
:param args:
:param kwargs:
:return: None.
"""
return
async def setup(self):
"""Setup the environment and load the multimodal dataset"""
self.dataset = load_dataset("MMInstruction/Clevr_CoGenT_TrainA_70K_Complex")
self.train = self.dataset["train"]
self.iter = 0
async def get_next_item(self) -> Item:
"""
Get the next items to be rolled out, including the image
"""
try:
print("DEBUG: Starting get_next_item")
# Get next dataset item
next_item = self.train[self.iter % len(self.train)]
self.iter += 1
print(f"DEBUG: Retrieved dataset item {self.iter-1}")
# For debugging, we'll use a simple text-only prompt and store the image separately
# This is because the collect_trajectories method will handle the multimodal formatting
# Store image as a class attribute so collect_trajectories can access it
self.current_image = next_item["image"]
print("DEBUG: Stored image in current_image attribute")
# Create a simple text prompt - the image will be added in collect_trajectories
# This avoids the unhashable type error with lists in frozensets
text_prompt = next_item["problem"]
# Create a simple text-only prompt
prompt = tuple(
[frozenset({"role": "user", "content": text_prompt}.items())]
)
answer = next_item["solution"]
# Convert PIL image to base64
import io
img_byte_arr = io.BytesIO()
self.current_image.save(img_byte_arr, format="PNG")
img_byte_arr = img_byte_arr.getvalue()
base64_image = base64.b64encode(img_byte_arr).decode("utf-8")
print("DEBUG: Created simple text-only prompt for get_next_item")
return (prompt, answer, base64_image)
except Exception as e:
print(f"DEBUG: Error in get_next_item: {str(e)}")
traceback.print_exc()
# Create a dummy item as fallback
prompt = tuple(
[
frozenset(
{"role": "user", "content": "Please solve: 2 + 2 = ?"}.items()
)
]
)
answer = "4"
return (prompt, answer, "obobob")
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
scores = ScoredDataGroup()
scores["tokens"] = list()
scores["masks"] = list()
scores["scores"] = list()
scores["images"] = list()
random.shuffle(rollout_group_data)
for item in rollout_group_data:
out_dict = tokenize_for_trainer(self.tokenizer, item[0])
tokens = out_dict["tokens"]
masks = out_dict["masks"]
# Extract the answer from the model's response
try:
model_answer = (
item[0][-1]["content"].split("\\boxed{")[-1].split("}")[0]
)
print(
f"DEBUG: Model answer: {model_answer} and RG data: {rollout_group_data[0][1]}"
)
# Handle both numeric and yes/no answers
gold_answer = rollout_group_data[0][1]
# Case-insensitive comparison for yes/no and direct comparison for numbers
if gold_answer.lower() in ["yes", "no"]:
reward = gold_answer.lower() == model_answer.lower()
else:
# For numeric answers
reward = gold_answer == model_answer
except IndexError:
reward = False
# remove obviously bad examples
if len([1 for i in masks if i != -100]) < 10:
continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(1.0 if reward else -1.0)
try:
scores["images"].append(item[2])
except IndexError:
scores["images"].append(None)
if len(scores["tokens"]) >= self.config.group_size:
break
return scores
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
if not os.environ.get("OPENAI_API_KEY"):
print("ERROR: OPENAI_API_KEY environment variable is not set!")
print("Please set it using: export OPENAI_API_KEY=your_api_key")
sys.exit(1)
print(
f"DEBUG: Using API key starting with: {os.environ.get('OPENAI_API_KEY')[:5]}..."
)
config = BaseEnvConfig(
wandb_name="clevr_complex",
tokenizer_name="gpt2",
group_size=2,
use_wandb=False,
max_num_workers=2,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=1,
steps_per_eval=10,
ensure_scores_are_not_same=False,
)
print("DEBUG: Creating OpenAI configuration")
server_configs = [
OpenaiConfig(
model_name="gpt-4o", # Using GPT-4o which has multimodal capabilities
base_url=None,
api_key=os.environ.get("OPENAI_API_KEY"),
num_requests_for_eval=1,
),
]
return config, server_configs
if __name__ == "__main__":
MultimodalComplexEnv.cli()

View file

@ -0,0 +1,200 @@
import base64
import io
import os
import random
import re
import sys
import traceback
from typing import List, Optional, Tuple
from datasets import load_dataset
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
from atroposlib.type_definitions import GameHistory, Item
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
class OcrVqaEnv(BaseEnv):
name = "ocr_vqa"
name_config_cls = BaseEnvConfig
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
async def collect_trajectories(
self, item: Item
) -> Tuple[GameHistory | None, List[Item]]:
to_score: List[Tuple[GameHistory, str, Optional[str]]] = []
to_backlog: List[Item] = []
# Extract question and image from item
prompt_tuple, gold, base64_image = item
# The prompt_tuple contains the user prompt as a frozenset
text_prompt = dict(prompt_tuple[0])["content"]
# System instruction for answer formatting
system_msg = {
"role": "system",
"content": (
"You must submit your answer enclosed in <answer> tags, "
"e.g., <answer>YOUR_ANSWER</answer>"
),
}
# Multimodal user message with text and image
user_msg = {
"role": "user",
"content": [
{"type": "text", "text": text_prompt},
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{base64_image}"},
},
],
}
messages = [system_msg, user_msg]
# Call the chat completion endpoint
chat_completions = await self.server.chat_completion(
messages=messages,
n=self.config.group_size,
max_tokens=512,
timeout=60,
)
# Build trajectories for scoring
for choice in chat_completions.choices:
user_hist = {"role": "user", "content": text_prompt}
assistant_hist = {"role": "assistant", "content": choice.message.content}
history: GameHistory = (user_hist, assistant_hist)
to_score.append((history, gold, base64_image))
return to_score, to_backlog
async def postprocess_histories(
self, trajectories: List[GameHistory]
) -> ScoredDataGroup:
# No additional post-processing needed
pass
async def evaluate(self, *args, **kwargs):
# No custom evaluation
return
async def setup(self):
# Load the OCR-VQA dataset
self.dataset = load_dataset("howard-hou/OCR-VQA")
self.train = self.dataset["train"]
self.iter = 0
async def get_next_item(self) -> Item:
try:
entry = self.train[self.iter % len(self.train)]
self.iter += 1
# Take the first question and answer
question = entry["questions"][0]
answer = entry["answers"][0]
text_prompt = question
prompt = tuple(
[frozenset({"role": "user", "content": text_prompt}.items())]
)
# Format the gold answer
gold_answer = f"<answer>{answer}</answer>"
# Convert image to base64
img = entry["image"]
buf = io.BytesIO()
img.save(buf, format="PNG")
img_bytes = buf.getvalue()
base64_image = base64.b64encode(img_bytes).decode("utf-8")
return (prompt, gold_answer, base64_image)
except Exception:
traceback.print_exc()
# Fallback example
fallback_prompt = tuple(
[
frozenset(
{"role": "user", "content": "Please solve: 2 + 2 = ?"}.items()
)
]
)
return (fallback_prompt, "<answer>4</answer>", None)
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
scores = ScoredDataGroup()
scores["tokens"] = []
scores["masks"] = []
scores["scores"] = []
scores["images"] = []
random.shuffle(rollout_group_data)
for item in rollout_group_data:
out = tokenize_for_trainer(self.tokenizer, item[0])
tokens = out["tokens"]
masks = out["masks"]
# Extract model and gold answers
try:
reply = item[0][-1]["content"]
m = re.search(r"<answer>\s*(.*?)\s*</answer>", reply, re.IGNORECASE)
model_answer = m.group(1).strip() if m else reply.strip()
gold = item[1]
g = re.search(r"<answer>\s*(.*?)\s*</answer>", gold, re.IGNORECASE)
gold_answer = g.group(1).strip() if g else gold.strip()
reward = model_answer.lower() == gold_answer.lower()
except Exception:
reward = False
# Filter out short examples
if len([i for i in masks if i != -100]) < 10:
continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(1.0 if reward else -1.0)
try:
scores["images"].append(item[2])
except Exception:
scores["images"].append(None)
if len(scores["tokens"]) >= self.config.group_size:
break
return scores
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
if not os.environ.get("OPENAI_API_KEY"):
print("ERROR: OPENAI_API_KEY environment variable is not set!")
sys.exit(1)
config = BaseEnvConfig(
wandb_name="ocr_vqa",
tokenizer_name="gpt2",
group_size=2,
use_wandb=False,
max_num_workers=2,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=1,
steps_per_eval=10,
ensure_scores_are_not_same=False,
)
server_configs = [
OpenaiConfig(
model_name="gpt-4o",
base_url=None,
api_key=os.environ.get("OPENAI_API_KEY"),
num_requests_for_eval=1,
),
]
return config, server_configs
if __name__ == "__main__":
OcrVqaEnv.cli()

View file

@ -0,0 +1,202 @@
import base64
import io
import os
import random
import re
import sys
import traceback
from typing import List, Optional, Tuple
from datasets import load_dataset
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
from atroposlib.type_definitions import GameHistory, Item
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
class ClockDatasetEnv(BaseEnv):
name = "pixmo_clocks"
name_config_cls = BaseEnvConfig
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
async def collect_trajectories(
self, item: Item
) -> Tuple[GameHistory | None, List[Item]]:
to_score: List[Tuple[GameHistory, str, Optional[str]]] = []
to_backlog: List[Item] = []
# Extract the base64 image
base64_image = item[2]
# Build system instruction and multimodal user message
system_msg = {
"role": "system",
"content": (
"You must submit your answer enclosed in <answer> tags, "
"e.g., <answer>HH:MM</answer>"
),
}
user_prompt_text = "What time does the clock show?"
user_msg_multimodal = {
"role": "user",
"content": [
{"type": "text", "text": user_prompt_text},
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{base64_image}"},
},
],
}
messages = [system_msg, user_msg_multimodal]
# Call chat completion
chat_completions = await self.server.chat_completion(
messages=messages,
n=self.config.group_size,
max_tokens=512,
timeout=60,
)
# Prepare trajectories for scoring
for choice in chat_completions.choices:
# Use text-only prompt for history
user_msg = {"role": "user", "content": user_prompt_text}
assistant_msg = {"role": "assistant", "content": choice.message.content}
history: GameHistory = (user_msg, assistant_msg)
to_score.append((history, item[1], base64_image))
return to_score, to_backlog
async def postprocess_histories(
self, trajectories: List[GameHistory]
) -> ScoredDataGroup:
# No custom post-processing
pass
async def evaluate(self, *args, **kwargs):
# No custom evaluation
return
async def setup(self):
# Load the clock dataset
self.dataset = load_dataset("junyeong-nero/clock-dataset")
self.train = self.dataset["train"]
self.iter = 0
async def get_next_item(self) -> Item:
try:
entry = self.train[self.iter % len(self.train)]
self.iter += 1
text_prompt = "What time does the clock show"
prompt = tuple(
[frozenset({"role": "user", "content": text_prompt}.items())]
)
# Format gold answer
hour = entry["hour"]
minute = entry["minute"]
gold_answer = f"<answer>{hour}:{minute:02d}</answer>"
# Convert image to base64
img = entry["image"]
buf = io.BytesIO()
img.save(buf, format="PNG")
image_bytes = buf.getvalue()
base64_image = base64.b64encode(image_bytes).decode("utf-8")
return (prompt, gold_answer, base64_image)
except Exception:
traceback.print_exc()
# Fallback
fallback_prompt = tuple(
[
frozenset(
{"role": "user", "content": "Please solve: 2 + 2 = ?"}.items()
)
]
)
return (fallback_prompt, "<answer>4</answer>", None)
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
scores = ScoredDataGroup()
scores["tokens"] = []
scores["masks"] = []
scores["scores"] = []
scores["images"] = []
random.shuffle(rollout_group_data)
for item in rollout_group_data:
out = tokenize_for_trainer(self.tokenizer, item[0])
tokens = out["tokens"]
masks = out["masks"]
# Extract answers
try:
reply = item[0][-1]["content"]
m_match = re.search(
r"<answer>\s*(.*?)\s*</answer>", reply, re.IGNORECASE
)
model_answer = m_match.group(1).strip() if m_match else reply.strip()
gold = item[1]
g_match = re.search(
r"<answer>\s*(.*?)\s*</answer>", gold, re.IGNORECASE
)
gold_answer = g_match.group(1).strip() if g_match else gold.strip()
reward = model_answer == gold_answer
except Exception:
reward = False
if len([i for i in masks if i != -100]) < 10:
continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(1.0 if reward else -1.0)
try:
scores["images"].append(item[2])
except Exception:
scores["images"].append(None)
if len(scores["tokens"]) >= self.config.group_size:
break
return scores
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
if not os.environ.get("OPENAI_API_KEY"):
print("ERROR: OPENAI_API_KEY environment variable is not set!")
sys.exit(1)
config = BaseEnvConfig(
wandb_name="clocks",
tokenizer_name="gpt2",
group_size=2,
use_wandb=False,
max_num_workers=2,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=1,
steps_per_eval=10,
ensure_scores_are_not_same=False,
)
server_configs = [
OpenaiConfig(
model_name="gpt-4o",
base_url=None,
api_key=os.environ.get("OPENAI_API_KEY"),
num_requests_for_eval=1,
),
]
return config, server_configs
if __name__ == "__main__":
ClockDatasetEnv.cli()

View file

@ -0,0 +1,191 @@
import base64
import io
import os
import random
import re
import sys
import traceback
from typing import List, Optional, Tuple
import requests
from datasets import load_dataset
from PIL import Image
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
from atroposlib.type_definitions import GameHistory, Item
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
class PixmoCountEnv(BaseEnv):
name = "pixmo_count"
name_config_cls = BaseEnvConfig
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
async def setup(self):
# Load the pixmo-count dataset
self.dataset = load_dataset("allenai/pixmo-count")
self.train = self.dataset["train"]
self.iter = 0
async def get_next_item(self) -> Item:
try:
entry = self.train[self.iter % len(self.train)]
self.iter += 1
label = entry["label"]
count = entry["count"]
question = f"how many {label} are in the image?"
prompt = tuple([frozenset({"role": "user", "content": question}.items())])
gold_answer = f"<answer>{count}</answer>"
# Load image from URL and convert to base64
image_url = entry["image_url"]
response = requests.get(image_url)
img = Image.open(io.BytesIO(response.content))
buf = io.BytesIO()
img.save(buf, format="PNG")
img_bytes = buf.getvalue()
base64_image = base64.b64encode(img_bytes).decode("utf-8")
return (prompt, gold_answer, base64_image)
except Exception:
traceback.print_exc()
fallback = tuple(
[
frozenset(
{"role": "user", "content": "Please solve: 2 + 2 = ?"}.items()
)
]
)
return (fallback, "<answer>4</answer>", None)
async def collect_trajectories(
self, item: Item
) -> Tuple[GameHistory | None, List[Item]]:
to_score: List[Tuple[GameHistory, str, Optional[str]]] = []
to_backlog: List[Item] = []
prompt_tuple, gold, base64_image = item
text_prompt = dict(prompt_tuple[0])["content"]
system_msg = {
"role": "system",
"content": (
"You must submit your answer enclosed in <answer> tags, "
"e.g., <answer>3</answer>"
),
}
user_msg = {
"role": "user",
"content": [
{"type": "text", "text": text_prompt},
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{base64_image}"},
},
],
}
messages = [system_msg, user_msg]
chat_completions = await self.server.chat_completion(
messages=messages,
n=self.config.group_size,
max_tokens=512,
timeout=60,
)
for choice in chat_completions.choices:
user_hist = {"role": "user", "content": text_prompt}
assistant_hist = {"role": "assistant", "content": choice.message.content}
history: GameHistory = (user_hist, assistant_hist)
to_score.append((history, gold, base64_image))
return to_score, to_backlog
async def postprocess_histories(
self, trajectories: List[GameHistory]
) -> ScoredDataGroup:
# No custom post-processing
pass
async def evaluate(self, *args, **kwargs):
# No custom evaluation
return
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
scores = ScoredDataGroup()
scores["tokens"] = []
scores["masks"] = []
scores["scores"] = []
scores["images"] = []
random.shuffle(rollout_group_data)
for item in rollout_group_data:
out = tokenize_for_trainer(self.tokenizer, item[0])
tokens = out["tokens"]
masks = out["masks"]
try:
reply = item[0][-1]["content"]
m = re.search(r"<answer>\s*(.*?)\s*</answer>", reply, re.IGNORECASE)
model_answer = m.group(1).strip() if m else reply.strip()
gold = item[1]
g = re.search(r"<answer>\s*(.*?)\s*</answer>", gold, re.IGNORECASE)
gold_answer = g.group(1).strip() if g else gold.strip()
reward = model_answer.lower() == gold_answer.lower()
except Exception:
reward = False
if len([i for i in masks if i != -100]) < 10:
continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(1.0 if reward else -1.0)
try:
scores["images"].append(item[2])
except Exception:
scores["images"].append(None)
if len(scores["tokens"]) >= self.config.group_size:
break
return scores
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
if not os.environ.get("OPENAI_API_KEY"):
print("ERROR: OPENAI_API_KEY environment variable is not set!")
sys.exit(1)
config = BaseEnvConfig(
wandb_name="pixmo_count",
tokenizer_name="gpt2",
group_size=2,
use_wandb=False,
max_num_workers=2,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=1,
steps_per_eval=10,
ensure_scores_are_not_same=False,
)
server_configs = [
OpenaiConfig(
model_name="gpt-4o",
base_url=None,
api_key=os.environ.get("OPENAI_API_KEY"),
num_requests_for_eval=1,
),
]
return config, server_configs
if __name__ == "__main__":
PixmoCountEnv.cli()

View file

@ -0,0 +1,202 @@
import base64
import io
import os
import random
import re
import sys
import traceback
from typing import List, Optional, Tuple
import requests
from datasets import load_dataset
from PIL import Image
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
from atroposlib.type_definitions import GameHistory, Item
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
class PixmoPointExplanationsEnv(BaseEnv):
name = "pixmo_point_explanations"
name_config_cls = BaseEnvConfig
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
async def setup(self):
# Load the pixmo-point-explanations dataset
self.dataset = load_dataset("allenai/pixmo-point-explanations")
self.train = self.dataset["train"]
self.iter = 0
async def get_next_item(self) -> Item:
try:
entry = self.train[self.iter % len(self.train)]
self.iter += 1
question = entry["question"]
# Use the first inline text as the answer
answer_text = entry["inline_text"][0]
prompt = tuple([frozenset({"role": "user", "content": question}.items())])
gold_answer = f"<answer>{answer_text}</answer>"
# Load image from URL and convert to base64
try:
image_url = entry["image_url"]
response = requests.get(image_url, timeout=10)
response.raise_for_status()
img = Image.open(io.BytesIO(response.content))
buf = io.BytesIO()
img.save(buf, format="PNG")
img_bytes = buf.getvalue()
base64_image = base64.b64encode(img_bytes).decode("utf-8")
except Exception as e:
print(f"Error loading image from URL: {e}")
base64_image = None
return (prompt, gold_answer, base64_image)
except Exception:
traceback.print_exc()
fallback = tuple(
[
frozenset(
{"role": "user", "content": "Please solve: 2 + 2 = ?"}.items()
)
]
)
return (fallback, "<answer>4</answer>", None)
async def collect_trajectories(
self, item: Item
) -> Tuple[GameHistory | None, List[Item]]:
to_score: List[Tuple[GameHistory, str, Optional[str]]] = []
to_backlog: List[Item] = []
prompt_tuple, gold, base64_image = item
text_prompt = dict(prompt_tuple[0])["content"]
system_msg = {
"role": "system",
"content": (
"You must submit your answer enclosed in <answer> tags, "
"e.g., <answer>YOUR_ANSWER</answer>"
),
}
user_msg = {
"role": "user",
"content": [
{"type": "text", "text": text_prompt},
],
}
# Only add image if we have a valid base64 image
if base64_image:
user_msg["content"].append(
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{base64_image}"},
}
)
messages = [system_msg, user_msg]
# Call chat completion
chat_completions = await self.server.chat_completion(
messages=messages,
n=self.config.group_size,
max_tokens=512,
timeout=60,
)
for choice in chat_completions.choices:
user_hist = {"role": "user", "content": text_prompt}
assistant_hist = {"role": "assistant", "content": choice.message.content}
history: GameHistory = (user_hist, assistant_hist)
to_score.append((history, gold, base64_image))
return to_score, to_backlog
async def postprocess_histories(
self, trajectories: List[GameHistory]
) -> ScoredDataGroup:
# No custom post-processing needed
pass
async def evaluate(self, *args, **kwargs):
# No custom evaluation
return
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
scores = ScoredDataGroup()
scores["tokens"] = []
scores["masks"] = []
scores["scores"] = []
scores["images"] = []
random.shuffle(rollout_group_data)
for item in rollout_group_data:
out = tokenize_for_trainer(self.tokenizer, item[0])
tokens = out["tokens"]
masks = out["masks"]
try:
reply = item[0][-1]["content"]
m = re.search(r"<answer>\s*(.*?)\s*</answer>", reply, re.IGNORECASE)
model_answer = m.group(1).strip() if m else reply.strip()
gold = item[1]
g = re.search(r"<answer>\s*(.*?)\s*</answer>", gold, re.IGNORECASE)
gold_answer = g.group(1).strip() if g else gold.strip()
reward = model_answer.lower() == gold_answer.lower()
except Exception:
reward = False
# Filter out short examples
if len([i for i in masks if i != -100]) < 10:
continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(1.0 if reward else -1.0)
try:
scores["images"].append(item[2])
except Exception:
scores["images"].append(None)
if len(scores["tokens"]) >= self.config.group_size:
break
return scores
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
if not os.environ.get("OPENAI_API_KEY"):
print("ERROR: OPENAI_API_KEY environment variable is not set!")
sys.exit(1)
config = BaseEnvConfig(
wandb_name="pixmo_point_explanations",
tokenizer_name="gpt2",
group_size=2,
use_wandb=False,
max_num_workers=2,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=1,
steps_per_eval=10,
ensure_scores_are_not_same=False,
)
server_configs = [
OpenaiConfig(
model_name="gpt-4o",
base_url=None,
api_key=os.environ.get("OPENAI_API_KEY"),
num_requests_for_eval=1,
),
]
return config, server_configs
if __name__ == "__main__":
PixmoPointExplanationsEnv.cli()