init commit
127
environments/hack0/ufc_env/README.md
Normal file
|
|
@ -0,0 +1,127 @@
|
||||||
|
# UFC Fight Prediction Environment
|
||||||
|
|
||||||
|
This environment provides a framework for training and evaluating AI models on UFC fight prediction tasks, with a unique twist: instead of traditional analytical predictions, it generates entertaining fight commentary that can be directly used with Text-to-Speech (TTS) models like DIA. The environment includes two main components: a text-based predictor and an image-based predictor, both designed to create engaging, broadcast-style fight commentary.
|
||||||
|
|
||||||
|
## Environment Design
|
||||||
|
|
||||||
|
### Core Components
|
||||||
|
|
||||||
|
1. **UFC Server (ufc_server.py)**
|
||||||
|
- Text-based fight prediction environment
|
||||||
|
- Generates dynamic, entertaining fight commentary
|
||||||
|
- Uses fighter statistics and historical data
|
||||||
|
- Outputs TTS-ready commentary with dramatic flair
|
||||||
|
- Implements a scoring system for model evaluation
|
||||||
|
|
||||||
|
2. **UFC Image Environment (ufc_image_env.py)**
|
||||||
|
- Visual-based fight prediction environment
|
||||||
|
- Creates commentary based on fighter appearances
|
||||||
|
- Implements multimodal prediction capabilities
|
||||||
|
- Generates broadcast-style commentary from visual analysis
|
||||||
|
- Includes image processing and base64 encoding utilities
|
||||||
|
|
||||||
|
### Data Structure
|
||||||
|
|
||||||
|
- **fighter_stats.csv**: Contains detailed statistics for each fighter including:
|
||||||
|
- Win/Loss records
|
||||||
|
- Physical attributes (height, weight, reach)
|
||||||
|
- Performance metrics (strikes per minute, takedown accuracy, etc.)
|
||||||
|
|
||||||
|
- **large_dataset.csv**: Historical fight data including:
|
||||||
|
- Fighter matchups
|
||||||
|
- Fight outcomes
|
||||||
|
- Event information
|
||||||
|
|
||||||
|
- **fighter_images/**: Directory containing fighter profile images
|
||||||
|
- Images are stored in JPG format
|
||||||
|
- Filenames follow slug format (e.g., "john-smith.jpg")
|
||||||
|
|
||||||
|
## Motivation
|
||||||
|
|
||||||
|
This environment was designed to transform traditional fight prediction into an engaging entertainment experience:
|
||||||
|
|
||||||
|
1. **Entertainment-First Approach**
|
||||||
|
- Generates dynamic, broadcast-style fight commentary
|
||||||
|
- Creates TTS-ready output for voice synthesis
|
||||||
|
- Incorporates dramatic elements and commentator personalities
|
||||||
|
- Makes fight prediction more engaging and accessible
|
||||||
|
|
||||||
|
2. **Statistical Analysis with Style**
|
||||||
|
- Wraps technical analysis in entertaining commentary
|
||||||
|
- Uses fight statistics to inform dramatic storytelling
|
||||||
|
- Maintains prediction accuracy while being entertaining
|
||||||
|
- Creates a more engaging way to present fight analysis
|
||||||
|
|
||||||
|
3. **Visual Storytelling**
|
||||||
|
- Transforms visual analysis into engaging commentary
|
||||||
|
- Creates dramatic narratives from fighter appearances
|
||||||
|
- Makes technical observations more accessible
|
||||||
|
- Generates TTS-compatible descriptions of visual elements
|
||||||
|
|
||||||
|
4. **Multimodal Entertainment**
|
||||||
|
- Combines statistical and visual data for rich commentary
|
||||||
|
- Creates cohesive narratives from multiple data sources
|
||||||
|
- Generates engaging stories that work well with TTS
|
||||||
|
- Makes technical analysis more accessible and fun
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
1. Install dependencies:
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Prepare data:
|
||||||
|
- Ensure fighter_stats.csv and large_dataset.csv are in the environment directory
|
||||||
|
- Place fighter images in the fighter_images/ directory
|
||||||
|
|
||||||
|
3. Run the environment:
|
||||||
|
- For text-based commentary: Use UFCEnv
|
||||||
|
- For image-based commentary: Use UFCImageEnv
|
||||||
|
|
||||||
|
4. TTS Integration:
|
||||||
|
- The generated commentary is formatted for direct use with TTS models
|
||||||
|
- Includes dramatic pauses and emphasis markers
|
||||||
|
- Contains natural speech patterns and commentator personalities
|
||||||
|
- Ready for voice synthesis with models like DIA
|
||||||
|
|
||||||
|
## Example Runs
|
||||||
|
|
||||||
|
Here are some example runs demonstrating the environment in action:
|
||||||
|
|
||||||
|
- [Video Demo](https://youtu.be/C_hFe6TfQvU) - Watch the environment in action with real-time commentary generation
|
||||||
|
- [Text-based Prediction Run](https://wandb.ai/edtheman/Atropos-environments_ufc_env/runs/rq5wfxgh?nw=nwuseredtheman) - Shows the environment generating commentary based on fighter statistics and historical data
|
||||||
|
- [Image-based Prediction Run](https://wandb.ai/edtheman/Atropos-environments_ufc_env/runs/klw4m5of?nw=nwuseredtheman) - Demonstrates the environment creating commentary from visual analysis of fighter appearances
|
||||||
|
|
||||||
|
The key difference between these runs is their input modality:
|
||||||
|
- The text-based run focuses on statistical analysis and historical data to generate commentary
|
||||||
|
- The image-based run analyzes fighter appearances and visual characteristics to create engaging narratives
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
The environment can be configured through the following parameters:
|
||||||
|
|
||||||
|
- `fighter_stats_path`: Path to fighter statistics CSV
|
||||||
|
- `fight_data_path`: Path to fight dataset CSV
|
||||||
|
- `image_folder`: Path to fighter images directory
|
||||||
|
- `max_steps`: Number of steps per prediction
|
||||||
|
- `temperature`: Generation diversity parameter (affects commentary style)
|
||||||
|
- `top_p`: Nucleus sampling parameter (affects commentary creativity)
|
||||||
|
|
||||||
|
## Scoring System
|
||||||
|
|
||||||
|
The environment implements a scoring system that evaluates predictions based on:
|
||||||
|
- Accuracy of winner prediction
|
||||||
|
- Entertainment value of the commentary
|
||||||
|
- TTS compatibility and natural flow
|
||||||
|
- Integration of statistical/visual data in an engaging way
|
||||||
|
- Proper formatting for voice synthesis
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
Contributions are welcome! Please feel free to submit pull requests for:
|
||||||
|
- New commentary styles and personalities
|
||||||
|
- Enhanced TTS compatibility features
|
||||||
|
- Additional dramatic elements
|
||||||
|
- Improved entertainment value
|
||||||
|
- Better integration with voice synthesis models
|
||||||
3
environments/hack0/ufc_env/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .ufc_server import UFCEnv, UFCEnvConfig
|
||||||
|
|
||||||
|
__all__ = ["UFCEnv", "UFCEnvConfig"]
|
||||||
BIN
environments/hack0/ufc_env/fighter_images/matt-frevola.jpg
Normal file
|
After Width: | Height: | Size: 344 KiB |
BIN
environments/hack0/ufc_env/fighter_images/matt-hamill.jpg
Normal file
|
After Width: | Height: | Size: 167 KiB |
BIN
environments/hack0/ufc_env/fighter_images/randy-brown.jpg
Normal file
|
After Width: | Height: | Size: 414 KiB |
BIN
environments/hack0/ufc_env/fighter_images/tom-nolan.jpg
Normal file
|
After Width: | Height: | Size: 353 KiB |
BIN
environments/hack0/ufc_env/fighter_images/vagner-rocha.jpg
Normal file
|
After Width: | Height: | Size: 117 KiB |
BIN
environments/hack0/ufc_env/fighter_images/yanis-ghemmouri.jpg
Normal file
|
After Width: | Height: | Size: 360 KiB |
BIN
environments/hack0/ufc_env/fighter_images/zhalgas-zhumagulov.jpg
Normal file
|
After Width: | Height: | Size: 330 KiB |
BIN
environments/hack0/ufc_env/fighter_images/zviad-lazishvili.jpg
Normal file
|
After Width: | Height: | Size: 411 KiB |
2480
environments/hack0/ufc_env/fighter_stats.csv
Normal file
113
environments/hack0/ufc_env/get_images.py
Normal file
|
|
@ -0,0 +1,113 @@
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import pandas as pd
|
||||||
|
import requests
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
|
||||||
|
# 1. Load your CSV
|
||||||
|
df = pd.read_csv("fighter_stats.csv")
|
||||||
|
|
||||||
|
# List of different user agents to rotate between
|
||||||
|
USER_AGENTS = [
|
||||||
|
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
|
||||||
|
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.159 Safari/537.36',
|
||||||
|
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
|
||||||
|
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:89.0) Gecko/20100101 Firefox/89.0',
|
||||||
|
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.1.1 Safari/605.1.15',
|
||||||
|
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36 Edg/91.0.864.59',
|
||||||
|
'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
|
||||||
|
]
|
||||||
|
|
||||||
|
# List of different accept languages
|
||||||
|
ACCEPT_LANGUAGES = [
|
||||||
|
'en-US,en;q=0.9',
|
||||||
|
'en-GB,en;q=0.9',
|
||||||
|
'en-CA,en;q=0.9',
|
||||||
|
'en-AU,en;q=0.9',
|
||||||
|
'en-NZ,en;q=0.9',
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_random_headers():
|
||||||
|
"""Generate a random set of headers for each request."""
|
||||||
|
return {
|
||||||
|
'User-Agent': random.choice(USER_AGENTS),
|
||||||
|
'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8',
|
||||||
|
'Accept-Language': random.choice(ACCEPT_LANGUAGES),
|
||||||
|
'Connection': 'keep-alive',
|
||||||
|
'Upgrade-Insecure-Requests': '1',
|
||||||
|
'Cache-Control': 'max-age=0',
|
||||||
|
'Sec-Fetch-Dest': 'document',
|
||||||
|
'Sec-Fetch-Mode': 'navigate',
|
||||||
|
'Sec-Fetch-Site': 'none',
|
||||||
|
'Sec-Fetch-User': '?1',
|
||||||
|
'DNT': '1',
|
||||||
|
}
|
||||||
|
|
||||||
|
# 2. Generate UFC profile URLs
|
||||||
|
def make_url(name):
|
||||||
|
"""Turn a fighter name into a UFC athlete URL slug."""
|
||||||
|
if not isinstance(name, str) or not name.strip():
|
||||||
|
return None
|
||||||
|
slug = name.lower().strip()
|
||||||
|
slug = re.sub(r"[''']", "", slug) # Remove apostrophes
|
||||||
|
slug = re.sub(r"[^a-z\s-]", "", slug) # Keep letters, spaces, hyphens
|
||||||
|
slug = re.sub(r"\s+", "-", slug) # Spaces → hyphens
|
||||||
|
return f"https://www.ufc.com/athlete/{slug}"
|
||||||
|
|
||||||
|
df["UFC_Profile_URL"] = df["name"].apply(make_url)
|
||||||
|
|
||||||
|
# 3. Prepare output folder
|
||||||
|
output_folder = "fighter_images"
|
||||||
|
os.makedirs(output_folder, exist_ok=True)
|
||||||
|
|
||||||
|
# 4. Scrape & download each image
|
||||||
|
def download_ufc_image(url, counter):
|
||||||
|
"""Fetch the UFC athlete page, parse out the main profile image, and save it."""
|
||||||
|
try:
|
||||||
|
# Derive filename from URL slug
|
||||||
|
slug = url.rstrip("/").split("/")[-1]
|
||||||
|
filename = f"{slug}.jpg"
|
||||||
|
path = os.path.join(output_folder, filename)
|
||||||
|
|
||||||
|
# Check if image already exists
|
||||||
|
if os.path.exists(path):
|
||||||
|
print(f"[i] Image already exists for {filename}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get fresh random headers for each request
|
||||||
|
headers = get_random_headers()
|
||||||
|
|
||||||
|
resp = requests.get(url, headers=headers)
|
||||||
|
resp.raise_for_status()
|
||||||
|
soup = BeautifulSoup(resp.text, "html.parser")
|
||||||
|
img_tag = soup.select_one("div.hero-profile__image-wrap img.hero-profile__image")
|
||||||
|
if not img_tag or not img_tag.get("src"):
|
||||||
|
print(f"[!] No image found at {url}")
|
||||||
|
return
|
||||||
|
img_url = img_tag["src"]
|
||||||
|
|
||||||
|
# Get fresh random headers for image download
|
||||||
|
headers = get_random_headers()
|
||||||
|
img_data = requests.get(img_url, headers=headers).content
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
f.write(img_data)
|
||||||
|
print(f"[✓] Saved {filename}")
|
||||||
|
|
||||||
|
# Add random delay between 1-3 seconds
|
||||||
|
time.sleep(random.uniform(1, 3))
|
||||||
|
|
||||||
|
# After every 100 downloads, take a longer break
|
||||||
|
if counter % 100 == 0:
|
||||||
|
print(f"[i] Taking a 30-second break after {counter} downloads...")
|
||||||
|
time.sleep(30)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[✗] Error at {url}: {e}")
|
||||||
|
|
||||||
|
# 5. Iterate and download
|
||||||
|
counter = 0
|
||||||
|
for url in df["UFC_Profile_URL"].dropna().unique():
|
||||||
|
counter += 1
|
||||||
|
download_ufc_image(url, counter)
|
||||||
7440
environments/hack0/ufc_env/large_dataset.csv
Normal file
8
environments/hack0/ufc_env/requirements.txt
Normal file
|
|
@ -0,0 +1,8 @@
|
||||||
|
pydantic>=2.0.0
|
||||||
|
Pillow>=9.0.0
|
||||||
|
datasets>=2.0.0
|
||||||
|
numpy>=1.21.0
|
||||||
|
pandas>=1.3.0
|
||||||
|
requests>=2.26.0
|
||||||
|
aiohttp>=3.8.0
|
||||||
|
python-dotenv>=0.19.0
|
||||||
416
environments/hack0/ufc_env/templates/predictor.html
Normal file
|
|
@ -0,0 +1,416 @@
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Fight Predictor</title>
|
||||||
|
<script src="https://cdn.jsdelivr.net/npm/canvas-confetti@1.6.0/dist/confetti.browser.min.js"></script>
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
font-family: 'Arial', sans-serif;
|
||||||
|
background-color: #1a1a1a;
|
||||||
|
color: #ffffff;
|
||||||
|
margin: 0;
|
||||||
|
padding: 20px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.container {
|
||||||
|
max-width: 1200px;
|
||||||
|
margin: 0 auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
h1 {
|
||||||
|
text-align: center;
|
||||||
|
margin-bottom: 30px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.fighters-container {
|
||||||
|
display: flex;
|
||||||
|
justify-content: space-around;
|
||||||
|
margin: 20px 0;
|
||||||
|
position: relative;
|
||||||
|
}
|
||||||
|
|
||||||
|
.fighter {
|
||||||
|
width: 300px;
|
||||||
|
text-align: center;
|
||||||
|
transition: all 0.5s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.fighter img {
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
border: 3px solid;
|
||||||
|
border-radius: 10px;
|
||||||
|
transition: all 0.5s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.red-corner img {
|
||||||
|
border-color: #ff0000;
|
||||||
|
}
|
||||||
|
|
||||||
|
.blue-corner img {
|
||||||
|
border-color: #0000ff;
|
||||||
|
}
|
||||||
|
|
||||||
|
.winner {
|
||||||
|
transform: scale(1.2);
|
||||||
|
z-index: 2;
|
||||||
|
position: relative;
|
||||||
|
}
|
||||||
|
|
||||||
|
.winner::after {
|
||||||
|
content: "WINNER!";
|
||||||
|
position: absolute;
|
||||||
|
top: -30px;
|
||||||
|
left: 50%;
|
||||||
|
transform: translateX(-50%);
|
||||||
|
color: #ffd700;
|
||||||
|
font-size: 24px;
|
||||||
|
font-weight: bold;
|
||||||
|
text-shadow: 2px 2px 4px rgba(0,0,0,0.5);
|
||||||
|
}
|
||||||
|
|
||||||
|
.loser {
|
||||||
|
animation: explode 1s forwards;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes explode {
|
||||||
|
0% {
|
||||||
|
transform: scale(1);
|
||||||
|
opacity: 1;
|
||||||
|
}
|
||||||
|
50% {
|
||||||
|
transform: scale(1.5);
|
||||||
|
opacity: 0.5;
|
||||||
|
}
|
||||||
|
100% {
|
||||||
|
transform: scale(0);
|
||||||
|
opacity: 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.commentary {
|
||||||
|
background-color: #2a2a2a;
|
||||||
|
padding: 20px;
|
||||||
|
border-radius: 10px;
|
||||||
|
margin: 20px 0;
|
||||||
|
min-height: 200px;
|
||||||
|
white-space: pre-wrap;
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.typing {
|
||||||
|
overflow: hidden;
|
||||||
|
border-right: 2px solid #fff;
|
||||||
|
white-space: nowrap;
|
||||||
|
animation: typing 1s steps(40, end),
|
||||||
|
blink-caret 0.75s step-end infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes typing {
|
||||||
|
from { width: 0 }
|
||||||
|
to { width: 100% }
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes blink-caret {
|
||||||
|
from, to { border-color: transparent }
|
||||||
|
50% { border-color: #fff }
|
||||||
|
}
|
||||||
|
|
||||||
|
.upload-form {
|
||||||
|
display: flex;
|
||||||
|
justify-content: space-around;
|
||||||
|
margin: 20px 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.upload-section {
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
input[type="file"] {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.upload-btn {
|
||||||
|
background-color: #333;
|
||||||
|
color: white;
|
||||||
|
padding: 10px 20px;
|
||||||
|
border-radius: 5px;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: background-color 0.3s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.upload-btn:hover {
|
||||||
|
background-color: #444;
|
||||||
|
}
|
||||||
|
|
||||||
|
.predict-btn {
|
||||||
|
background-color: #ff0000;
|
||||||
|
color: white;
|
||||||
|
padding: 15px 30px;
|
||||||
|
border: none;
|
||||||
|
border-radius: 5px;
|
||||||
|
cursor: pointer;
|
||||||
|
font-size: 18px;
|
||||||
|
margin: 20px auto;
|
||||||
|
display: block;
|
||||||
|
transition: background-color 0.3s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.predict-btn:hover {
|
||||||
|
background-color: #cc0000;
|
||||||
|
}
|
||||||
|
|
||||||
|
.loading {
|
||||||
|
display: none;
|
||||||
|
text-align: center;
|
||||||
|
margin: 20px 0;
|
||||||
|
font-size: 24px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.loading::before {
|
||||||
|
content: "⚡";
|
||||||
|
display: inline-block;
|
||||||
|
animation: loading 0.5s infinite;
|
||||||
|
margin-right: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.loading::after {
|
||||||
|
content: "Analyzing Fight...";
|
||||||
|
display: inline-block;
|
||||||
|
animation: pulse 1s infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes loading {
|
||||||
|
0% { transform: rotate(0deg); }
|
||||||
|
100% { transform: rotate(360deg); }
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes pulse {
|
||||||
|
0% { opacity: 0.5; }
|
||||||
|
50% { opacity: 1; }
|
||||||
|
100% { opacity: 0.5; }
|
||||||
|
}
|
||||||
|
|
||||||
|
.new-fight-btn {
|
||||||
|
background-color: #4CAF50;
|
||||||
|
color: white;
|
||||||
|
padding: 15px 30px;
|
||||||
|
border: none;
|
||||||
|
border-radius: 5px;
|
||||||
|
cursor: pointer;
|
||||||
|
font-size: 18px;
|
||||||
|
margin: 20px auto;
|
||||||
|
display: none;
|
||||||
|
transition: background-color 0.3s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.new-fight-btn:hover {
|
||||||
|
background-color: #45a049;
|
||||||
|
}
|
||||||
|
|
||||||
|
.buttons-container {
|
||||||
|
display: flex;
|
||||||
|
justify-content: center;
|
||||||
|
gap: 20px;
|
||||||
|
margin: 20px 0;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="container">
|
||||||
|
<h1>🤼♂️ Fight Predictor 🤼♂️</h1>
|
||||||
|
|
||||||
|
<!-- Hidden iframe for airhorn -->
|
||||||
|
<iframe id="airhorn" width="110" height="200" src="https://www.myinstants.com/instant/dj-airhorn/embed/" frameborder="0" scrolling="no" style="display: none;"></iframe>
|
||||||
|
|
||||||
|
<form id="predictForm" class="upload-form">
|
||||||
|
<div class="upload-section">
|
||||||
|
<label class="upload-btn">
|
||||||
|
Upload Red Corner Fighter
|
||||||
|
<input type="file" name="red_fighter" accept="image/*" required>
|
||||||
|
</label>
|
||||||
|
<div class="fighter red-corner">
|
||||||
|
<img id="redPreview" src="" alt="Red Corner Fighter" style="display: none;">
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="upload-section">
|
||||||
|
<label class="upload-btn">
|
||||||
|
Upload Blue Corner Fighter
|
||||||
|
<input type="file" name="blue_fighter" accept="image/*" required>
|
||||||
|
</label>
|
||||||
|
<div class="fighter blue-corner">
|
||||||
|
<img id="bluePreview" src="" alt="Blue Corner Fighter" style="display: none;">
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</form>
|
||||||
|
|
||||||
|
<div class="buttons-container">
|
||||||
|
<button class="predict-btn" onclick="predictFight()">Predict Fight!</button>
|
||||||
|
<button class="new-fight-btn" onclick="resetFight()">New Fight</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="loading" id="loading"></div>
|
||||||
|
<div class="commentary" id="commentary"></div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
function previewImage(input, previewId) {
|
||||||
|
const preview = document.getElementById(previewId);
|
||||||
|
if (input.files && input.files[0]) {
|
||||||
|
const reader = new FileReader();
|
||||||
|
reader.onload = function(e) {
|
||||||
|
preview.src = e.target.result;
|
||||||
|
preview.style.display = 'block';
|
||||||
|
}
|
||||||
|
reader.readAsDataURL(input.files[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
document.querySelector('input[name="red_fighter"]').addEventListener('change', function() {
|
||||||
|
previewImage(this, 'redPreview');
|
||||||
|
});
|
||||||
|
|
||||||
|
document.querySelector('input[name="blue_fighter"]').addEventListener('change', function() {
|
||||||
|
previewImage(this, 'bluePreview');
|
||||||
|
});
|
||||||
|
|
||||||
|
function playAirhorn() {
|
||||||
|
const iframe = document.getElementById('airhorn');
|
||||||
|
iframe.contentWindow.postMessage('play', '*');
|
||||||
|
}
|
||||||
|
|
||||||
|
function triggerConfetti() {
|
||||||
|
const duration = 3 * 1000;
|
||||||
|
const animationEnd = Date.now() + duration;
|
||||||
|
const defaults = { startVelocity: 30, spread: 360, ticks: 60, zIndex: 0 };
|
||||||
|
|
||||||
|
function randomInRange(min, max) {
|
||||||
|
return Math.random() * (max - min) + min;
|
||||||
|
}
|
||||||
|
|
||||||
|
const interval = setInterval(function() {
|
||||||
|
const timeLeft = animationEnd - Date.now();
|
||||||
|
|
||||||
|
if (timeLeft <= 0) {
|
||||||
|
return clearInterval(interval);
|
||||||
|
}
|
||||||
|
|
||||||
|
const particleCount = 50 * (timeLeft / duration);
|
||||||
|
|
||||||
|
// since particles fall down, start a bit higher than random
|
||||||
|
confetti({
|
||||||
|
...defaults,
|
||||||
|
particleCount,
|
||||||
|
origin: { x: randomInRange(0.1, 0.3), y: Math.random() - 0.2 }
|
||||||
|
});
|
||||||
|
confetti({
|
||||||
|
...defaults,
|
||||||
|
particleCount,
|
||||||
|
origin: { x: randomInRange(0.7, 0.9), y: Math.random() - 0.2 }
|
||||||
|
});
|
||||||
|
}, 250);
|
||||||
|
}
|
||||||
|
|
||||||
|
function typeText(element, text, speed = 30) {
|
||||||
|
let i = 0;
|
||||||
|
element.innerHTML = '';
|
||||||
|
element.style.display = 'block';
|
||||||
|
|
||||||
|
function type() {
|
||||||
|
if (i < text.length) {
|
||||||
|
element.innerHTML += text.charAt(i);
|
||||||
|
i++;
|
||||||
|
setTimeout(type, speed);
|
||||||
|
} else {
|
||||||
|
// After typing is complete, check for winner
|
||||||
|
const winnerMatch = text.match(/\\boxed{(Red|Blue)}/);
|
||||||
|
if (winnerMatch) {
|
||||||
|
const winner = winnerMatch[1];
|
||||||
|
const redFighter = document.querySelector('.red-corner');
|
||||||
|
const blueFighter = document.querySelector('.blue-corner');
|
||||||
|
|
||||||
|
if (winner === 'Red') {
|
||||||
|
redFighter.classList.add('winner');
|
||||||
|
blueFighter.classList.add('loser');
|
||||||
|
// Trigger effects from red corner
|
||||||
|
triggerConfetti();
|
||||||
|
playAirhorn();
|
||||||
|
} else {
|
||||||
|
blueFighter.classList.add('winner');
|
||||||
|
redFighter.classList.add('loser');
|
||||||
|
// Trigger effects from blue corner
|
||||||
|
triggerConfetti();
|
||||||
|
playAirhorn();
|
||||||
|
}
|
||||||
|
// Show new fight button after prediction is complete
|
||||||
|
document.querySelector('.new-fight-btn').style.display = 'block';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type();
|
||||||
|
}
|
||||||
|
|
||||||
|
function resetFight() {
|
||||||
|
// Reset images
|
||||||
|
document.getElementById('redPreview').style.display = 'none';
|
||||||
|
document.getElementById('redPreview').src = '';
|
||||||
|
document.getElementById('bluePreview').style.display = 'none';
|
||||||
|
document.getElementById('bluePreview').src = '';
|
||||||
|
|
||||||
|
// Reset form
|
||||||
|
document.getElementById('predictForm').reset();
|
||||||
|
|
||||||
|
// Reset classes
|
||||||
|
document.querySelector('.red-corner').classList.remove('winner', 'loser');
|
||||||
|
document.querySelector('.blue-corner').classList.remove('winner', 'loser');
|
||||||
|
|
||||||
|
// Clear and hide commentary
|
||||||
|
const commentary = document.getElementById('commentary');
|
||||||
|
commentary.innerHTML = '';
|
||||||
|
commentary.style.display = 'none';
|
||||||
|
|
||||||
|
// Hide new fight button
|
||||||
|
document.querySelector('.new-fight-btn').style.display = 'none';
|
||||||
|
}
|
||||||
|
|
||||||
|
async function predictFight() {
|
||||||
|
const form = document.getElementById('predictForm');
|
||||||
|
const formData = new FormData(form);
|
||||||
|
|
||||||
|
const loading = document.getElementById('loading');
|
||||||
|
const commentary = document.getElementById('commentary');
|
||||||
|
const predictBtn = document.querySelector('.predict-btn');
|
||||||
|
|
||||||
|
loading.style.display = 'block';
|
||||||
|
commentary.style.display = 'none';
|
||||||
|
predictBtn.disabled = true;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await fetch('/predict', {
|
||||||
|
method: 'POST',
|
||||||
|
body: formData
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (data.success) {
|
||||||
|
typeText(commentary, data.prediction);
|
||||||
|
} else {
|
||||||
|
commentary.style.display = 'block';
|
||||||
|
commentary.innerHTML = `Error: ${data.error}`;
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
commentary.style.display = 'block';
|
||||||
|
commentary.innerHTML = `Error: ${error.message}`;
|
||||||
|
} finally {
|
||||||
|
loading.style.display = 'none';
|
||||||
|
predictBtn.disabled = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
326
environments/hack0/ufc_env/ufc_image_env.py
Normal file
|
|
@ -0,0 +1,326 @@
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
import csv
|
||||||
|
from typing import List, Optional, Tuple, Any, Dict
|
||||||
|
import base64
|
||||||
|
from PIL import Image
|
||||||
|
import io
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
|
||||||
|
from atroposlib.type_definitions import GameHistory, Item
|
||||||
|
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||||
|
|
||||||
|
|
||||||
|
class UFCImageEnvConfig(BaseEnvConfig):
|
||||||
|
"""Configuration for the UFC Image Environment"""
|
||||||
|
fighter_stats_path: str = Field(os.path.join(os.path.dirname(__file__), "fighter_stats.csv"), description="Path to fighter stats CSV")
|
||||||
|
fight_data_path: str = Field(os.path.join(os.path.dirname(__file__), "large_dataset.csv"), description="Path to large fight dataset CSV")
|
||||||
|
image_folder: str = Field(os.path.join(os.path.dirname(__file__), "fighter_images"), description="Path to fighter images folder")
|
||||||
|
max_steps: int = Field(1, description="Only one step per fight prediction")
|
||||||
|
temperature: float = Field(0.7, description="Temperature for generation diversity")
|
||||||
|
top_p: float = Field(0.95, description="Top p for nucleus sampling")
|
||||||
|
|
||||||
|
|
||||||
|
class UFCImageEnv(BaseEnv):
|
||||||
|
"""UFC Fight Prediction Environment using only fighter images"""
|
||||||
|
name = "ufc_image_predictor"
|
||||||
|
env_config_cls = UFCImageEnvConfig
|
||||||
|
|
||||||
|
def __init__(self, config: UFCImageEnvConfig, server_configs: List[OpenaiConfig], slurm=True, testing=False):
|
||||||
|
super().__init__(config, server_configs, slurm, testing)
|
||||||
|
self.fighter_stats = {}
|
||||||
|
self.fight_data = []
|
||||||
|
self.current_index = 0
|
||||||
|
self.inference_server = self.server.servers[0] # Get first server as inference server
|
||||||
|
|
||||||
|
async def setup(self):
|
||||||
|
"""Load the fighter stats and fight data"""
|
||||||
|
try:
|
||||||
|
print("Loading fighter stats from:", self.config.fighter_stats_path)
|
||||||
|
with open(self.config.fighter_stats_path, encoding="utf-8") as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
self.fighter_stats = {row["name"]: row for row in reader}
|
||||||
|
print(f"Loaded stats for {len(self.fighter_stats)} fighters")
|
||||||
|
|
||||||
|
print("Loading fight data from:", self.config.fight_data_path)
|
||||||
|
with open(self.config.fight_data_path, encoding="utf-8") as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
self.fight_data = list(reader)
|
||||||
|
print(f"Loaded {len(self.fight_data)} fights")
|
||||||
|
|
||||||
|
# Filter out fights where either fighter's image is missing
|
||||||
|
filtered_fights = []
|
||||||
|
missing_images = set() # Track unique missing images
|
||||||
|
for fight in self.fight_data:
|
||||||
|
r_fighter = fight["r_fighter"]
|
||||||
|
b_fighter = fight["b_fighter"]
|
||||||
|
|
||||||
|
# Convert names to image filename format
|
||||||
|
r_slug = r_fighter.lower().replace(" ", "-")
|
||||||
|
b_slug = b_fighter.lower().replace(" ", "-")
|
||||||
|
|
||||||
|
r_image_path = os.path.join(self.config.image_folder, f"{r_slug}.jpg")
|
||||||
|
b_image_path = os.path.join(self.config.image_folder, f"{b_slug}.jpg")
|
||||||
|
|
||||||
|
if os.path.exists(r_image_path) and os.path.exists(b_image_path):
|
||||||
|
filtered_fights.append(fight)
|
||||||
|
else:
|
||||||
|
if not os.path.exists(r_image_path):
|
||||||
|
missing_images.add(r_fighter)
|
||||||
|
if not os.path.exists(b_image_path):
|
||||||
|
missing_images.add(b_fighter)
|
||||||
|
|
||||||
|
if missing_images:
|
||||||
|
print(f"\nMissing images for {len(missing_images)} fighters. These fights will be skipped.")
|
||||||
|
|
||||||
|
self.fight_data = filtered_fights
|
||||||
|
print(f"Filtered to {len(self.fight_data)} fights with complete image sets")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading data: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def get_fighter_image(self, fighter_name):
|
||||||
|
"""Convert fighter name to image path and return base64 encoded image"""
|
||||||
|
try:
|
||||||
|
# Convert name to slug format
|
||||||
|
slug = fighter_name.lower().replace(" ", "-")
|
||||||
|
|
||||||
|
image_path = os.path.join(self.config.image_folder, f"{slug}.jpg")
|
||||||
|
if not os.path.exists(image_path):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Convert image to base64
|
||||||
|
with Image.open(image_path) as img:
|
||||||
|
# Convert RGBA to RGB if necessary
|
||||||
|
if img.mode == 'RGBA':
|
||||||
|
img = img.convert('RGB')
|
||||||
|
buf = io.BytesIO()
|
||||||
|
img.save(buf, format="JPEG")
|
||||||
|
image_bytes = buf.getvalue()
|
||||||
|
return base64.b64encode(image_bytes).decode("utf-8")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error getting image for {fighter_name}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_next_item(self) -> Optional[Item]:
|
||||||
|
"""Get the next fight from the dataset"""
|
||||||
|
try:
|
||||||
|
if self.current_index >= len(self.fight_data):
|
||||||
|
return None
|
||||||
|
fight = self.fight_data[self.current_index]
|
||||||
|
self.current_index += 1
|
||||||
|
|
||||||
|
r_fighter = fight["r_fighter"]
|
||||||
|
b_fighter = fight["b_fighter"]
|
||||||
|
|
||||||
|
# Get base64 encoded images
|
||||||
|
r_image = self.get_fighter_image(r_fighter)
|
||||||
|
b_image = self.get_fighter_image(b_fighter)
|
||||||
|
|
||||||
|
if not r_image or not b_image:
|
||||||
|
print(f"Skipping fight {self.current_index} due to missing images")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Format the prompt with images
|
||||||
|
prompt_text = (
|
||||||
|
"🎤 LADIES AND GENTLEMEN! Welcome to the most electrifying show in sports entertainment "
|
||||||
|
"Let's break down this matchup that's got everyone talking!\n\n"
|
||||||
|
"In the red corner, we have:(YOUR FIRST IMAGE):\n"
|
||||||
|
"And in the blue corner: (YOUR SECOND IMAGE):\n\n"
|
||||||
|
"Now, act as your favorite fight comentator, I want you to:\n"
|
||||||
|
"create a fight commentary of whats happening in the fight live\n"
|
||||||
|
"Give us your best fight commentary! Make it exciting, make it dramatic, make it sound like you're calling the fight live! "
|
||||||
|
"Throw in some classic commentator phrases, maybe a 'OH MY GOODNESS!' or two, and definitely some dramatic pauses for effect.\n\n"
|
||||||
|
"End your masterpiece with your prediction in this exact format:\n"
|
||||||
|
"[S1]Hello im your host [S2] And so am i (name) [S1] Wow. Amazing. (laughs) [S2] Lets get started! (coughs)\n\n"
|
||||||
|
"The winner should always be annouced with"
|
||||||
|
"\\boxed{Red} or \\boxed{Blue}"
|
||||||
|
"Or you will receive a score of -1.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create multimodal prompt with images
|
||||||
|
prompt = tuple([
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": prompt_text},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": f"data:image/jpeg;base64,{r_image}"}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": f"data:image/jpeg;base64,{b_image}"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
])
|
||||||
|
|
||||||
|
winner = fight.get("winner", "") # Red or Blue
|
||||||
|
ground_truth = f"Answer: {winner}" if winner else ""
|
||||||
|
|
||||||
|
return (prompt, ground_truth, None)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in get_next_item: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def collect_trajectories(self, item: Item) -> Tuple[List[Tuple[GameHistory, str, Optional[str]]], List[Item]]:
|
||||||
|
to_score = []
|
||||||
|
to_backlog = []
|
||||||
|
|
||||||
|
system_msg = {
|
||||||
|
"role": "system",
|
||||||
|
"content": (
|
||||||
|
"You are an expert MMA analyst. You will be given two UFC fighters' images. "
|
||||||
|
"Your task is to predict the winner of the fight based on their appearance and physique.\n\n"
|
||||||
|
"IMPORTANT: You MUST format your response in exactly two parts:\n"
|
||||||
|
"1. First, analyze the fighters' appearances and create a fight commentary\n"
|
||||||
|
"2. Then on a new line, give ONLY your final prediction in this exact format:\n"
|
||||||
|
"\\boxed{Red} or \\boxed{Blue}\n\n"
|
||||||
|
"For example:\n"
|
||||||
|
"After analyzing the fighters' appearances... [your analysis here]\n"
|
||||||
|
"\\boxed{Red}\n\n"
|
||||||
|
"If you do not end your response with the \\boxed{} format containing either 'Red' or 'Blue', you will receive a score of -1.0."
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
user_msg = {
|
||||||
|
"role": "user",
|
||||||
|
"content": dict(item[0][0])["content"]
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = [system_msg, user_msg]
|
||||||
|
|
||||||
|
try:
|
||||||
|
chat_completions = await self.inference_server.chat_completion(
|
||||||
|
messages=messages,
|
||||||
|
n=self.config.group_size,
|
||||||
|
max_tokens=2048,
|
||||||
|
temperature=self.config.temperature,
|
||||||
|
top_p=self.config.top_p,
|
||||||
|
timeout=60,
|
||||||
|
)
|
||||||
|
for choice in chat_completions.choices:
|
||||||
|
assistant_msg = {"role": "assistant", "content": choice.message.content}
|
||||||
|
history = [
|
||||||
|
{"role": "system", "content": system_msg["content"]},
|
||||||
|
{"role": "user", "content": user_msg["content"]},
|
||||||
|
{"role": "assistant", "content": choice.message.content}
|
||||||
|
]
|
||||||
|
to_score.append((history, item[1], None))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in collect_trajectories: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
to_backlog.append(item)
|
||||||
|
|
||||||
|
if not to_score:
|
||||||
|
return None, to_backlog
|
||||||
|
|
||||||
|
scored_data = await self.score(to_score)
|
||||||
|
return scored_data, to_backlog
|
||||||
|
|
||||||
|
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
|
||||||
|
if not rollout_group_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
scores = ScoredDataGroup()
|
||||||
|
scores["tokens"] = []
|
||||||
|
scores["masks"] = []
|
||||||
|
scores["scores"] = []
|
||||||
|
scores["advantages"] = None
|
||||||
|
scores["ref_logprobs"] = None
|
||||||
|
scores["messages"] = None
|
||||||
|
scores["group_overrides"] = {"group_size": self.config.group_size}
|
||||||
|
scores["overrides"] = None
|
||||||
|
scores["ground_truths"] = []
|
||||||
|
|
||||||
|
random.shuffle(rollout_group_data)
|
||||||
|
for item in rollout_group_data:
|
||||||
|
out = tokenize_for_trainer(self.tokenizer, item[0])
|
||||||
|
tokens = out["tokens"]
|
||||||
|
masks = out["masks"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Extract prediction and ground truth
|
||||||
|
reply = item[0][-1]["content"]
|
||||||
|
ground_truth = item[1].strip().lower()
|
||||||
|
|
||||||
|
# Extract color from ground truth (format: "answer: color")
|
||||||
|
ground_truth_color = ground_truth.replace("answer:", "").strip()
|
||||||
|
|
||||||
|
# Extract color from \boxed{color} format
|
||||||
|
import re
|
||||||
|
boxed_match = re.search(r"\\boxed{([^}]+)}", reply)
|
||||||
|
if boxed_match:
|
||||||
|
prediction = boxed_match.group(1).strip().lower()
|
||||||
|
# Compare just the colors
|
||||||
|
reward = 1.0 if prediction == ground_truth_color else -1.0
|
||||||
|
else:
|
||||||
|
# No boxed answer found
|
||||||
|
reward = -1.0
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error scoring response: {e}")
|
||||||
|
reward = -1.0
|
||||||
|
ground_truth = item[1] if isinstance(item[1], str) else ""
|
||||||
|
|
||||||
|
if len([i for i in masks if i != -100]) < 10:
|
||||||
|
continue
|
||||||
|
|
||||||
|
scores["tokens"].append(tokens)
|
||||||
|
scores["masks"].append(masks)
|
||||||
|
scores["scores"].append(reward)
|
||||||
|
scores["ground_truths"].append(ground_truth)
|
||||||
|
|
||||||
|
if len(scores["tokens"]) >= self.config.group_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not scores["tokens"]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
async def evaluate(self, *args, **kwargs):
|
||||||
|
"""No-op evaluation"""
|
||||||
|
return
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
|
||||||
|
"""Initialize configuration for the environment"""
|
||||||
|
if not os.environ.get("OPENAI_API_KEY"):
|
||||||
|
print("ERROR: OPENAI_API_KEY environment variable is not set!")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
config = UFCImageEnvConfig(
|
||||||
|
wandb_name="ufc_image",
|
||||||
|
tokenizer_name="gpt2",
|
||||||
|
group_size=2,
|
||||||
|
use_wandb=False,
|
||||||
|
max_num_workers=2,
|
||||||
|
rollout_server_url="http://localhost:8000",
|
||||||
|
total_steps=1000,
|
||||||
|
batch_size=1,
|
||||||
|
steps_per_eval=10,
|
||||||
|
ensure_scores_are_not_same=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
server_configs = [
|
||||||
|
OpenaiConfig(
|
||||||
|
model_name="gpt-4o",
|
||||||
|
base_url=None,
|
||||||
|
api_key=os.environ.get("OPENAI_API_KEY"),
|
||||||
|
num_requests_for_eval=1,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
return config, server_configs
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
UFCImageEnv.cli()
|
||||||
104
environments/hack0/ufc_env/ufc_predictor_ui.py
Normal file
|
|
@ -0,0 +1,104 @@
|
||||||
|
import os
|
||||||
|
import base64
|
||||||
|
from io import BytesIO
|
||||||
|
from flask import Flask, render_template, request, jsonify
|
||||||
|
from PIL import Image
|
||||||
|
import openai
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size
|
||||||
|
|
||||||
|
# Initialize OpenAI client
|
||||||
|
client = openai.OpenAI(api_key=os.getenv('OPENAI_API_KEY'))
|
||||||
|
|
||||||
|
def process_image(image_file):
|
||||||
|
"""Convert uploaded image to base64"""
|
||||||
|
img = Image.open(image_file)
|
||||||
|
# Convert RGBA to RGB if necessary
|
||||||
|
if img.mode == 'RGBA':
|
||||||
|
img = img.convert('RGB')
|
||||||
|
buffered = BytesIO()
|
||||||
|
img.save(buffered, format="JPEG")
|
||||||
|
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||||
|
|
||||||
|
@app.route('/')
|
||||||
|
def home():
|
||||||
|
return render_template('predictor.html')
|
||||||
|
|
||||||
|
@app.route('/predict', methods=['POST'])
|
||||||
|
def predict():
|
||||||
|
try:
|
||||||
|
# Get uploaded images
|
||||||
|
red_fighter = request.files['red_fighter']
|
||||||
|
blue_fighter = request.files['blue_fighter']
|
||||||
|
|
||||||
|
if not red_fighter or not blue_fighter:
|
||||||
|
return jsonify({'error': 'Please upload both fighter images'}), 400
|
||||||
|
|
||||||
|
# Process images to base64
|
||||||
|
red_image = process_image(red_fighter)
|
||||||
|
blue_image = process_image(blue_fighter)
|
||||||
|
|
||||||
|
# Create the prompt
|
||||||
|
prompt_text = (
|
||||||
|
"🎤 LADIES AND GENTLEMEN! Welcome to the most electrifying show in sports entertainment "
|
||||||
|
"Let's break down this matchup that's got everyone talking!\n\n"
|
||||||
|
"In the red corner, we have:(YOUR FIRST IMAGE):\n"
|
||||||
|
"And in the blue corner: (YOUR SECOND IMAGE):\n\n"
|
||||||
|
"Now, as your favorite fight comentator, I want you to:\n"
|
||||||
|
"create a fight commentary of whats happening in the fight live\n"
|
||||||
|
"Give us your best fight commentary! Make it exciting, make it dramatic, make it sound like you're calling the fight live! "
|
||||||
|
"Throw in some classic commentator phrases, maybe a 'OH MY GOODNESS!' or two, and definitely some dramatic pauses for effect.\n\n"
|
||||||
|
"End your masterpiece with your prediction in this exact format:\n"
|
||||||
|
"\\boxed{Red} or \\boxed{Blue}"
|
||||||
|
"PLEASE FORMAT THE COMMENTARY IN THE EXACT FORMAT AS THE EXAMPLE BELOW:\n"
|
||||||
|
"[S1]Hello im your host [S2] And so am i (name) [S1] Wow. Amazing. (laughs) [S2] Lets get started! (coughs) ( add lots of coughs and laughs)\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the messages for the API call
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": prompt_text},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": f"data:image/jpeg;base64,{red_image}"}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": f"data:image/jpeg;base64,{blue_image}"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Make the API call
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="gpt-4o",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=2048,
|
||||||
|
temperature=0.7,
|
||||||
|
top_p=0.95
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract the prediction
|
||||||
|
prediction = response.choices[0].message.content
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'prediction': prediction,
|
||||||
|
'success': True
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({
|
||||||
|
'error': str(e),
|
||||||
|
'success': False
|
||||||
|
}), 500
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
app.run(debug=True)
|
||||||
233
environments/hack0/ufc_env/ufc_server.py
Normal file
|
|
@ -0,0 +1,233 @@
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
import csv
|
||||||
|
from typing import List, Optional, Tuple, Any, Dict
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
|
||||||
|
from atroposlib.type_definitions import GameHistory, Item
|
||||||
|
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||||
|
|
||||||
|
|
||||||
|
class UFCEnvConfig(BaseEnvConfig):
|
||||||
|
"""Configuration for the UFC Environment"""
|
||||||
|
fighter_stats_path: str = Field(os.path.join(os.path.dirname(__file__), "fighter_stats.csv"), description="Path to fighter stats CSV")
|
||||||
|
fight_data_path: str = Field(os.path.join(os.path.dirname(__file__), "large_dataset.csv"), description="Path to large fight dataset CSV")
|
||||||
|
max_steps: int = Field(1, description="Only one step per fight prediction")
|
||||||
|
temperature: float = Field(0.7, description="Temperature for generation diversity")
|
||||||
|
top_p: float = Field(0.95, description="Top p for nucleus sampling")
|
||||||
|
|
||||||
|
|
||||||
|
class UFCEnv(BaseEnv):
|
||||||
|
"""UFC Fight Prediction Environment"""
|
||||||
|
name = "ufc_predictor"
|
||||||
|
env_config_cls = UFCEnvConfig
|
||||||
|
|
||||||
|
def __init__(self, config: UFCEnvConfig, server_configs: List[OpenaiConfig], slurm=True, testing=False):
|
||||||
|
super().__init__(config, server_configs, slurm, testing)
|
||||||
|
self.fighter_stats = {}
|
||||||
|
self.fight_data = []
|
||||||
|
self.current_index = 0
|
||||||
|
self.inference_server = self.server.servers[0] # Get first server as inference server
|
||||||
|
|
||||||
|
async def setup(self):
|
||||||
|
"""Load the fighter stats and fight data"""
|
||||||
|
try:
|
||||||
|
print("Loading fighter stats from:", self.config.fighter_stats_path)
|
||||||
|
with open(self.config.fighter_stats_path, encoding="utf-8") as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
self.fighter_stats = {row["name"]: row for row in reader}
|
||||||
|
print(f"Loaded stats for {len(self.fighter_stats)} fighters")
|
||||||
|
|
||||||
|
print("Loading fight data from:", self.config.fight_data_path)
|
||||||
|
with open(self.config.fight_data_path, encoding="utf-8") as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
self.fight_data = list(reader)
|
||||||
|
print(f"Loaded {len(self.fight_data)} fights")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading data: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
async def get_next_item(self) -> Optional[Item]:
|
||||||
|
"""Get the next fight from the dataset"""
|
||||||
|
try:
|
||||||
|
if self.current_index >= len(self.fight_data):
|
||||||
|
return None
|
||||||
|
fight = self.fight_data[self.current_index]
|
||||||
|
self.current_index += 1
|
||||||
|
|
||||||
|
r_fighter = fight["r_fighter"]
|
||||||
|
b_fighter = fight["b_fighter"]
|
||||||
|
r_stats = self.fighter_stats.get(r_fighter, {})
|
||||||
|
b_stats = self.fighter_stats.get(b_fighter, {})
|
||||||
|
|
||||||
|
# Format the prompt
|
||||||
|
def stats_str(name, stats):
|
||||||
|
if not stats:
|
||||||
|
return f"{name}: (No stats available)"
|
||||||
|
return (
|
||||||
|
f"Name: {name}\n"
|
||||||
|
f"Wins: {stats.get('wins','?')} Losses: {stats.get('losses','?')} Age: {stats.get('age','?')}\n"
|
||||||
|
f"Height: {stats.get('height','?')} cm Weight: {stats.get('weight','?')} kg Reach: {stats.get('reach','?')} cm Stance: {stats.get('stance','?')}\n"
|
||||||
|
f"SLpM: {stats.get('SLpM','?')} Sig Str Acc: {stats.get('sig_str_acc','?')} SApM: {stats.get('SApM','?')} Str Def: {stats.get('str_def','?')}\n"
|
||||||
|
f"TD Avg: {stats.get('td_avg','?')} TD Acc: {stats.get('td_acc','?')} TD Def: {stats.get('td_def','?')} Sub Avg: {stats.get('sub_avg','?')}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_text = (
|
||||||
|
"🎤 LADIES AND GENTLEMEN! Welcome to the most electrifying show in sports entertainment - the UFC Fight Prediction Show! "
|
||||||
|
"Let's break down this matchup that's got everyone talking!\n\n"
|
||||||
|
f"*Drumroll please* In the red corner, we have :\n{stats_str(r_fighter, r_stats)}\n\n"
|
||||||
|
f"And in the blue corner:\n{stats_str(b_fighter, b_stats)}\n\n"
|
||||||
|
"Now, as your favorite fight analyst who's definitely not just making this up as I go along, I want you to:\n"
|
||||||
|
"1. Break down these fighters like you're explaining why your favorite TV show character would win in a fight\n"
|
||||||
|
"2. Compare their styles\n"
|
||||||
|
"3. Point out their advantages\n"
|
||||||
|
"Give us your best fight commentary! Make it exciting, make it dramatic, make it sound like you're calling the fight live! "
|
||||||
|
"Throw in some classic commentator phrases, maybe a 'OH MY GOODNESS!' or two, and definitely some dramatic pauses for effect.\n\n"
|
||||||
|
"End your masterpiece with the winner's name in this exact format:\n"
|
||||||
|
"\\boxed{fighter name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = tuple([
|
||||||
|
frozenset({"role": "user", "content": prompt_text}.items())
|
||||||
|
])
|
||||||
|
|
||||||
|
winner = fight.get("winner", "") # Red or Blue
|
||||||
|
winner_name = r_fighter if winner == "Red" else b_fighter if winner == "Blue" else ""
|
||||||
|
ground_truth = f"Answer: {winner_name}" if winner_name else ""
|
||||||
|
|
||||||
|
return (prompt, ground_truth, None)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in get_next_item: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def collect_trajectories(self, item: Item) -> Tuple[List[Tuple[GameHistory, str, Optional[str]]], List[Item]]:
|
||||||
|
to_score = []
|
||||||
|
to_backlog = []
|
||||||
|
|
||||||
|
system_msg = {
|
||||||
|
"role": "system",
|
||||||
|
"content": (
|
||||||
|
"You are an expert MMA analyst. You will be given two UFC fighters and their stats. "
|
||||||
|
"Your task is to predict the winner of the fight based on their statistics.\n\n"
|
||||||
|
"IMPORTANT: You MUST format your response in exactly two parts:\n"
|
||||||
|
"1. First, analyze the fighters' stats and explain create a fight commentary\n"
|
||||||
|
"2. Then on a new line, give ONLY your final prediction in this exact format:\n"
|
||||||
|
"\\boxed{fighter name}\n\n"
|
||||||
|
"For example:\n"
|
||||||
|
"After analyzing stats... [your analysis here]\n"
|
||||||
|
"\\boxed{John Smith}\n\n"
|
||||||
|
"If you do not end your response with the \\boxed{} format, you will receive a score of -1.0."
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
user_msg = {
|
||||||
|
"role": "user",
|
||||||
|
"content": dict(item[0][0])["content"]
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = [system_msg, user_msg]
|
||||||
|
|
||||||
|
try:
|
||||||
|
chat_completions = await self.inference_server.chat_completion(
|
||||||
|
messages=messages,
|
||||||
|
n=self.config.group_size,
|
||||||
|
max_tokens=2048, # Increased from 512 to allow for longer, more detailed fight commentaries
|
||||||
|
temperature=self.config.temperature,
|
||||||
|
top_p=self.config.top_p,
|
||||||
|
timeout=60,
|
||||||
|
)
|
||||||
|
for choice in chat_completions.choices:
|
||||||
|
assistant_msg = {"role": "assistant", "content": choice.message.content}
|
||||||
|
history = [
|
||||||
|
{"role": "system", "content": system_msg["content"]},
|
||||||
|
{"role": "user", "content": user_msg["content"]},
|
||||||
|
{"role": "assistant", "content": choice.message.content}
|
||||||
|
]
|
||||||
|
to_score.append((history, item[1], None))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in collect_trajectories: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
to_backlog.append(item)
|
||||||
|
|
||||||
|
if not to_score:
|
||||||
|
return None, to_backlog
|
||||||
|
|
||||||
|
scored_data = await self.score(to_score)
|
||||||
|
return scored_data, to_backlog
|
||||||
|
|
||||||
|
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
|
||||||
|
if not rollout_group_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
scores = ScoredDataGroup()
|
||||||
|
scores["tokens"] = []
|
||||||
|
scores["masks"] = []
|
||||||
|
scores["scores"] = []
|
||||||
|
scores["advantages"] = None
|
||||||
|
scores["ref_logprobs"] = None
|
||||||
|
scores["messages"] = None
|
||||||
|
scores["group_overrides"] = {"group_size": self.config.group_size}
|
||||||
|
scores["overrides"] = None
|
||||||
|
scores["ground_truths"] = []
|
||||||
|
|
||||||
|
random.shuffle(rollout_group_data)
|
||||||
|
for item in rollout_group_data:
|
||||||
|
out = tokenize_for_trainer(self.tokenizer, item[0])
|
||||||
|
tokens = out["tokens"]
|
||||||
|
masks = out["masks"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Extract prediction and ground truth
|
||||||
|
reply = item[0][-1]["content"]
|
||||||
|
ground_truth = item[1].strip().lower()
|
||||||
|
|
||||||
|
# Extract name from ground truth (format: "answer: name")
|
||||||
|
ground_truth_name = ground_truth.replace("answer:", "").strip()
|
||||||
|
|
||||||
|
# Extract name from \boxed{name} format
|
||||||
|
import re
|
||||||
|
boxed_match = re.search(r"\\boxed{([^}]+)}", reply)
|
||||||
|
if boxed_match:
|
||||||
|
prediction = boxed_match.group(1).strip().lower()
|
||||||
|
# Compare just the names
|
||||||
|
reward = 1.0 if prediction == ground_truth_name else -1.0
|
||||||
|
else:
|
||||||
|
# No boxed answer found
|
||||||
|
reward = -1.0
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error scoring response: {e}")
|
||||||
|
reward = -1.0
|
||||||
|
ground_truth = item[1] if isinstance(item[1], str) else ""
|
||||||
|
|
||||||
|
if len([i for i in masks if i != -100]) < 10:
|
||||||
|
continue
|
||||||
|
|
||||||
|
scores["tokens"].append(tokens)
|
||||||
|
scores["masks"].append(masks)
|
||||||
|
scores["scores"].append(reward)
|
||||||
|
scores["ground_truths"].append(ground_truth)
|
||||||
|
|
||||||
|
if len(scores["tokens"]) >= self.config.group_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not scores["tokens"]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
async def evaluate(self, *args, **kwargs):
|
||||||
|
"""No-op evaluation"""
|
||||||
|
return
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
UFCEnv.cli()
|
||||||