init commit

This commit is contained in:
edmundman 2025-05-18 16:58:42 -07:00
parent c189fc3351
commit 0e660a7429
19 changed files with 11250 additions and 0 deletions

View file

@ -0,0 +1,127 @@
# UFC Fight Prediction Environment
This environment provides a framework for training and evaluating AI models on UFC fight prediction tasks, with a unique twist: instead of traditional analytical predictions, it generates entertaining fight commentary that can be directly used with Text-to-Speech (TTS) models like DIA. The environment includes two main components: a text-based predictor and an image-based predictor, both designed to create engaging, broadcast-style fight commentary.
## Environment Design
### Core Components
1. **UFC Server (ufc_server.py)**
- Text-based fight prediction environment
- Generates dynamic, entertaining fight commentary
- Uses fighter statistics and historical data
- Outputs TTS-ready commentary with dramatic flair
- Implements a scoring system for model evaluation
2. **UFC Image Environment (ufc_image_env.py)**
- Visual-based fight prediction environment
- Creates commentary based on fighter appearances
- Implements multimodal prediction capabilities
- Generates broadcast-style commentary from visual analysis
- Includes image processing and base64 encoding utilities
### Data Structure
- **fighter_stats.csv**: Contains detailed statistics for each fighter including:
- Win/Loss records
- Physical attributes (height, weight, reach)
- Performance metrics (strikes per minute, takedown accuracy, etc.)
- **large_dataset.csv**: Historical fight data including:
- Fighter matchups
- Fight outcomes
- Event information
- **fighter_images/**: Directory containing fighter profile images
- Images are stored in JPG format
- Filenames follow slug format (e.g., "john-smith.jpg")
## Motivation
This environment was designed to transform traditional fight prediction into an engaging entertainment experience:
1. **Entertainment-First Approach**
- Generates dynamic, broadcast-style fight commentary
- Creates TTS-ready output for voice synthesis
- Incorporates dramatic elements and commentator personalities
- Makes fight prediction more engaging and accessible
2. **Statistical Analysis with Style**
- Wraps technical analysis in entertaining commentary
- Uses fight statistics to inform dramatic storytelling
- Maintains prediction accuracy while being entertaining
- Creates a more engaging way to present fight analysis
3. **Visual Storytelling**
- Transforms visual analysis into engaging commentary
- Creates dramatic narratives from fighter appearances
- Makes technical observations more accessible
- Generates TTS-compatible descriptions of visual elements
4. **Multimodal Entertainment**
- Combines statistical and visual data for rich commentary
- Creates cohesive narratives from multiple data sources
- Generates engaging stories that work well with TTS
- Makes technical analysis more accessible and fun
## Usage
1. Install dependencies:
```bash
pip install -r requirements.txt
```
2. Prepare data:
- Ensure fighter_stats.csv and large_dataset.csv are in the environment directory
- Place fighter images in the fighter_images/ directory
3. Run the environment:
- For text-based commentary: Use UFCEnv
- For image-based commentary: Use UFCImageEnv
4. TTS Integration:
- The generated commentary is formatted for direct use with TTS models
- Includes dramatic pauses and emphasis markers
- Contains natural speech patterns and commentator personalities
- Ready for voice synthesis with models like DIA
## Example Runs
Here are some example runs demonstrating the environment in action:
- [Video Demo](https://youtu.be/C_hFe6TfQvU) - Watch the environment in action with real-time commentary generation
- [Text-based Prediction Run](https://wandb.ai/edtheman/Atropos-environments_ufc_env/runs/rq5wfxgh?nw=nwuseredtheman) - Shows the environment generating commentary based on fighter statistics and historical data
- [Image-based Prediction Run](https://wandb.ai/edtheman/Atropos-environments_ufc_env/runs/klw4m5of?nw=nwuseredtheman) - Demonstrates the environment creating commentary from visual analysis of fighter appearances
The key difference between these runs is their input modality:
- The text-based run focuses on statistical analysis and historical data to generate commentary
- The image-based run analyzes fighter appearances and visual characteristics to create engaging narratives
## Configuration
The environment can be configured through the following parameters:
- `fighter_stats_path`: Path to fighter statistics CSV
- `fight_data_path`: Path to fight dataset CSV
- `image_folder`: Path to fighter images directory
- `max_steps`: Number of steps per prediction
- `temperature`: Generation diversity parameter (affects commentary style)
- `top_p`: Nucleus sampling parameter (affects commentary creativity)
## Scoring System
The environment implements a scoring system that evaluates predictions based on:
- Accuracy of winner prediction
- Entertainment value of the commentary
- TTS compatibility and natural flow
- Integration of statistical/visual data in an engaging way
- Proper formatting for voice synthesis
## Contributing
Contributions are welcome! Please feel free to submit pull requests for:
- New commentary styles and personalities
- Enhanced TTS compatibility features
- Additional dramatic elements
- Improved entertainment value
- Better integration with voice synthesis models

View file

@ -0,0 +1,3 @@
from .ufc_server import UFCEnv, UFCEnvConfig
__all__ = ["UFCEnv", "UFCEnvConfig"]

Binary file not shown.

After

Width:  |  Height:  |  Size: 344 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 167 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 414 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 353 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 117 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 360 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 330 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 411 KiB

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,113 @@
import os
import re
import pandas as pd
import requests
from bs4 import BeautifulSoup
import time
import random
# 1. Load your CSV
df = pd.read_csv("fighter_stats.csv")
# List of different user agents to rotate between
USER_AGENTS = [
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.159 Safari/537.36',
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:89.0) Gecko/20100101 Firefox/89.0',
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.1.1 Safari/605.1.15',
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36 Edg/91.0.864.59',
'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
]
# List of different accept languages
ACCEPT_LANGUAGES = [
'en-US,en;q=0.9',
'en-GB,en;q=0.9',
'en-CA,en;q=0.9',
'en-AU,en;q=0.9',
'en-NZ,en;q=0.9',
]
def get_random_headers():
"""Generate a random set of headers for each request."""
return {
'User-Agent': random.choice(USER_AGENTS),
'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8',
'Accept-Language': random.choice(ACCEPT_LANGUAGES),
'Connection': 'keep-alive',
'Upgrade-Insecure-Requests': '1',
'Cache-Control': 'max-age=0',
'Sec-Fetch-Dest': 'document',
'Sec-Fetch-Mode': 'navigate',
'Sec-Fetch-Site': 'none',
'Sec-Fetch-User': '?1',
'DNT': '1',
}
# 2. Generate UFC profile URLs
def make_url(name):
"""Turn a fighter name into a UFC athlete URL slug."""
if not isinstance(name, str) or not name.strip():
return None
slug = name.lower().strip()
slug = re.sub(r"[''']", "", slug) # Remove apostrophes
slug = re.sub(r"[^a-z\s-]", "", slug) # Keep letters, spaces, hyphens
slug = re.sub(r"\s+", "-", slug) # Spaces → hyphens
return f"https://www.ufc.com/athlete/{slug}"
df["UFC_Profile_URL"] = df["name"].apply(make_url)
# 3. Prepare output folder
output_folder = "fighter_images"
os.makedirs(output_folder, exist_ok=True)
# 4. Scrape & download each image
def download_ufc_image(url, counter):
"""Fetch the UFC athlete page, parse out the main profile image, and save it."""
try:
# Derive filename from URL slug
slug = url.rstrip("/").split("/")[-1]
filename = f"{slug}.jpg"
path = os.path.join(output_folder, filename)
# Check if image already exists
if os.path.exists(path):
print(f"[i] Image already exists for {filename}")
return
# Get fresh random headers for each request
headers = get_random_headers()
resp = requests.get(url, headers=headers)
resp.raise_for_status()
soup = BeautifulSoup(resp.text, "html.parser")
img_tag = soup.select_one("div.hero-profile__image-wrap img.hero-profile__image")
if not img_tag or not img_tag.get("src"):
print(f"[!] No image found at {url}")
return
img_url = img_tag["src"]
# Get fresh random headers for image download
headers = get_random_headers()
img_data = requests.get(img_url, headers=headers).content
with open(path, "wb") as f:
f.write(img_data)
print(f"[✓] Saved {filename}")
# Add random delay between 1-3 seconds
time.sleep(random.uniform(1, 3))
# After every 100 downloads, take a longer break
if counter % 100 == 0:
print(f"[i] Taking a 30-second break after {counter} downloads...")
time.sleep(30)
except Exception as e:
print(f"[✗] Error at {url}: {e}")
# 5. Iterate and download
counter = 0
for url in df["UFC_Profile_URL"].dropna().unique():
counter += 1
download_ufc_image(url, counter)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,8 @@
pydantic>=2.0.0
Pillow>=9.0.0
datasets>=2.0.0
numpy>=1.21.0
pandas>=1.3.0
requests>=2.26.0
aiohttp>=3.8.0
python-dotenv>=0.19.0

View file

@ -0,0 +1,416 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Fight Predictor</title>
<script src="https://cdn.jsdelivr.net/npm/canvas-confetti@1.6.0/dist/confetti.browser.min.js"></script>
<style>
body {
font-family: 'Arial', sans-serif;
background-color: #1a1a1a;
color: #ffffff;
margin: 0;
padding: 20px;
}
.container {
max-width: 1200px;
margin: 0 auto;
}
h1 {
text-align: center;
margin-bottom: 30px;
}
.fighters-container {
display: flex;
justify-content: space-around;
margin: 20px 0;
position: relative;
}
.fighter {
width: 300px;
text-align: center;
transition: all 0.5s ease;
}
.fighter img {
width: 100%;
height: auto;
border: 3px solid;
border-radius: 10px;
transition: all 0.5s ease;
}
.red-corner img {
border-color: #ff0000;
}
.blue-corner img {
border-color: #0000ff;
}
.winner {
transform: scale(1.2);
z-index: 2;
position: relative;
}
.winner::after {
content: "WINNER!";
position: absolute;
top: -30px;
left: 50%;
transform: translateX(-50%);
color: #ffd700;
font-size: 24px;
font-weight: bold;
text-shadow: 2px 2px 4px rgba(0,0,0,0.5);
}
.loser {
animation: explode 1s forwards;
}
@keyframes explode {
0% {
transform: scale(1);
opacity: 1;
}
50% {
transform: scale(1.5);
opacity: 0.5;
}
100% {
transform: scale(0);
opacity: 0;
}
}
.commentary {
background-color: #2a2a2a;
padding: 20px;
border-radius: 10px;
margin: 20px 0;
min-height: 200px;
white-space: pre-wrap;
display: none;
}
.typing {
overflow: hidden;
border-right: 2px solid #fff;
white-space: nowrap;
animation: typing 1s steps(40, end),
blink-caret 0.75s step-end infinite;
}
@keyframes typing {
from { width: 0 }
to { width: 100% }
}
@keyframes blink-caret {
from, to { border-color: transparent }
50% { border-color: #fff }
}
.upload-form {
display: flex;
justify-content: space-around;
margin: 20px 0;
}
.upload-section {
text-align: center;
}
input[type="file"] {
display: none;
}
.upload-btn {
background-color: #333;
color: white;
padding: 10px 20px;
border-radius: 5px;
cursor: pointer;
transition: background-color 0.3s;
}
.upload-btn:hover {
background-color: #444;
}
.predict-btn {
background-color: #ff0000;
color: white;
padding: 15px 30px;
border: none;
border-radius: 5px;
cursor: pointer;
font-size: 18px;
margin: 20px auto;
display: block;
transition: background-color 0.3s;
}
.predict-btn:hover {
background-color: #cc0000;
}
.loading {
display: none;
text-align: center;
margin: 20px 0;
font-size: 24px;
}
.loading::before {
content: "⚡";
display: inline-block;
animation: loading 0.5s infinite;
margin-right: 10px;
}
.loading::after {
content: "Analyzing Fight...";
display: inline-block;
animation: pulse 1s infinite;
}
@keyframes loading {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
@keyframes pulse {
0% { opacity: 0.5; }
50% { opacity: 1; }
100% { opacity: 0.5; }
}
.new-fight-btn {
background-color: #4CAF50;
color: white;
padding: 15px 30px;
border: none;
border-radius: 5px;
cursor: pointer;
font-size: 18px;
margin: 20px auto;
display: none;
transition: background-color 0.3s;
}
.new-fight-btn:hover {
background-color: #45a049;
}
.buttons-container {
display: flex;
justify-content: center;
gap: 20px;
margin: 20px 0;
}
</style>
</head>
<body>
<div class="container">
<h1>🤼‍♂️ Fight Predictor 🤼‍♂️</h1>
<!-- Hidden iframe for airhorn -->
<iframe id="airhorn" width="110" height="200" src="https://www.myinstants.com/instant/dj-airhorn/embed/" frameborder="0" scrolling="no" style="display: none;"></iframe>
<form id="predictForm" class="upload-form">
<div class="upload-section">
<label class="upload-btn">
Upload Red Corner Fighter
<input type="file" name="red_fighter" accept="image/*" required>
</label>
<div class="fighter red-corner">
<img id="redPreview" src="" alt="Red Corner Fighter" style="display: none;">
</div>
</div>
<div class="upload-section">
<label class="upload-btn">
Upload Blue Corner Fighter
<input type="file" name="blue_fighter" accept="image/*" required>
</label>
<div class="fighter blue-corner">
<img id="bluePreview" src="" alt="Blue Corner Fighter" style="display: none;">
</div>
</div>
</form>
<div class="buttons-container">
<button class="predict-btn" onclick="predictFight()">Predict Fight!</button>
<button class="new-fight-btn" onclick="resetFight()">New Fight</button>
</div>
<div class="loading" id="loading"></div>
<div class="commentary" id="commentary"></div>
</div>
<script>
function previewImage(input, previewId) {
const preview = document.getElementById(previewId);
if (input.files && input.files[0]) {
const reader = new FileReader();
reader.onload = function(e) {
preview.src = e.target.result;
preview.style.display = 'block';
}
reader.readAsDataURL(input.files[0]);
}
}
document.querySelector('input[name="red_fighter"]').addEventListener('change', function() {
previewImage(this, 'redPreview');
});
document.querySelector('input[name="blue_fighter"]').addEventListener('change', function() {
previewImage(this, 'bluePreview');
});
function playAirhorn() {
const iframe = document.getElementById('airhorn');
iframe.contentWindow.postMessage('play', '*');
}
function triggerConfetti() {
const duration = 3 * 1000;
const animationEnd = Date.now() + duration;
const defaults = { startVelocity: 30, spread: 360, ticks: 60, zIndex: 0 };
function randomInRange(min, max) {
return Math.random() * (max - min) + min;
}
const interval = setInterval(function() {
const timeLeft = animationEnd - Date.now();
if (timeLeft <= 0) {
return clearInterval(interval);
}
const particleCount = 50 * (timeLeft / duration);
// since particles fall down, start a bit higher than random
confetti({
...defaults,
particleCount,
origin: { x: randomInRange(0.1, 0.3), y: Math.random() - 0.2 }
});
confetti({
...defaults,
particleCount,
origin: { x: randomInRange(0.7, 0.9), y: Math.random() - 0.2 }
});
}, 250);
}
function typeText(element, text, speed = 30) {
let i = 0;
element.innerHTML = '';
element.style.display = 'block';
function type() {
if (i < text.length) {
element.innerHTML += text.charAt(i);
i++;
setTimeout(type, speed);
} else {
// After typing is complete, check for winner
const winnerMatch = text.match(/\\boxed{(Red|Blue)}/);
if (winnerMatch) {
const winner = winnerMatch[1];
const redFighter = document.querySelector('.red-corner');
const blueFighter = document.querySelector('.blue-corner');
if (winner === 'Red') {
redFighter.classList.add('winner');
blueFighter.classList.add('loser');
// Trigger effects from red corner
triggerConfetti();
playAirhorn();
} else {
blueFighter.classList.add('winner');
redFighter.classList.add('loser');
// Trigger effects from blue corner
triggerConfetti();
playAirhorn();
}
// Show new fight button after prediction is complete
document.querySelector('.new-fight-btn').style.display = 'block';
}
}
}
type();
}
function resetFight() {
// Reset images
document.getElementById('redPreview').style.display = 'none';
document.getElementById('redPreview').src = '';
document.getElementById('bluePreview').style.display = 'none';
document.getElementById('bluePreview').src = '';
// Reset form
document.getElementById('predictForm').reset();
// Reset classes
document.querySelector('.red-corner').classList.remove('winner', 'loser');
document.querySelector('.blue-corner').classList.remove('winner', 'loser');
// Clear and hide commentary
const commentary = document.getElementById('commentary');
commentary.innerHTML = '';
commentary.style.display = 'none';
// Hide new fight button
document.querySelector('.new-fight-btn').style.display = 'none';
}
async function predictFight() {
const form = document.getElementById('predictForm');
const formData = new FormData(form);
const loading = document.getElementById('loading');
const commentary = document.getElementById('commentary');
const predictBtn = document.querySelector('.predict-btn');
loading.style.display = 'block';
commentary.style.display = 'none';
predictBtn.disabled = true;
try {
const response = await fetch('/predict', {
method: 'POST',
body: formData
});
const data = await response.json();
if (data.success) {
typeText(commentary, data.prediction);
} else {
commentary.style.display = 'block';
commentary.innerHTML = `Error: ${data.error}`;
}
} catch (error) {
commentary.style.display = 'block';
commentary.innerHTML = `Error: ${error.message}`;
} finally {
loading.style.display = 'none';
predictBtn.disabled = false;
}
}
</script>
</body>
</html>

View file

@ -0,0 +1,326 @@
import os
import random
import sys
import traceback
import csv
from typing import List, Optional, Tuple, Any, Dict
import base64
from PIL import Image
import io
from pydantic import Field
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
from atroposlib.type_definitions import GameHistory, Item
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
class UFCImageEnvConfig(BaseEnvConfig):
"""Configuration for the UFC Image Environment"""
fighter_stats_path: str = Field(os.path.join(os.path.dirname(__file__), "fighter_stats.csv"), description="Path to fighter stats CSV")
fight_data_path: str = Field(os.path.join(os.path.dirname(__file__), "large_dataset.csv"), description="Path to large fight dataset CSV")
image_folder: str = Field(os.path.join(os.path.dirname(__file__), "fighter_images"), description="Path to fighter images folder")
max_steps: int = Field(1, description="Only one step per fight prediction")
temperature: float = Field(0.7, description="Temperature for generation diversity")
top_p: float = Field(0.95, description="Top p for nucleus sampling")
class UFCImageEnv(BaseEnv):
"""UFC Fight Prediction Environment using only fighter images"""
name = "ufc_image_predictor"
env_config_cls = UFCImageEnvConfig
def __init__(self, config: UFCImageEnvConfig, server_configs: List[OpenaiConfig], slurm=True, testing=False):
super().__init__(config, server_configs, slurm, testing)
self.fighter_stats = {}
self.fight_data = []
self.current_index = 0
self.inference_server = self.server.servers[0] # Get first server as inference server
async def setup(self):
"""Load the fighter stats and fight data"""
try:
print("Loading fighter stats from:", self.config.fighter_stats_path)
with open(self.config.fighter_stats_path, encoding="utf-8") as f:
reader = csv.DictReader(f)
self.fighter_stats = {row["name"]: row for row in reader}
print(f"Loaded stats for {len(self.fighter_stats)} fighters")
print("Loading fight data from:", self.config.fight_data_path)
with open(self.config.fight_data_path, encoding="utf-8") as f:
reader = csv.DictReader(f)
self.fight_data = list(reader)
print(f"Loaded {len(self.fight_data)} fights")
# Filter out fights where either fighter's image is missing
filtered_fights = []
missing_images = set() # Track unique missing images
for fight in self.fight_data:
r_fighter = fight["r_fighter"]
b_fighter = fight["b_fighter"]
# Convert names to image filename format
r_slug = r_fighter.lower().replace(" ", "-")
b_slug = b_fighter.lower().replace(" ", "-")
r_image_path = os.path.join(self.config.image_folder, f"{r_slug}.jpg")
b_image_path = os.path.join(self.config.image_folder, f"{b_slug}.jpg")
if os.path.exists(r_image_path) and os.path.exists(b_image_path):
filtered_fights.append(fight)
else:
if not os.path.exists(r_image_path):
missing_images.add(r_fighter)
if not os.path.exists(b_image_path):
missing_images.add(b_fighter)
if missing_images:
print(f"\nMissing images for {len(missing_images)} fighters. These fights will be skipped.")
self.fight_data = filtered_fights
print(f"Filtered to {len(self.fight_data)} fights with complete image sets")
except Exception as e:
print(f"Error loading data: {e}")
traceback.print_exc()
sys.exit(1)
def get_fighter_image(self, fighter_name):
"""Convert fighter name to image path and return base64 encoded image"""
try:
# Convert name to slug format
slug = fighter_name.lower().replace(" ", "-")
image_path = os.path.join(self.config.image_folder, f"{slug}.jpg")
if not os.path.exists(image_path):
return None
# Convert image to base64
with Image.open(image_path) as img:
# Convert RGBA to RGB if necessary
if img.mode == 'RGBA':
img = img.convert('RGB')
buf = io.BytesIO()
img.save(buf, format="JPEG")
image_bytes = buf.getvalue()
return base64.b64encode(image_bytes).decode("utf-8")
except Exception as e:
print(f"Error getting image for {fighter_name}: {e}")
return None
async def get_next_item(self) -> Optional[Item]:
"""Get the next fight from the dataset"""
try:
if self.current_index >= len(self.fight_data):
return None
fight = self.fight_data[self.current_index]
self.current_index += 1
r_fighter = fight["r_fighter"]
b_fighter = fight["b_fighter"]
# Get base64 encoded images
r_image = self.get_fighter_image(r_fighter)
b_image = self.get_fighter_image(b_fighter)
if not r_image or not b_image:
print(f"Skipping fight {self.current_index} due to missing images")
return None
# Format the prompt with images
prompt_text = (
"🎤 LADIES AND GENTLEMEN! Welcome to the most electrifying show in sports entertainment "
"Let's break down this matchup that's got everyone talking!\n\n"
"In the red corner, we have:(YOUR FIRST IMAGE):\n"
"And in the blue corner: (YOUR SECOND IMAGE):\n\n"
"Now, act as your favorite fight comentator, I want you to:\n"
"create a fight commentary of whats happening in the fight live\n"
"Give us your best fight commentary! Make it exciting, make it dramatic, make it sound like you're calling the fight live! "
"Throw in some classic commentator phrases, maybe a 'OH MY GOODNESS!' or two, and definitely some dramatic pauses for effect.\n\n"
"End your masterpiece with your prediction in this exact format:\n"
"[S1]Hello im your host [S2] And so am i (name) [S1] Wow. Amazing. (laughs) [S2] Lets get started! (coughs)\n\n"
"The winner should always be annouced with"
"\\boxed{Red} or \\boxed{Blue}"
"Or you will receive a score of -1.0"
)
# Create multimodal prompt with images
prompt = tuple([
{
"role": "user",
"content": [
{"type": "text", "text": prompt_text},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{r_image}"}
},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{b_image}"}
}
]
}
])
winner = fight.get("winner", "") # Red or Blue
ground_truth = f"Answer: {winner}" if winner else ""
return (prompt, ground_truth, None)
except Exception as e:
print(f"Error in get_next_item: {e}")
traceback.print_exc()
return None
async def collect_trajectories(self, item: Item) -> Tuple[List[Tuple[GameHistory, str, Optional[str]]], List[Item]]:
to_score = []
to_backlog = []
system_msg = {
"role": "system",
"content": (
"You are an expert MMA analyst. You will be given two UFC fighters' images. "
"Your task is to predict the winner of the fight based on their appearance and physique.\n\n"
"IMPORTANT: You MUST format your response in exactly two parts:\n"
"1. First, analyze the fighters' appearances and create a fight commentary\n"
"2. Then on a new line, give ONLY your final prediction in this exact format:\n"
"\\boxed{Red} or \\boxed{Blue}\n\n"
"For example:\n"
"After analyzing the fighters' appearances... [your analysis here]\n"
"\\boxed{Red}\n\n"
"If you do not end your response with the \\boxed{} format containing either 'Red' or 'Blue', you will receive a score of -1.0."
)
}
user_msg = {
"role": "user",
"content": dict(item[0][0])["content"]
}
messages = [system_msg, user_msg]
try:
chat_completions = await self.inference_server.chat_completion(
messages=messages,
n=self.config.group_size,
max_tokens=2048,
temperature=self.config.temperature,
top_p=self.config.top_p,
timeout=60,
)
for choice in chat_completions.choices:
assistant_msg = {"role": "assistant", "content": choice.message.content}
history = [
{"role": "system", "content": system_msg["content"]},
{"role": "user", "content": user_msg["content"]},
{"role": "assistant", "content": choice.message.content}
]
to_score.append((history, item[1], None))
except Exception as e:
print(f"Error in collect_trajectories: {e}")
traceback.print_exc()
to_backlog.append(item)
if not to_score:
return None, to_backlog
scored_data = await self.score(to_score)
return scored_data, to_backlog
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
if not rollout_group_data:
return None
scores = ScoredDataGroup()
scores["tokens"] = []
scores["masks"] = []
scores["scores"] = []
scores["advantages"] = None
scores["ref_logprobs"] = None
scores["messages"] = None
scores["group_overrides"] = {"group_size": self.config.group_size}
scores["overrides"] = None
scores["ground_truths"] = []
random.shuffle(rollout_group_data)
for item in rollout_group_data:
out = tokenize_for_trainer(self.tokenizer, item[0])
tokens = out["tokens"]
masks = out["masks"]
try:
# Extract prediction and ground truth
reply = item[0][-1]["content"]
ground_truth = item[1].strip().lower()
# Extract color from ground truth (format: "answer: color")
ground_truth_color = ground_truth.replace("answer:", "").strip()
# Extract color from \boxed{color} format
import re
boxed_match = re.search(r"\\boxed{([^}]+)}", reply)
if boxed_match:
prediction = boxed_match.group(1).strip().lower()
# Compare just the colors
reward = 1.0 if prediction == ground_truth_color else -1.0
else:
# No boxed answer found
reward = -1.0
except Exception as e:
print(f"Error scoring response: {e}")
reward = -1.0
ground_truth = item[1] if isinstance(item[1], str) else ""
if len([i for i in masks if i != -100]) < 10:
continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(reward)
scores["ground_truths"].append(ground_truth)
if len(scores["tokens"]) >= self.config.group_size:
break
if not scores["tokens"]:
return None
return scores
async def evaluate(self, *args, **kwargs):
"""No-op evaluation"""
return
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
"""Initialize configuration for the environment"""
if not os.environ.get("OPENAI_API_KEY"):
print("ERROR: OPENAI_API_KEY environment variable is not set!")
sys.exit(1)
config = UFCImageEnvConfig(
wandb_name="ufc_image",
tokenizer_name="gpt2",
group_size=2,
use_wandb=False,
max_num_workers=2,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=1,
steps_per_eval=10,
ensure_scores_are_not_same=False,
)
server_configs = [
OpenaiConfig(
model_name="gpt-4o",
base_url=None,
api_key=os.environ.get("OPENAI_API_KEY"),
num_requests_for_eval=1,
),
]
return config, server_configs
if __name__ == "__main__":
UFCImageEnv.cli()

View file

@ -0,0 +1,104 @@
import os
import base64
from io import BytesIO
from flask import Flask, render_template, request, jsonify
from PIL import Image
import openai
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
app = Flask(__name__)
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size
# Initialize OpenAI client
client = openai.OpenAI(api_key=os.getenv('OPENAI_API_KEY'))
def process_image(image_file):
"""Convert uploaded image to base64"""
img = Image.open(image_file)
# Convert RGBA to RGB if necessary
if img.mode == 'RGBA':
img = img.convert('RGB')
buffered = BytesIO()
img.save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue()).decode('utf-8')
@app.route('/')
def home():
return render_template('predictor.html')
@app.route('/predict', methods=['POST'])
def predict():
try:
# Get uploaded images
red_fighter = request.files['red_fighter']
blue_fighter = request.files['blue_fighter']
if not red_fighter or not blue_fighter:
return jsonify({'error': 'Please upload both fighter images'}), 400
# Process images to base64
red_image = process_image(red_fighter)
blue_image = process_image(blue_fighter)
# Create the prompt
prompt_text = (
"🎤 LADIES AND GENTLEMEN! Welcome to the most electrifying show in sports entertainment "
"Let's break down this matchup that's got everyone talking!\n\n"
"In the red corner, we have:(YOUR FIRST IMAGE):\n"
"And in the blue corner: (YOUR SECOND IMAGE):\n\n"
"Now, as your favorite fight comentator, I want you to:\n"
"create a fight commentary of whats happening in the fight live\n"
"Give us your best fight commentary! Make it exciting, make it dramatic, make it sound like you're calling the fight live! "
"Throw in some classic commentator phrases, maybe a 'OH MY GOODNESS!' or two, and definitely some dramatic pauses for effect.\n\n"
"End your masterpiece with your prediction in this exact format:\n"
"\\boxed{Red} or \\boxed{Blue}"
"PLEASE FORMAT THE COMMENTARY IN THE EXACT FORMAT AS THE EXAMPLE BELOW:\n"
"[S1]Hello im your host [S2] And so am i (name) [S1] Wow. Amazing. (laughs) [S2] Lets get started! (coughs) ( add lots of coughs and laughs)\n\n"
)
# Create the messages for the API call
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt_text},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{red_image}"}
},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{blue_image}"}
}
]
}
]
# Make the API call
response = client.chat.completions.create(
model="gpt-4o",
messages=messages,
max_tokens=2048,
temperature=0.7,
top_p=0.95
)
# Extract the prediction
prediction = response.choices[0].message.content
return jsonify({
'prediction': prediction,
'success': True
})
except Exception as e:
return jsonify({
'error': str(e),
'success': False
}), 500
if __name__ == '__main__':
app.run(debug=True)

View file

@ -0,0 +1,233 @@
import os
import random
import sys
import traceback
import csv
from typing import List, Optional, Tuple, Any, Dict
from datasets import load_dataset
from pydantic import Field
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
from atroposlib.type_definitions import GameHistory, Item
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
class UFCEnvConfig(BaseEnvConfig):
"""Configuration for the UFC Environment"""
fighter_stats_path: str = Field(os.path.join(os.path.dirname(__file__), "fighter_stats.csv"), description="Path to fighter stats CSV")
fight_data_path: str = Field(os.path.join(os.path.dirname(__file__), "large_dataset.csv"), description="Path to large fight dataset CSV")
max_steps: int = Field(1, description="Only one step per fight prediction")
temperature: float = Field(0.7, description="Temperature for generation diversity")
top_p: float = Field(0.95, description="Top p for nucleus sampling")
class UFCEnv(BaseEnv):
"""UFC Fight Prediction Environment"""
name = "ufc_predictor"
env_config_cls = UFCEnvConfig
def __init__(self, config: UFCEnvConfig, server_configs: List[OpenaiConfig], slurm=True, testing=False):
super().__init__(config, server_configs, slurm, testing)
self.fighter_stats = {}
self.fight_data = []
self.current_index = 0
self.inference_server = self.server.servers[0] # Get first server as inference server
async def setup(self):
"""Load the fighter stats and fight data"""
try:
print("Loading fighter stats from:", self.config.fighter_stats_path)
with open(self.config.fighter_stats_path, encoding="utf-8") as f:
reader = csv.DictReader(f)
self.fighter_stats = {row["name"]: row for row in reader}
print(f"Loaded stats for {len(self.fighter_stats)} fighters")
print("Loading fight data from:", self.config.fight_data_path)
with open(self.config.fight_data_path, encoding="utf-8") as f:
reader = csv.DictReader(f)
self.fight_data = list(reader)
print(f"Loaded {len(self.fight_data)} fights")
except Exception as e:
print(f"Error loading data: {e}")
traceback.print_exc()
sys.exit(1)
async def get_next_item(self) -> Optional[Item]:
"""Get the next fight from the dataset"""
try:
if self.current_index >= len(self.fight_data):
return None
fight = self.fight_data[self.current_index]
self.current_index += 1
r_fighter = fight["r_fighter"]
b_fighter = fight["b_fighter"]
r_stats = self.fighter_stats.get(r_fighter, {})
b_stats = self.fighter_stats.get(b_fighter, {})
# Format the prompt
def stats_str(name, stats):
if not stats:
return f"{name}: (No stats available)"
return (
f"Name: {name}\n"
f"Wins: {stats.get('wins','?')} Losses: {stats.get('losses','?')} Age: {stats.get('age','?')}\n"
f"Height: {stats.get('height','?')} cm Weight: {stats.get('weight','?')} kg Reach: {stats.get('reach','?')} cm Stance: {stats.get('stance','?')}\n"
f"SLpM: {stats.get('SLpM','?')} Sig Str Acc: {stats.get('sig_str_acc','?')} SApM: {stats.get('SApM','?')} Str Def: {stats.get('str_def','?')}\n"
f"TD Avg: {stats.get('td_avg','?')} TD Acc: {stats.get('td_acc','?')} TD Def: {stats.get('td_def','?')} Sub Avg: {stats.get('sub_avg','?')}\n"
)
prompt_text = (
"🎤 LADIES AND GENTLEMEN! Welcome to the most electrifying show in sports entertainment - the UFC Fight Prediction Show! "
"Let's break down this matchup that's got everyone talking!\n\n"
f"*Drumroll please* In the red corner, we have :\n{stats_str(r_fighter, r_stats)}\n\n"
f"And in the blue corner:\n{stats_str(b_fighter, b_stats)}\n\n"
"Now, as your favorite fight analyst who's definitely not just making this up as I go along, I want you to:\n"
"1. Break down these fighters like you're explaining why your favorite TV show character would win in a fight\n"
"2. Compare their styles\n"
"3. Point out their advantages\n"
"Give us your best fight commentary! Make it exciting, make it dramatic, make it sound like you're calling the fight live! "
"Throw in some classic commentator phrases, maybe a 'OH MY GOODNESS!' or two, and definitely some dramatic pauses for effect.\n\n"
"End your masterpiece with the winner's name in this exact format:\n"
"\\boxed{fighter name}"
)
prompt = tuple([
frozenset({"role": "user", "content": prompt_text}.items())
])
winner = fight.get("winner", "") # Red or Blue
winner_name = r_fighter if winner == "Red" else b_fighter if winner == "Blue" else ""
ground_truth = f"Answer: {winner_name}" if winner_name else ""
return (prompt, ground_truth, None)
except Exception as e:
print(f"Error in get_next_item: {e}")
traceback.print_exc()
return None
async def collect_trajectories(self, item: Item) -> Tuple[List[Tuple[GameHistory, str, Optional[str]]], List[Item]]:
to_score = []
to_backlog = []
system_msg = {
"role": "system",
"content": (
"You are an expert MMA analyst. You will be given two UFC fighters and their stats. "
"Your task is to predict the winner of the fight based on their statistics.\n\n"
"IMPORTANT: You MUST format your response in exactly two parts:\n"
"1. First, analyze the fighters' stats and explain create a fight commentary\n"
"2. Then on a new line, give ONLY your final prediction in this exact format:\n"
"\\boxed{fighter name}\n\n"
"For example:\n"
"After analyzing stats... [your analysis here]\n"
"\\boxed{John Smith}\n\n"
"If you do not end your response with the \\boxed{} format, you will receive a score of -1.0."
)
}
user_msg = {
"role": "user",
"content": dict(item[0][0])["content"]
}
messages = [system_msg, user_msg]
try:
chat_completions = await self.inference_server.chat_completion(
messages=messages,
n=self.config.group_size,
max_tokens=2048, # Increased from 512 to allow for longer, more detailed fight commentaries
temperature=self.config.temperature,
top_p=self.config.top_p,
timeout=60,
)
for choice in chat_completions.choices:
assistant_msg = {"role": "assistant", "content": choice.message.content}
history = [
{"role": "system", "content": system_msg["content"]},
{"role": "user", "content": user_msg["content"]},
{"role": "assistant", "content": choice.message.content}
]
to_score.append((history, item[1], None))
except Exception as e:
print(f"Error in collect_trajectories: {e}")
traceback.print_exc()
to_backlog.append(item)
if not to_score:
return None, to_backlog
scored_data = await self.score(to_score)
return scored_data, to_backlog
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
if not rollout_group_data:
return None
scores = ScoredDataGroup()
scores["tokens"] = []
scores["masks"] = []
scores["scores"] = []
scores["advantages"] = None
scores["ref_logprobs"] = None
scores["messages"] = None
scores["group_overrides"] = {"group_size": self.config.group_size}
scores["overrides"] = None
scores["ground_truths"] = []
random.shuffle(rollout_group_data)
for item in rollout_group_data:
out = tokenize_for_trainer(self.tokenizer, item[0])
tokens = out["tokens"]
masks = out["masks"]
try:
# Extract prediction and ground truth
reply = item[0][-1]["content"]
ground_truth = item[1].strip().lower()
# Extract name from ground truth (format: "answer: name")
ground_truth_name = ground_truth.replace("answer:", "").strip()
# Extract name from \boxed{name} format
import re
boxed_match = re.search(r"\\boxed{([^}]+)}", reply)
if boxed_match:
prediction = boxed_match.group(1).strip().lower()
# Compare just the names
reward = 1.0 if prediction == ground_truth_name else -1.0
else:
# No boxed answer found
reward = -1.0
except Exception as e:
print(f"Error scoring response: {e}")
reward = -1.0
ground_truth = item[1] if isinstance(item[1], str) else ""
if len([i for i in masks if i != -100]) < 10:
continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(reward)
scores["ground_truths"].append(ground_truth)
if len(scores["tokens"]) >= self.config.group_size:
break
if not scores["tokens"]:
return None
return scores
async def evaluate(self, *args, **kwargs):
"""No-op evaluation"""
return
if __name__ == "__main__":
UFCEnv.cli()