mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
init commit
This commit is contained in:
parent
c189fc3351
commit
0e660a7429
19 changed files with 11250 additions and 0 deletions
326
environments/hack0/ufc_env/ufc_image_env.py
Normal file
326
environments/hack0/ufc_env/ufc_image_env.py
Normal file
|
|
@ -0,0 +1,326 @@
|
|||
import os
|
||||
import random
|
||||
import sys
|
||||
import traceback
|
||||
import csv
|
||||
from typing import List, Optional, Tuple, Any, Dict
|
||||
import base64
|
||||
from PIL import Image
|
||||
import io
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
|
||||
from atroposlib.type_definitions import GameHistory, Item
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
|
||||
class UFCImageEnvConfig(BaseEnvConfig):
|
||||
"""Configuration for the UFC Image Environment"""
|
||||
fighter_stats_path: str = Field(os.path.join(os.path.dirname(__file__), "fighter_stats.csv"), description="Path to fighter stats CSV")
|
||||
fight_data_path: str = Field(os.path.join(os.path.dirname(__file__), "large_dataset.csv"), description="Path to large fight dataset CSV")
|
||||
image_folder: str = Field(os.path.join(os.path.dirname(__file__), "fighter_images"), description="Path to fighter images folder")
|
||||
max_steps: int = Field(1, description="Only one step per fight prediction")
|
||||
temperature: float = Field(0.7, description="Temperature for generation diversity")
|
||||
top_p: float = Field(0.95, description="Top p for nucleus sampling")
|
||||
|
||||
|
||||
class UFCImageEnv(BaseEnv):
|
||||
"""UFC Fight Prediction Environment using only fighter images"""
|
||||
name = "ufc_image_predictor"
|
||||
env_config_cls = UFCImageEnvConfig
|
||||
|
||||
def __init__(self, config: UFCImageEnvConfig, server_configs: List[OpenaiConfig], slurm=True, testing=False):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.fighter_stats = {}
|
||||
self.fight_data = []
|
||||
self.current_index = 0
|
||||
self.inference_server = self.server.servers[0] # Get first server as inference server
|
||||
|
||||
async def setup(self):
|
||||
"""Load the fighter stats and fight data"""
|
||||
try:
|
||||
print("Loading fighter stats from:", self.config.fighter_stats_path)
|
||||
with open(self.config.fighter_stats_path, encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
self.fighter_stats = {row["name"]: row for row in reader}
|
||||
print(f"Loaded stats for {len(self.fighter_stats)} fighters")
|
||||
|
||||
print("Loading fight data from:", self.config.fight_data_path)
|
||||
with open(self.config.fight_data_path, encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
self.fight_data = list(reader)
|
||||
print(f"Loaded {len(self.fight_data)} fights")
|
||||
|
||||
# Filter out fights where either fighter's image is missing
|
||||
filtered_fights = []
|
||||
missing_images = set() # Track unique missing images
|
||||
for fight in self.fight_data:
|
||||
r_fighter = fight["r_fighter"]
|
||||
b_fighter = fight["b_fighter"]
|
||||
|
||||
# Convert names to image filename format
|
||||
r_slug = r_fighter.lower().replace(" ", "-")
|
||||
b_slug = b_fighter.lower().replace(" ", "-")
|
||||
|
||||
r_image_path = os.path.join(self.config.image_folder, f"{r_slug}.jpg")
|
||||
b_image_path = os.path.join(self.config.image_folder, f"{b_slug}.jpg")
|
||||
|
||||
if os.path.exists(r_image_path) and os.path.exists(b_image_path):
|
||||
filtered_fights.append(fight)
|
||||
else:
|
||||
if not os.path.exists(r_image_path):
|
||||
missing_images.add(r_fighter)
|
||||
if not os.path.exists(b_image_path):
|
||||
missing_images.add(b_fighter)
|
||||
|
||||
if missing_images:
|
||||
print(f"\nMissing images for {len(missing_images)} fighters. These fights will be skipped.")
|
||||
|
||||
self.fight_data = filtered_fights
|
||||
print(f"Filtered to {len(self.fight_data)} fights with complete image sets")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading data: {e}")
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
def get_fighter_image(self, fighter_name):
|
||||
"""Convert fighter name to image path and return base64 encoded image"""
|
||||
try:
|
||||
# Convert name to slug format
|
||||
slug = fighter_name.lower().replace(" ", "-")
|
||||
|
||||
image_path = os.path.join(self.config.image_folder, f"{slug}.jpg")
|
||||
if not os.path.exists(image_path):
|
||||
return None
|
||||
|
||||
# Convert image to base64
|
||||
with Image.open(image_path) as img:
|
||||
# Convert RGBA to RGB if necessary
|
||||
if img.mode == 'RGBA':
|
||||
img = img.convert('RGB')
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG")
|
||||
image_bytes = buf.getvalue()
|
||||
return base64.b64encode(image_bytes).decode("utf-8")
|
||||
except Exception as e:
|
||||
print(f"Error getting image for {fighter_name}: {e}")
|
||||
return None
|
||||
|
||||
async def get_next_item(self) -> Optional[Item]:
|
||||
"""Get the next fight from the dataset"""
|
||||
try:
|
||||
if self.current_index >= len(self.fight_data):
|
||||
return None
|
||||
fight = self.fight_data[self.current_index]
|
||||
self.current_index += 1
|
||||
|
||||
r_fighter = fight["r_fighter"]
|
||||
b_fighter = fight["b_fighter"]
|
||||
|
||||
# Get base64 encoded images
|
||||
r_image = self.get_fighter_image(r_fighter)
|
||||
b_image = self.get_fighter_image(b_fighter)
|
||||
|
||||
if not r_image or not b_image:
|
||||
print(f"Skipping fight {self.current_index} due to missing images")
|
||||
return None
|
||||
|
||||
# Format the prompt with images
|
||||
prompt_text = (
|
||||
"🎤 LADIES AND GENTLEMEN! Welcome to the most electrifying show in sports entertainment "
|
||||
"Let's break down this matchup that's got everyone talking!\n\n"
|
||||
"In the red corner, we have:(YOUR FIRST IMAGE):\n"
|
||||
"And in the blue corner: (YOUR SECOND IMAGE):\n\n"
|
||||
"Now, act as your favorite fight comentator, I want you to:\n"
|
||||
"create a fight commentary of whats happening in the fight live\n"
|
||||
"Give us your best fight commentary! Make it exciting, make it dramatic, make it sound like you're calling the fight live! "
|
||||
"Throw in some classic commentator phrases, maybe a 'OH MY GOODNESS!' or two, and definitely some dramatic pauses for effect.\n\n"
|
||||
"End your masterpiece with your prediction in this exact format:\n"
|
||||
"[S1]Hello im your host [S2] And so am i (name) [S1] Wow. Amazing. (laughs) [S2] Lets get started! (coughs)\n\n"
|
||||
"The winner should always be annouced with"
|
||||
"\\boxed{Red} or \\boxed{Blue}"
|
||||
"Or you will receive a score of -1.0"
|
||||
)
|
||||
|
||||
# Create multimodal prompt with images
|
||||
prompt = tuple([
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt_text},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{r_image}"}
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{b_image}"}
|
||||
}
|
||||
]
|
||||
}
|
||||
])
|
||||
|
||||
winner = fight.get("winner", "") # Red or Blue
|
||||
ground_truth = f"Answer: {winner}" if winner else ""
|
||||
|
||||
return (prompt, ground_truth, None)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in get_next_item: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
async def collect_trajectories(self, item: Item) -> Tuple[List[Tuple[GameHistory, str, Optional[str]]], List[Item]]:
|
||||
to_score = []
|
||||
to_backlog = []
|
||||
|
||||
system_msg = {
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are an expert MMA analyst. You will be given two UFC fighters' images. "
|
||||
"Your task is to predict the winner of the fight based on their appearance and physique.\n\n"
|
||||
"IMPORTANT: You MUST format your response in exactly two parts:\n"
|
||||
"1. First, analyze the fighters' appearances and create a fight commentary\n"
|
||||
"2. Then on a new line, give ONLY your final prediction in this exact format:\n"
|
||||
"\\boxed{Red} or \\boxed{Blue}\n\n"
|
||||
"For example:\n"
|
||||
"After analyzing the fighters' appearances... [your analysis here]\n"
|
||||
"\\boxed{Red}\n\n"
|
||||
"If you do not end your response with the \\boxed{} format containing either 'Red' or 'Blue', you will receive a score of -1.0."
|
||||
)
|
||||
}
|
||||
|
||||
user_msg = {
|
||||
"role": "user",
|
||||
"content": dict(item[0][0])["content"]
|
||||
}
|
||||
|
||||
messages = [system_msg, user_msg]
|
||||
|
||||
try:
|
||||
chat_completions = await self.inference_server.chat_completion(
|
||||
messages=messages,
|
||||
n=self.config.group_size,
|
||||
max_tokens=2048,
|
||||
temperature=self.config.temperature,
|
||||
top_p=self.config.top_p,
|
||||
timeout=60,
|
||||
)
|
||||
for choice in chat_completions.choices:
|
||||
assistant_msg = {"role": "assistant", "content": choice.message.content}
|
||||
history = [
|
||||
{"role": "system", "content": system_msg["content"]},
|
||||
{"role": "user", "content": user_msg["content"]},
|
||||
{"role": "assistant", "content": choice.message.content}
|
||||
]
|
||||
to_score.append((history, item[1], None))
|
||||
except Exception as e:
|
||||
print(f"Error in collect_trajectories: {e}")
|
||||
traceback.print_exc()
|
||||
to_backlog.append(item)
|
||||
|
||||
if not to_score:
|
||||
return None, to_backlog
|
||||
|
||||
scored_data = await self.score(to_score)
|
||||
return scored_data, to_backlog
|
||||
|
||||
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
|
||||
if not rollout_group_data:
|
||||
return None
|
||||
|
||||
scores = ScoredDataGroup()
|
||||
scores["tokens"] = []
|
||||
scores["masks"] = []
|
||||
scores["scores"] = []
|
||||
scores["advantages"] = None
|
||||
scores["ref_logprobs"] = None
|
||||
scores["messages"] = None
|
||||
scores["group_overrides"] = {"group_size": self.config.group_size}
|
||||
scores["overrides"] = None
|
||||
scores["ground_truths"] = []
|
||||
|
||||
random.shuffle(rollout_group_data)
|
||||
for item in rollout_group_data:
|
||||
out = tokenize_for_trainer(self.tokenizer, item[0])
|
||||
tokens = out["tokens"]
|
||||
masks = out["masks"]
|
||||
|
||||
try:
|
||||
# Extract prediction and ground truth
|
||||
reply = item[0][-1]["content"]
|
||||
ground_truth = item[1].strip().lower()
|
||||
|
||||
# Extract color from ground truth (format: "answer: color")
|
||||
ground_truth_color = ground_truth.replace("answer:", "").strip()
|
||||
|
||||
# Extract color from \boxed{color} format
|
||||
import re
|
||||
boxed_match = re.search(r"\\boxed{([^}]+)}", reply)
|
||||
if boxed_match:
|
||||
prediction = boxed_match.group(1).strip().lower()
|
||||
# Compare just the colors
|
||||
reward = 1.0 if prediction == ground_truth_color else -1.0
|
||||
else:
|
||||
# No boxed answer found
|
||||
reward = -1.0
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error scoring response: {e}")
|
||||
reward = -1.0
|
||||
ground_truth = item[1] if isinstance(item[1], str) else ""
|
||||
|
||||
if len([i for i in masks if i != -100]) < 10:
|
||||
continue
|
||||
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
scores["scores"].append(reward)
|
||||
scores["ground_truths"].append(ground_truth)
|
||||
|
||||
if len(scores["tokens"]) >= self.config.group_size:
|
||||
break
|
||||
|
||||
if not scores["tokens"]:
|
||||
return None
|
||||
|
||||
return scores
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""No-op evaluation"""
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
|
||||
"""Initialize configuration for the environment"""
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
print("ERROR: OPENAI_API_KEY environment variable is not set!")
|
||||
sys.exit(1)
|
||||
|
||||
config = UFCImageEnvConfig(
|
||||
wandb_name="ufc_image",
|
||||
tokenizer_name="gpt2",
|
||||
group_size=2,
|
||||
use_wandb=False,
|
||||
max_num_workers=2,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=1,
|
||||
steps_per_eval=10,
|
||||
ensure_scores_are_not_same=False,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
OpenaiConfig(
|
||||
model_name="gpt-4o",
|
||||
base_url=None,
|
||||
api_key=os.environ.get("OPENAI_API_KEY"),
|
||||
num_requests_for_eval=1,
|
||||
),
|
||||
]
|
||||
|
||||
return config, server_configs
|
||||
|
||||
if __name__ == "__main__":
|
||||
UFCImageEnv.cli()
|
||||
Loading…
Add table
Add a link
Reference in a new issue