mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
202 lines
6.3 KiB
Python
202 lines
6.3 KiB
Python
import base64
|
|
import io
|
|
import os
|
|
import random
|
|
import re
|
|
import sys
|
|
import traceback
|
|
from typing import List, Optional, Tuple
|
|
|
|
from datasets import load_dataset
|
|
|
|
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
|
|
from atroposlib.type_definitions import GameHistory, Item
|
|
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
|
|
|
|
|
class ClockDatasetEnv(BaseEnv):
|
|
name = "pixmo_clocks"
|
|
name_config_cls = BaseEnvConfig
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
async def collect_trajectories(
|
|
self, item: Item
|
|
) -> Tuple[GameHistory | None, List[Item]]:
|
|
to_score: List[Tuple[GameHistory, str, Optional[str]]] = []
|
|
to_backlog: List[Item] = []
|
|
|
|
# Extract the base64 image
|
|
base64_image = item[2]
|
|
|
|
# Build system instruction and multimodal user message
|
|
system_msg = {
|
|
"role": "system",
|
|
"content": (
|
|
"You must submit your answer enclosed in <answer> tags, "
|
|
"e.g., <answer>HH:MM</answer>"
|
|
),
|
|
}
|
|
user_prompt_text = "What time does the clock show?"
|
|
user_msg_multimodal = {
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": user_prompt_text},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {"url": f"data:image/png;base64,{base64_image}"},
|
|
},
|
|
],
|
|
}
|
|
|
|
messages = [system_msg, user_msg_multimodal]
|
|
|
|
# Call chat completion
|
|
chat_completions = await self.server.chat_completion(
|
|
messages=messages,
|
|
n=self.config.group_size,
|
|
max_tokens=512,
|
|
timeout=60,
|
|
)
|
|
|
|
# Prepare trajectories for scoring
|
|
for choice in chat_completions.choices:
|
|
# Use text-only prompt for history
|
|
user_msg = {"role": "user", "content": user_prompt_text}
|
|
assistant_msg = {"role": "assistant", "content": choice.message.content}
|
|
history: GameHistory = (user_msg, assistant_msg)
|
|
to_score.append((history, item[1], base64_image))
|
|
|
|
return to_score, to_backlog
|
|
|
|
async def postprocess_histories(
|
|
self, trajectories: List[GameHistory]
|
|
) -> ScoredDataGroup:
|
|
# No custom post-processing
|
|
pass
|
|
|
|
async def evaluate(self, *args, **kwargs):
|
|
# No custom evaluation
|
|
return
|
|
|
|
async def setup(self):
|
|
# Load the clock dataset
|
|
self.dataset = load_dataset("junyeong-nero/clock-dataset")
|
|
self.train = self.dataset["train"]
|
|
self.iter = 0
|
|
|
|
async def get_next_item(self) -> Item:
|
|
try:
|
|
entry = self.train[self.iter % len(self.train)]
|
|
self.iter += 1
|
|
|
|
text_prompt = "What time does the clock show"
|
|
prompt = tuple(
|
|
[frozenset({"role": "user", "content": text_prompt}.items())]
|
|
)
|
|
|
|
# Format gold answer
|
|
hour = entry["hour"]
|
|
minute = entry["minute"]
|
|
gold_answer = f"<answer>{hour}:{minute:02d}</answer>"
|
|
|
|
# Convert image to base64
|
|
img = entry["image"]
|
|
buf = io.BytesIO()
|
|
img.save(buf, format="PNG")
|
|
image_bytes = buf.getvalue()
|
|
base64_image = base64.b64encode(image_bytes).decode("utf-8")
|
|
|
|
return (prompt, gold_answer, base64_image)
|
|
except Exception:
|
|
traceback.print_exc()
|
|
# Fallback
|
|
fallback_prompt = tuple(
|
|
[
|
|
frozenset(
|
|
{"role": "user", "content": "Please solve: 2 + 2 = ?"}.items()
|
|
)
|
|
]
|
|
)
|
|
return (fallback_prompt, "<answer>4</answer>", None)
|
|
|
|
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
|
|
scores = ScoredDataGroup()
|
|
scores["tokens"] = []
|
|
scores["masks"] = []
|
|
scores["scores"] = []
|
|
scores["images"] = []
|
|
random.shuffle(rollout_group_data)
|
|
for item in rollout_group_data:
|
|
out = tokenize_for_trainer(self.tokenizer, item[0])
|
|
tokens = out["tokens"]
|
|
masks = out["masks"]
|
|
|
|
# Extract answers
|
|
try:
|
|
reply = item[0][-1]["content"]
|
|
m_match = re.search(
|
|
r"<answer>\s*(.*?)\s*</answer>", reply, re.IGNORECASE
|
|
)
|
|
model_answer = m_match.group(1).strip() if m_match else reply.strip()
|
|
|
|
gold = item[1]
|
|
g_match = re.search(
|
|
r"<answer>\s*(.*?)\s*</answer>", gold, re.IGNORECASE
|
|
)
|
|
gold_answer = g_match.group(1).strip() if g_match else gold.strip()
|
|
|
|
reward = model_answer == gold_answer
|
|
except Exception:
|
|
reward = False
|
|
|
|
if len([i for i in masks if i != -100]) < 10:
|
|
continue
|
|
|
|
scores["tokens"].append(tokens)
|
|
scores["masks"].append(masks)
|
|
scores["scores"].append(1.0 if reward else -1.0)
|
|
try:
|
|
scores["images"].append(item[2])
|
|
except Exception:
|
|
scores["images"].append(None)
|
|
|
|
if len(scores["tokens"]) >= self.config.group_size:
|
|
break
|
|
|
|
return scores
|
|
|
|
@classmethod
|
|
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
|
|
if not os.environ.get("OPENAI_API_KEY"):
|
|
print("ERROR: OPENAI_API_KEY environment variable is not set!")
|
|
sys.exit(1)
|
|
|
|
config = BaseEnvConfig(
|
|
wandb_name="clocks",
|
|
tokenizer_name="gpt2",
|
|
group_size=2,
|
|
use_wandb=False,
|
|
max_num_workers=2,
|
|
rollout_server_url="http://localhost:8000",
|
|
total_steps=1000,
|
|
batch_size=1,
|
|
steps_per_eval=10,
|
|
ensure_scores_are_not_same=False,
|
|
)
|
|
|
|
server_configs = [
|
|
OpenaiConfig(
|
|
model_name="gpt-4o",
|
|
base_url=None,
|
|
api_key=os.environ.get("OPENAI_API_KEY"),
|
|
num_requests_for_eval=1,
|
|
),
|
|
]
|
|
|
|
return config, server_configs
|
|
|
|
|
|
if __name__ == "__main__":
|
|
ClockDatasetEnv.cli()
|