init commit
127
environments/hack0/ufc_env/README.md
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
# UFC Fight Prediction Environment
|
||||
|
||||
This environment provides a framework for training and evaluating AI models on UFC fight prediction tasks, with a unique twist: instead of traditional analytical predictions, it generates entertaining fight commentary that can be directly used with Text-to-Speech (TTS) models like DIA. The environment includes two main components: a text-based predictor and an image-based predictor, both designed to create engaging, broadcast-style fight commentary.
|
||||
|
||||
## Environment Design
|
||||
|
||||
### Core Components
|
||||
|
||||
1. **UFC Server (ufc_server.py)**
|
||||
- Text-based fight prediction environment
|
||||
- Generates dynamic, entertaining fight commentary
|
||||
- Uses fighter statistics and historical data
|
||||
- Outputs TTS-ready commentary with dramatic flair
|
||||
- Implements a scoring system for model evaluation
|
||||
|
||||
2. **UFC Image Environment (ufc_image_env.py)**
|
||||
- Visual-based fight prediction environment
|
||||
- Creates commentary based on fighter appearances
|
||||
- Implements multimodal prediction capabilities
|
||||
- Generates broadcast-style commentary from visual analysis
|
||||
- Includes image processing and base64 encoding utilities
|
||||
|
||||
### Data Structure
|
||||
|
||||
- **fighter_stats.csv**: Contains detailed statistics for each fighter including:
|
||||
- Win/Loss records
|
||||
- Physical attributes (height, weight, reach)
|
||||
- Performance metrics (strikes per minute, takedown accuracy, etc.)
|
||||
|
||||
- **large_dataset.csv**: Historical fight data including:
|
||||
- Fighter matchups
|
||||
- Fight outcomes
|
||||
- Event information
|
||||
|
||||
- **fighter_images/**: Directory containing fighter profile images
|
||||
- Images are stored in JPG format
|
||||
- Filenames follow slug format (e.g., "john-smith.jpg")
|
||||
|
||||
## Motivation
|
||||
|
||||
This environment was designed to transform traditional fight prediction into an engaging entertainment experience:
|
||||
|
||||
1. **Entertainment-First Approach**
|
||||
- Generates dynamic, broadcast-style fight commentary
|
||||
- Creates TTS-ready output for voice synthesis
|
||||
- Incorporates dramatic elements and commentator personalities
|
||||
- Makes fight prediction more engaging and accessible
|
||||
|
||||
2. **Statistical Analysis with Style**
|
||||
- Wraps technical analysis in entertaining commentary
|
||||
- Uses fight statistics to inform dramatic storytelling
|
||||
- Maintains prediction accuracy while being entertaining
|
||||
- Creates a more engaging way to present fight analysis
|
||||
|
||||
3. **Visual Storytelling**
|
||||
- Transforms visual analysis into engaging commentary
|
||||
- Creates dramatic narratives from fighter appearances
|
||||
- Makes technical observations more accessible
|
||||
- Generates TTS-compatible descriptions of visual elements
|
||||
|
||||
4. **Multimodal Entertainment**
|
||||
- Combines statistical and visual data for rich commentary
|
||||
- Creates cohesive narratives from multiple data sources
|
||||
- Generates engaging stories that work well with TTS
|
||||
- Makes technical analysis more accessible and fun
|
||||
|
||||
## Usage
|
||||
|
||||
1. Install dependencies:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. Prepare data:
|
||||
- Ensure fighter_stats.csv and large_dataset.csv are in the environment directory
|
||||
- Place fighter images in the fighter_images/ directory
|
||||
|
||||
3. Run the environment:
|
||||
- For text-based commentary: Use UFCEnv
|
||||
- For image-based commentary: Use UFCImageEnv
|
||||
|
||||
4. TTS Integration:
|
||||
- The generated commentary is formatted for direct use with TTS models
|
||||
- Includes dramatic pauses and emphasis markers
|
||||
- Contains natural speech patterns and commentator personalities
|
||||
- Ready for voice synthesis with models like DIA
|
||||
|
||||
## Example Runs
|
||||
|
||||
Here are some example runs demonstrating the environment in action:
|
||||
|
||||
- [Video Demo](https://youtu.be/C_hFe6TfQvU) - Watch the environment in action with real-time commentary generation
|
||||
- [Text-based Prediction Run](https://wandb.ai/edtheman/Atropos-environments_ufc_env/runs/rq5wfxgh?nw=nwuseredtheman) - Shows the environment generating commentary based on fighter statistics and historical data
|
||||
- [Image-based Prediction Run](https://wandb.ai/edtheman/Atropos-environments_ufc_env/runs/klw4m5of?nw=nwuseredtheman) - Demonstrates the environment creating commentary from visual analysis of fighter appearances
|
||||
|
||||
The key difference between these runs is their input modality:
|
||||
- The text-based run focuses on statistical analysis and historical data to generate commentary
|
||||
- The image-based run analyzes fighter appearances and visual characteristics to create engaging narratives
|
||||
|
||||
## Configuration
|
||||
|
||||
The environment can be configured through the following parameters:
|
||||
|
||||
- `fighter_stats_path`: Path to fighter statistics CSV
|
||||
- `fight_data_path`: Path to fight dataset CSV
|
||||
- `image_folder`: Path to fighter images directory
|
||||
- `max_steps`: Number of steps per prediction
|
||||
- `temperature`: Generation diversity parameter (affects commentary style)
|
||||
- `top_p`: Nucleus sampling parameter (affects commentary creativity)
|
||||
|
||||
## Scoring System
|
||||
|
||||
The environment implements a scoring system that evaluates predictions based on:
|
||||
- Accuracy of winner prediction
|
||||
- Entertainment value of the commentary
|
||||
- TTS compatibility and natural flow
|
||||
- Integration of statistical/visual data in an engaging way
|
||||
- Proper formatting for voice synthesis
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions are welcome! Please feel free to submit pull requests for:
|
||||
- New commentary styles and personalities
|
||||
- Enhanced TTS compatibility features
|
||||
- Additional dramatic elements
|
||||
- Improved entertainment value
|
||||
- Better integration with voice synthesis models
|
||||
3
environments/hack0/ufc_env/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from .ufc_server import UFCEnv, UFCEnvConfig
|
||||
|
||||
__all__ = ["UFCEnv", "UFCEnvConfig"]
|
||||
BIN
environments/hack0/ufc_env/fighter_images/matt-frevola.jpg
Normal file
|
After Width: | Height: | Size: 344 KiB |
BIN
environments/hack0/ufc_env/fighter_images/matt-hamill.jpg
Normal file
|
After Width: | Height: | Size: 167 KiB |
BIN
environments/hack0/ufc_env/fighter_images/randy-brown.jpg
Normal file
|
After Width: | Height: | Size: 414 KiB |
BIN
environments/hack0/ufc_env/fighter_images/tom-nolan.jpg
Normal file
|
After Width: | Height: | Size: 353 KiB |
BIN
environments/hack0/ufc_env/fighter_images/vagner-rocha.jpg
Normal file
|
After Width: | Height: | Size: 117 KiB |
BIN
environments/hack0/ufc_env/fighter_images/yanis-ghemmouri.jpg
Normal file
|
After Width: | Height: | Size: 360 KiB |
BIN
environments/hack0/ufc_env/fighter_images/zhalgas-zhumagulov.jpg
Normal file
|
After Width: | Height: | Size: 330 KiB |
BIN
environments/hack0/ufc_env/fighter_images/zviad-lazishvili.jpg
Normal file
|
After Width: | Height: | Size: 411 KiB |
2480
environments/hack0/ufc_env/fighter_stats.csv
Normal file
113
environments/hack0/ufc_env/get_images.py
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
import os
|
||||
import re
|
||||
import pandas as pd
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
import time
|
||||
import random
|
||||
|
||||
# 1. Load your CSV
|
||||
df = pd.read_csv("fighter_stats.csv")
|
||||
|
||||
# List of different user agents to rotate between
|
||||
USER_AGENTS = [
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.159 Safari/537.36',
|
||||
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:89.0) Gecko/20100101 Firefox/89.0',
|
||||
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.1.1 Safari/605.1.15',
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36 Edg/91.0.864.59',
|
||||
'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
|
||||
]
|
||||
|
||||
# List of different accept languages
|
||||
ACCEPT_LANGUAGES = [
|
||||
'en-US,en;q=0.9',
|
||||
'en-GB,en;q=0.9',
|
||||
'en-CA,en;q=0.9',
|
||||
'en-AU,en;q=0.9',
|
||||
'en-NZ,en;q=0.9',
|
||||
]
|
||||
|
||||
def get_random_headers():
|
||||
"""Generate a random set of headers for each request."""
|
||||
return {
|
||||
'User-Agent': random.choice(USER_AGENTS),
|
||||
'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8',
|
||||
'Accept-Language': random.choice(ACCEPT_LANGUAGES),
|
||||
'Connection': 'keep-alive',
|
||||
'Upgrade-Insecure-Requests': '1',
|
||||
'Cache-Control': 'max-age=0',
|
||||
'Sec-Fetch-Dest': 'document',
|
||||
'Sec-Fetch-Mode': 'navigate',
|
||||
'Sec-Fetch-Site': 'none',
|
||||
'Sec-Fetch-User': '?1',
|
||||
'DNT': '1',
|
||||
}
|
||||
|
||||
# 2. Generate UFC profile URLs
|
||||
def make_url(name):
|
||||
"""Turn a fighter name into a UFC athlete URL slug."""
|
||||
if not isinstance(name, str) or not name.strip():
|
||||
return None
|
||||
slug = name.lower().strip()
|
||||
slug = re.sub(r"[''']", "", slug) # Remove apostrophes
|
||||
slug = re.sub(r"[^a-z\s-]", "", slug) # Keep letters, spaces, hyphens
|
||||
slug = re.sub(r"\s+", "-", slug) # Spaces → hyphens
|
||||
return f"https://www.ufc.com/athlete/{slug}"
|
||||
|
||||
df["UFC_Profile_URL"] = df["name"].apply(make_url)
|
||||
|
||||
# 3. Prepare output folder
|
||||
output_folder = "fighter_images"
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
|
||||
# 4. Scrape & download each image
|
||||
def download_ufc_image(url, counter):
|
||||
"""Fetch the UFC athlete page, parse out the main profile image, and save it."""
|
||||
try:
|
||||
# Derive filename from URL slug
|
||||
slug = url.rstrip("/").split("/")[-1]
|
||||
filename = f"{slug}.jpg"
|
||||
path = os.path.join(output_folder, filename)
|
||||
|
||||
# Check if image already exists
|
||||
if os.path.exists(path):
|
||||
print(f"[i] Image already exists for {filename}")
|
||||
return
|
||||
|
||||
# Get fresh random headers for each request
|
||||
headers = get_random_headers()
|
||||
|
||||
resp = requests.get(url, headers=headers)
|
||||
resp.raise_for_status()
|
||||
soup = BeautifulSoup(resp.text, "html.parser")
|
||||
img_tag = soup.select_one("div.hero-profile__image-wrap img.hero-profile__image")
|
||||
if not img_tag or not img_tag.get("src"):
|
||||
print(f"[!] No image found at {url}")
|
||||
return
|
||||
img_url = img_tag["src"]
|
||||
|
||||
# Get fresh random headers for image download
|
||||
headers = get_random_headers()
|
||||
img_data = requests.get(img_url, headers=headers).content
|
||||
with open(path, "wb") as f:
|
||||
f.write(img_data)
|
||||
print(f"[✓] Saved {filename}")
|
||||
|
||||
# Add random delay between 1-3 seconds
|
||||
time.sleep(random.uniform(1, 3))
|
||||
|
||||
# After every 100 downloads, take a longer break
|
||||
if counter % 100 == 0:
|
||||
print(f"[i] Taking a 30-second break after {counter} downloads...")
|
||||
time.sleep(30)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[✗] Error at {url}: {e}")
|
||||
|
||||
# 5. Iterate and download
|
||||
counter = 0
|
||||
for url in df["UFC_Profile_URL"].dropna().unique():
|
||||
counter += 1
|
||||
download_ufc_image(url, counter)
|
||||
7440
environments/hack0/ufc_env/large_dataset.csv
Normal file
8
environments/hack0/ufc_env/requirements.txt
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
pydantic>=2.0.0
|
||||
Pillow>=9.0.0
|
||||
datasets>=2.0.0
|
||||
numpy>=1.21.0
|
||||
pandas>=1.3.0
|
||||
requests>=2.26.0
|
||||
aiohttp>=3.8.0
|
||||
python-dotenv>=0.19.0
|
||||
416
environments/hack0/ufc_env/templates/predictor.html
Normal file
|
|
@ -0,0 +1,416 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Fight Predictor</title>
|
||||
<script src="https://cdn.jsdelivr.net/npm/canvas-confetti@1.6.0/dist/confetti.browser.min.js"></script>
|
||||
<style>
|
||||
body {
|
||||
font-family: 'Arial', sans-serif;
|
||||
background-color: #1a1a1a;
|
||||
color: #ffffff;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
.container {
|
||||
max-width: 1200px;
|
||||
margin: 0 auto;
|
||||
}
|
||||
|
||||
h1 {
|
||||
text-align: center;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
|
||||
.fighters-container {
|
||||
display: flex;
|
||||
justify-content: space-around;
|
||||
margin: 20px 0;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.fighter {
|
||||
width: 300px;
|
||||
text-align: center;
|
||||
transition: all 0.5s ease;
|
||||
}
|
||||
|
||||
.fighter img {
|
||||
width: 100%;
|
||||
height: auto;
|
||||
border: 3px solid;
|
||||
border-radius: 10px;
|
||||
transition: all 0.5s ease;
|
||||
}
|
||||
|
||||
.red-corner img {
|
||||
border-color: #ff0000;
|
||||
}
|
||||
|
||||
.blue-corner img {
|
||||
border-color: #0000ff;
|
||||
}
|
||||
|
||||
.winner {
|
||||
transform: scale(1.2);
|
||||
z-index: 2;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.winner::after {
|
||||
content: "WINNER!";
|
||||
position: absolute;
|
||||
top: -30px;
|
||||
left: 50%;
|
||||
transform: translateX(-50%);
|
||||
color: #ffd700;
|
||||
font-size: 24px;
|
||||
font-weight: bold;
|
||||
text-shadow: 2px 2px 4px rgba(0,0,0,0.5);
|
||||
}
|
||||
|
||||
.loser {
|
||||
animation: explode 1s forwards;
|
||||
}
|
||||
|
||||
@keyframes explode {
|
||||
0% {
|
||||
transform: scale(1);
|
||||
opacity: 1;
|
||||
}
|
||||
50% {
|
||||
transform: scale(1.5);
|
||||
opacity: 0.5;
|
||||
}
|
||||
100% {
|
||||
transform: scale(0);
|
||||
opacity: 0;
|
||||
}
|
||||
}
|
||||
|
||||
.commentary {
|
||||
background-color: #2a2a2a;
|
||||
padding: 20px;
|
||||
border-radius: 10px;
|
||||
margin: 20px 0;
|
||||
min-height: 200px;
|
||||
white-space: pre-wrap;
|
||||
display: none;
|
||||
}
|
||||
|
||||
.typing {
|
||||
overflow: hidden;
|
||||
border-right: 2px solid #fff;
|
||||
white-space: nowrap;
|
||||
animation: typing 1s steps(40, end),
|
||||
blink-caret 0.75s step-end infinite;
|
||||
}
|
||||
|
||||
@keyframes typing {
|
||||
from { width: 0 }
|
||||
to { width: 100% }
|
||||
}
|
||||
|
||||
@keyframes blink-caret {
|
||||
from, to { border-color: transparent }
|
||||
50% { border-color: #fff }
|
||||
}
|
||||
|
||||
.upload-form {
|
||||
display: flex;
|
||||
justify-content: space-around;
|
||||
margin: 20px 0;
|
||||
}
|
||||
|
||||
.upload-section {
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
input[type="file"] {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.upload-btn {
|
||||
background-color: #333;
|
||||
color: white;
|
||||
padding: 10px 20px;
|
||||
border-radius: 5px;
|
||||
cursor: pointer;
|
||||
transition: background-color 0.3s;
|
||||
}
|
||||
|
||||
.upload-btn:hover {
|
||||
background-color: #444;
|
||||
}
|
||||
|
||||
.predict-btn {
|
||||
background-color: #ff0000;
|
||||
color: white;
|
||||
padding: 15px 30px;
|
||||
border: none;
|
||||
border-radius: 5px;
|
||||
cursor: pointer;
|
||||
font-size: 18px;
|
||||
margin: 20px auto;
|
||||
display: block;
|
||||
transition: background-color 0.3s;
|
||||
}
|
||||
|
||||
.predict-btn:hover {
|
||||
background-color: #cc0000;
|
||||
}
|
||||
|
||||
.loading {
|
||||
display: none;
|
||||
text-align: center;
|
||||
margin: 20px 0;
|
||||
font-size: 24px;
|
||||
}
|
||||
|
||||
.loading::before {
|
||||
content: "⚡";
|
||||
display: inline-block;
|
||||
animation: loading 0.5s infinite;
|
||||
margin-right: 10px;
|
||||
}
|
||||
|
||||
.loading::after {
|
||||
content: "Analyzing Fight...";
|
||||
display: inline-block;
|
||||
animation: pulse 1s infinite;
|
||||
}
|
||||
|
||||
@keyframes loading {
|
||||
0% { transform: rotate(0deg); }
|
||||
100% { transform: rotate(360deg); }
|
||||
}
|
||||
|
||||
@keyframes pulse {
|
||||
0% { opacity: 0.5; }
|
||||
50% { opacity: 1; }
|
||||
100% { opacity: 0.5; }
|
||||
}
|
||||
|
||||
.new-fight-btn {
|
||||
background-color: #4CAF50;
|
||||
color: white;
|
||||
padding: 15px 30px;
|
||||
border: none;
|
||||
border-radius: 5px;
|
||||
cursor: pointer;
|
||||
font-size: 18px;
|
||||
margin: 20px auto;
|
||||
display: none;
|
||||
transition: background-color 0.3s;
|
||||
}
|
||||
|
||||
.new-fight-btn:hover {
|
||||
background-color: #45a049;
|
||||
}
|
||||
|
||||
.buttons-container {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
gap: 20px;
|
||||
margin: 20px 0;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>🤼♂️ Fight Predictor 🤼♂️</h1>
|
||||
|
||||
<!-- Hidden iframe for airhorn -->
|
||||
<iframe id="airhorn" width="110" height="200" src="https://www.myinstants.com/instant/dj-airhorn/embed/" frameborder="0" scrolling="no" style="display: none;"></iframe>
|
||||
|
||||
<form id="predictForm" class="upload-form">
|
||||
<div class="upload-section">
|
||||
<label class="upload-btn">
|
||||
Upload Red Corner Fighter
|
||||
<input type="file" name="red_fighter" accept="image/*" required>
|
||||
</label>
|
||||
<div class="fighter red-corner">
|
||||
<img id="redPreview" src="" alt="Red Corner Fighter" style="display: none;">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="upload-section">
|
||||
<label class="upload-btn">
|
||||
Upload Blue Corner Fighter
|
||||
<input type="file" name="blue_fighter" accept="image/*" required>
|
||||
</label>
|
||||
<div class="fighter blue-corner">
|
||||
<img id="bluePreview" src="" alt="Blue Corner Fighter" style="display: none;">
|
||||
</div>
|
||||
</div>
|
||||
</form>
|
||||
|
||||
<div class="buttons-container">
|
||||
<button class="predict-btn" onclick="predictFight()">Predict Fight!</button>
|
||||
<button class="new-fight-btn" onclick="resetFight()">New Fight</button>
|
||||
</div>
|
||||
|
||||
<div class="loading" id="loading"></div>
|
||||
<div class="commentary" id="commentary"></div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
function previewImage(input, previewId) {
|
||||
const preview = document.getElementById(previewId);
|
||||
if (input.files && input.files[0]) {
|
||||
const reader = new FileReader();
|
||||
reader.onload = function(e) {
|
||||
preview.src = e.target.result;
|
||||
preview.style.display = 'block';
|
||||
}
|
||||
reader.readAsDataURL(input.files[0]);
|
||||
}
|
||||
}
|
||||
|
||||
document.querySelector('input[name="red_fighter"]').addEventListener('change', function() {
|
||||
previewImage(this, 'redPreview');
|
||||
});
|
||||
|
||||
document.querySelector('input[name="blue_fighter"]').addEventListener('change', function() {
|
||||
previewImage(this, 'bluePreview');
|
||||
});
|
||||
|
||||
function playAirhorn() {
|
||||
const iframe = document.getElementById('airhorn');
|
||||
iframe.contentWindow.postMessage('play', '*');
|
||||
}
|
||||
|
||||
function triggerConfetti() {
|
||||
const duration = 3 * 1000;
|
||||
const animationEnd = Date.now() + duration;
|
||||
const defaults = { startVelocity: 30, spread: 360, ticks: 60, zIndex: 0 };
|
||||
|
||||
function randomInRange(min, max) {
|
||||
return Math.random() * (max - min) + min;
|
||||
}
|
||||
|
||||
const interval = setInterval(function() {
|
||||
const timeLeft = animationEnd - Date.now();
|
||||
|
||||
if (timeLeft <= 0) {
|
||||
return clearInterval(interval);
|
||||
}
|
||||
|
||||
const particleCount = 50 * (timeLeft / duration);
|
||||
|
||||
// since particles fall down, start a bit higher than random
|
||||
confetti({
|
||||
...defaults,
|
||||
particleCount,
|
||||
origin: { x: randomInRange(0.1, 0.3), y: Math.random() - 0.2 }
|
||||
});
|
||||
confetti({
|
||||
...defaults,
|
||||
particleCount,
|
||||
origin: { x: randomInRange(0.7, 0.9), y: Math.random() - 0.2 }
|
||||
});
|
||||
}, 250);
|
||||
}
|
||||
|
||||
function typeText(element, text, speed = 30) {
|
||||
let i = 0;
|
||||
element.innerHTML = '';
|
||||
element.style.display = 'block';
|
||||
|
||||
function type() {
|
||||
if (i < text.length) {
|
||||
element.innerHTML += text.charAt(i);
|
||||
i++;
|
||||
setTimeout(type, speed);
|
||||
} else {
|
||||
// After typing is complete, check for winner
|
||||
const winnerMatch = text.match(/\\boxed{(Red|Blue)}/);
|
||||
if (winnerMatch) {
|
||||
const winner = winnerMatch[1];
|
||||
const redFighter = document.querySelector('.red-corner');
|
||||
const blueFighter = document.querySelector('.blue-corner');
|
||||
|
||||
if (winner === 'Red') {
|
||||
redFighter.classList.add('winner');
|
||||
blueFighter.classList.add('loser');
|
||||
// Trigger effects from red corner
|
||||
triggerConfetti();
|
||||
playAirhorn();
|
||||
} else {
|
||||
blueFighter.classList.add('winner');
|
||||
redFighter.classList.add('loser');
|
||||
// Trigger effects from blue corner
|
||||
triggerConfetti();
|
||||
playAirhorn();
|
||||
}
|
||||
// Show new fight button after prediction is complete
|
||||
document.querySelector('.new-fight-btn').style.display = 'block';
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type();
|
||||
}
|
||||
|
||||
function resetFight() {
|
||||
// Reset images
|
||||
document.getElementById('redPreview').style.display = 'none';
|
||||
document.getElementById('redPreview').src = '';
|
||||
document.getElementById('bluePreview').style.display = 'none';
|
||||
document.getElementById('bluePreview').src = '';
|
||||
|
||||
// Reset form
|
||||
document.getElementById('predictForm').reset();
|
||||
|
||||
// Reset classes
|
||||
document.querySelector('.red-corner').classList.remove('winner', 'loser');
|
||||
document.querySelector('.blue-corner').classList.remove('winner', 'loser');
|
||||
|
||||
// Clear and hide commentary
|
||||
const commentary = document.getElementById('commentary');
|
||||
commentary.innerHTML = '';
|
||||
commentary.style.display = 'none';
|
||||
|
||||
// Hide new fight button
|
||||
document.querySelector('.new-fight-btn').style.display = 'none';
|
||||
}
|
||||
|
||||
async function predictFight() {
|
||||
const form = document.getElementById('predictForm');
|
||||
const formData = new FormData(form);
|
||||
|
||||
const loading = document.getElementById('loading');
|
||||
const commentary = document.getElementById('commentary');
|
||||
const predictBtn = document.querySelector('.predict-btn');
|
||||
|
||||
loading.style.display = 'block';
|
||||
commentary.style.display = 'none';
|
||||
predictBtn.disabled = true;
|
||||
|
||||
try {
|
||||
const response = await fetch('/predict', {
|
||||
method: 'POST',
|
||||
body: formData
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (data.success) {
|
||||
typeText(commentary, data.prediction);
|
||||
} else {
|
||||
commentary.style.display = 'block';
|
||||
commentary.innerHTML = `Error: ${data.error}`;
|
||||
}
|
||||
} catch (error) {
|
||||
commentary.style.display = 'block';
|
||||
commentary.innerHTML = `Error: ${error.message}`;
|
||||
} finally {
|
||||
loading.style.display = 'none';
|
||||
predictBtn.disabled = false;
|
||||
}
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
326
environments/hack0/ufc_env/ufc_image_env.py
Normal file
|
|
@ -0,0 +1,326 @@
|
|||
import os
|
||||
import random
|
||||
import sys
|
||||
import traceback
|
||||
import csv
|
||||
from typing import List, Optional, Tuple, Any, Dict
|
||||
import base64
|
||||
from PIL import Image
|
||||
import io
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
|
||||
from atroposlib.type_definitions import GameHistory, Item
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
|
||||
class UFCImageEnvConfig(BaseEnvConfig):
|
||||
"""Configuration for the UFC Image Environment"""
|
||||
fighter_stats_path: str = Field(os.path.join(os.path.dirname(__file__), "fighter_stats.csv"), description="Path to fighter stats CSV")
|
||||
fight_data_path: str = Field(os.path.join(os.path.dirname(__file__), "large_dataset.csv"), description="Path to large fight dataset CSV")
|
||||
image_folder: str = Field(os.path.join(os.path.dirname(__file__), "fighter_images"), description="Path to fighter images folder")
|
||||
max_steps: int = Field(1, description="Only one step per fight prediction")
|
||||
temperature: float = Field(0.7, description="Temperature for generation diversity")
|
||||
top_p: float = Field(0.95, description="Top p for nucleus sampling")
|
||||
|
||||
|
||||
class UFCImageEnv(BaseEnv):
|
||||
"""UFC Fight Prediction Environment using only fighter images"""
|
||||
name = "ufc_image_predictor"
|
||||
env_config_cls = UFCImageEnvConfig
|
||||
|
||||
def __init__(self, config: UFCImageEnvConfig, server_configs: List[OpenaiConfig], slurm=True, testing=False):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.fighter_stats = {}
|
||||
self.fight_data = []
|
||||
self.current_index = 0
|
||||
self.inference_server = self.server.servers[0] # Get first server as inference server
|
||||
|
||||
async def setup(self):
|
||||
"""Load the fighter stats and fight data"""
|
||||
try:
|
||||
print("Loading fighter stats from:", self.config.fighter_stats_path)
|
||||
with open(self.config.fighter_stats_path, encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
self.fighter_stats = {row["name"]: row for row in reader}
|
||||
print(f"Loaded stats for {len(self.fighter_stats)} fighters")
|
||||
|
||||
print("Loading fight data from:", self.config.fight_data_path)
|
||||
with open(self.config.fight_data_path, encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
self.fight_data = list(reader)
|
||||
print(f"Loaded {len(self.fight_data)} fights")
|
||||
|
||||
# Filter out fights where either fighter's image is missing
|
||||
filtered_fights = []
|
||||
missing_images = set() # Track unique missing images
|
||||
for fight in self.fight_data:
|
||||
r_fighter = fight["r_fighter"]
|
||||
b_fighter = fight["b_fighter"]
|
||||
|
||||
# Convert names to image filename format
|
||||
r_slug = r_fighter.lower().replace(" ", "-")
|
||||
b_slug = b_fighter.lower().replace(" ", "-")
|
||||
|
||||
r_image_path = os.path.join(self.config.image_folder, f"{r_slug}.jpg")
|
||||
b_image_path = os.path.join(self.config.image_folder, f"{b_slug}.jpg")
|
||||
|
||||
if os.path.exists(r_image_path) and os.path.exists(b_image_path):
|
||||
filtered_fights.append(fight)
|
||||
else:
|
||||
if not os.path.exists(r_image_path):
|
||||
missing_images.add(r_fighter)
|
||||
if not os.path.exists(b_image_path):
|
||||
missing_images.add(b_fighter)
|
||||
|
||||
if missing_images:
|
||||
print(f"\nMissing images for {len(missing_images)} fighters. These fights will be skipped.")
|
||||
|
||||
self.fight_data = filtered_fights
|
||||
print(f"Filtered to {len(self.fight_data)} fights with complete image sets")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading data: {e}")
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
def get_fighter_image(self, fighter_name):
|
||||
"""Convert fighter name to image path and return base64 encoded image"""
|
||||
try:
|
||||
# Convert name to slug format
|
||||
slug = fighter_name.lower().replace(" ", "-")
|
||||
|
||||
image_path = os.path.join(self.config.image_folder, f"{slug}.jpg")
|
||||
if not os.path.exists(image_path):
|
||||
return None
|
||||
|
||||
# Convert image to base64
|
||||
with Image.open(image_path) as img:
|
||||
# Convert RGBA to RGB if necessary
|
||||
if img.mode == 'RGBA':
|
||||
img = img.convert('RGB')
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG")
|
||||
image_bytes = buf.getvalue()
|
||||
return base64.b64encode(image_bytes).decode("utf-8")
|
||||
except Exception as e:
|
||||
print(f"Error getting image for {fighter_name}: {e}")
|
||||
return None
|
||||
|
||||
async def get_next_item(self) -> Optional[Item]:
|
||||
"""Get the next fight from the dataset"""
|
||||
try:
|
||||
if self.current_index >= len(self.fight_data):
|
||||
return None
|
||||
fight = self.fight_data[self.current_index]
|
||||
self.current_index += 1
|
||||
|
||||
r_fighter = fight["r_fighter"]
|
||||
b_fighter = fight["b_fighter"]
|
||||
|
||||
# Get base64 encoded images
|
||||
r_image = self.get_fighter_image(r_fighter)
|
||||
b_image = self.get_fighter_image(b_fighter)
|
||||
|
||||
if not r_image or not b_image:
|
||||
print(f"Skipping fight {self.current_index} due to missing images")
|
||||
return None
|
||||
|
||||
# Format the prompt with images
|
||||
prompt_text = (
|
||||
"🎤 LADIES AND GENTLEMEN! Welcome to the most electrifying show in sports entertainment "
|
||||
"Let's break down this matchup that's got everyone talking!\n\n"
|
||||
"In the red corner, we have:(YOUR FIRST IMAGE):\n"
|
||||
"And in the blue corner: (YOUR SECOND IMAGE):\n\n"
|
||||
"Now, act as your favorite fight comentator, I want you to:\n"
|
||||
"create a fight commentary of whats happening in the fight live\n"
|
||||
"Give us your best fight commentary! Make it exciting, make it dramatic, make it sound like you're calling the fight live! "
|
||||
"Throw in some classic commentator phrases, maybe a 'OH MY GOODNESS!' or two, and definitely some dramatic pauses for effect.\n\n"
|
||||
"End your masterpiece with your prediction in this exact format:\n"
|
||||
"[S1]Hello im your host [S2] And so am i (name) [S1] Wow. Amazing. (laughs) [S2] Lets get started! (coughs)\n\n"
|
||||
"The winner should always be annouced with"
|
||||
"\\boxed{Red} or \\boxed{Blue}"
|
||||
"Or you will receive a score of -1.0"
|
||||
)
|
||||
|
||||
# Create multimodal prompt with images
|
||||
prompt = tuple([
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt_text},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{r_image}"}
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{b_image}"}
|
||||
}
|
||||
]
|
||||
}
|
||||
])
|
||||
|
||||
winner = fight.get("winner", "") # Red or Blue
|
||||
ground_truth = f"Answer: {winner}" if winner else ""
|
||||
|
||||
return (prompt, ground_truth, None)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in get_next_item: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
async def collect_trajectories(self, item: Item) -> Tuple[List[Tuple[GameHistory, str, Optional[str]]], List[Item]]:
|
||||
to_score = []
|
||||
to_backlog = []
|
||||
|
||||
system_msg = {
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are an expert MMA analyst. You will be given two UFC fighters' images. "
|
||||
"Your task is to predict the winner of the fight based on their appearance and physique.\n\n"
|
||||
"IMPORTANT: You MUST format your response in exactly two parts:\n"
|
||||
"1. First, analyze the fighters' appearances and create a fight commentary\n"
|
||||
"2. Then on a new line, give ONLY your final prediction in this exact format:\n"
|
||||
"\\boxed{Red} or \\boxed{Blue}\n\n"
|
||||
"For example:\n"
|
||||
"After analyzing the fighters' appearances... [your analysis here]\n"
|
||||
"\\boxed{Red}\n\n"
|
||||
"If you do not end your response with the \\boxed{} format containing either 'Red' or 'Blue', you will receive a score of -1.0."
|
||||
)
|
||||
}
|
||||
|
||||
user_msg = {
|
||||
"role": "user",
|
||||
"content": dict(item[0][0])["content"]
|
||||
}
|
||||
|
||||
messages = [system_msg, user_msg]
|
||||
|
||||
try:
|
||||
chat_completions = await self.inference_server.chat_completion(
|
||||
messages=messages,
|
||||
n=self.config.group_size,
|
||||
max_tokens=2048,
|
||||
temperature=self.config.temperature,
|
||||
top_p=self.config.top_p,
|
||||
timeout=60,
|
||||
)
|
||||
for choice in chat_completions.choices:
|
||||
assistant_msg = {"role": "assistant", "content": choice.message.content}
|
||||
history = [
|
||||
{"role": "system", "content": system_msg["content"]},
|
||||
{"role": "user", "content": user_msg["content"]},
|
||||
{"role": "assistant", "content": choice.message.content}
|
||||
]
|
||||
to_score.append((history, item[1], None))
|
||||
except Exception as e:
|
||||
print(f"Error in collect_trajectories: {e}")
|
||||
traceback.print_exc()
|
||||
to_backlog.append(item)
|
||||
|
||||
if not to_score:
|
||||
return None, to_backlog
|
||||
|
||||
scored_data = await self.score(to_score)
|
||||
return scored_data, to_backlog
|
||||
|
||||
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
|
||||
if not rollout_group_data:
|
||||
return None
|
||||
|
||||
scores = ScoredDataGroup()
|
||||
scores["tokens"] = []
|
||||
scores["masks"] = []
|
||||
scores["scores"] = []
|
||||
scores["advantages"] = None
|
||||
scores["ref_logprobs"] = None
|
||||
scores["messages"] = None
|
||||
scores["group_overrides"] = {"group_size": self.config.group_size}
|
||||
scores["overrides"] = None
|
||||
scores["ground_truths"] = []
|
||||
|
||||
random.shuffle(rollout_group_data)
|
||||
for item in rollout_group_data:
|
||||
out = tokenize_for_trainer(self.tokenizer, item[0])
|
||||
tokens = out["tokens"]
|
||||
masks = out["masks"]
|
||||
|
||||
try:
|
||||
# Extract prediction and ground truth
|
||||
reply = item[0][-1]["content"]
|
||||
ground_truth = item[1].strip().lower()
|
||||
|
||||
# Extract color from ground truth (format: "answer: color")
|
||||
ground_truth_color = ground_truth.replace("answer:", "").strip()
|
||||
|
||||
# Extract color from \boxed{color} format
|
||||
import re
|
||||
boxed_match = re.search(r"\\boxed{([^}]+)}", reply)
|
||||
if boxed_match:
|
||||
prediction = boxed_match.group(1).strip().lower()
|
||||
# Compare just the colors
|
||||
reward = 1.0 if prediction == ground_truth_color else -1.0
|
||||
else:
|
||||
# No boxed answer found
|
||||
reward = -1.0
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error scoring response: {e}")
|
||||
reward = -1.0
|
||||
ground_truth = item[1] if isinstance(item[1], str) else ""
|
||||
|
||||
if len([i for i in masks if i != -100]) < 10:
|
||||
continue
|
||||
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
scores["scores"].append(reward)
|
||||
scores["ground_truths"].append(ground_truth)
|
||||
|
||||
if len(scores["tokens"]) >= self.config.group_size:
|
||||
break
|
||||
|
||||
if not scores["tokens"]:
|
||||
return None
|
||||
|
||||
return scores
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""No-op evaluation"""
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
|
||||
"""Initialize configuration for the environment"""
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
print("ERROR: OPENAI_API_KEY environment variable is not set!")
|
||||
sys.exit(1)
|
||||
|
||||
config = UFCImageEnvConfig(
|
||||
wandb_name="ufc_image",
|
||||
tokenizer_name="gpt2",
|
||||
group_size=2,
|
||||
use_wandb=False,
|
||||
max_num_workers=2,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=1,
|
||||
steps_per_eval=10,
|
||||
ensure_scores_are_not_same=False,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
OpenaiConfig(
|
||||
model_name="gpt-4o",
|
||||
base_url=None,
|
||||
api_key=os.environ.get("OPENAI_API_KEY"),
|
||||
num_requests_for_eval=1,
|
||||
),
|
||||
]
|
||||
|
||||
return config, server_configs
|
||||
|
||||
if __name__ == "__main__":
|
||||
UFCImageEnv.cli()
|
||||
104
environments/hack0/ufc_env/ufc_predictor_ui.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
import os
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from flask import Flask, render_template, request, jsonify
|
||||
from PIL import Image
|
||||
import openai
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
app = Flask(__name__)
|
||||
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size
|
||||
|
||||
# Initialize OpenAI client
|
||||
client = openai.OpenAI(api_key=os.getenv('OPENAI_API_KEY'))
|
||||
|
||||
def process_image(image_file):
|
||||
"""Convert uploaded image to base64"""
|
||||
img = Image.open(image_file)
|
||||
# Convert RGBA to RGB if necessary
|
||||
if img.mode == 'RGBA':
|
||||
img = img.convert('RGB')
|
||||
buffered = BytesIO()
|
||||
img.save(buffered, format="JPEG")
|
||||
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||
|
||||
@app.route('/')
|
||||
def home():
|
||||
return render_template('predictor.html')
|
||||
|
||||
@app.route('/predict', methods=['POST'])
|
||||
def predict():
|
||||
try:
|
||||
# Get uploaded images
|
||||
red_fighter = request.files['red_fighter']
|
||||
blue_fighter = request.files['blue_fighter']
|
||||
|
||||
if not red_fighter or not blue_fighter:
|
||||
return jsonify({'error': 'Please upload both fighter images'}), 400
|
||||
|
||||
# Process images to base64
|
||||
red_image = process_image(red_fighter)
|
||||
blue_image = process_image(blue_fighter)
|
||||
|
||||
# Create the prompt
|
||||
prompt_text = (
|
||||
"🎤 LADIES AND GENTLEMEN! Welcome to the most electrifying show in sports entertainment "
|
||||
"Let's break down this matchup that's got everyone talking!\n\n"
|
||||
"In the red corner, we have:(YOUR FIRST IMAGE):\n"
|
||||
"And in the blue corner: (YOUR SECOND IMAGE):\n\n"
|
||||
"Now, as your favorite fight comentator, I want you to:\n"
|
||||
"create a fight commentary of whats happening in the fight live\n"
|
||||
"Give us your best fight commentary! Make it exciting, make it dramatic, make it sound like you're calling the fight live! "
|
||||
"Throw in some classic commentator phrases, maybe a 'OH MY GOODNESS!' or two, and definitely some dramatic pauses for effect.\n\n"
|
||||
"End your masterpiece with your prediction in this exact format:\n"
|
||||
"\\boxed{Red} or \\boxed{Blue}"
|
||||
"PLEASE FORMAT THE COMMENTARY IN THE EXACT FORMAT AS THE EXAMPLE BELOW:\n"
|
||||
"[S1]Hello im your host [S2] And so am i (name) [S1] Wow. Amazing. (laughs) [S2] Lets get started! (coughs) ( add lots of coughs and laughs)\n\n"
|
||||
)
|
||||
|
||||
# Create the messages for the API call
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt_text},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{red_image}"}
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{blue_image}"}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Make the API call
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
messages=messages,
|
||||
max_tokens=2048,
|
||||
temperature=0.7,
|
||||
top_p=0.95
|
||||
)
|
||||
|
||||
# Extract the prediction
|
||||
prediction = response.choices[0].message.content
|
||||
|
||||
return jsonify({
|
||||
'prediction': prediction,
|
||||
'success': True
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({
|
||||
'error': str(e),
|
||||
'success': False
|
||||
}), 500
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(debug=True)
|
||||
233
environments/hack0/ufc_env/ufc_server.py
Normal file
|
|
@ -0,0 +1,233 @@
|
|||
import os
|
||||
import random
|
||||
import sys
|
||||
import traceback
|
||||
import csv
|
||||
from typing import List, Optional, Tuple, Any, Dict
|
||||
|
||||
from datasets import load_dataset
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
|
||||
from atroposlib.type_definitions import GameHistory, Item
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
|
||||
class UFCEnvConfig(BaseEnvConfig):
|
||||
"""Configuration for the UFC Environment"""
|
||||
fighter_stats_path: str = Field(os.path.join(os.path.dirname(__file__), "fighter_stats.csv"), description="Path to fighter stats CSV")
|
||||
fight_data_path: str = Field(os.path.join(os.path.dirname(__file__), "large_dataset.csv"), description="Path to large fight dataset CSV")
|
||||
max_steps: int = Field(1, description="Only one step per fight prediction")
|
||||
temperature: float = Field(0.7, description="Temperature for generation diversity")
|
||||
top_p: float = Field(0.95, description="Top p for nucleus sampling")
|
||||
|
||||
|
||||
class UFCEnv(BaseEnv):
|
||||
"""UFC Fight Prediction Environment"""
|
||||
name = "ufc_predictor"
|
||||
env_config_cls = UFCEnvConfig
|
||||
|
||||
def __init__(self, config: UFCEnvConfig, server_configs: List[OpenaiConfig], slurm=True, testing=False):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.fighter_stats = {}
|
||||
self.fight_data = []
|
||||
self.current_index = 0
|
||||
self.inference_server = self.server.servers[0] # Get first server as inference server
|
||||
|
||||
async def setup(self):
|
||||
"""Load the fighter stats and fight data"""
|
||||
try:
|
||||
print("Loading fighter stats from:", self.config.fighter_stats_path)
|
||||
with open(self.config.fighter_stats_path, encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
self.fighter_stats = {row["name"]: row for row in reader}
|
||||
print(f"Loaded stats for {len(self.fighter_stats)} fighters")
|
||||
|
||||
print("Loading fight data from:", self.config.fight_data_path)
|
||||
with open(self.config.fight_data_path, encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
self.fight_data = list(reader)
|
||||
print(f"Loaded {len(self.fight_data)} fights")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading data: {e}")
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
async def get_next_item(self) -> Optional[Item]:
|
||||
"""Get the next fight from the dataset"""
|
||||
try:
|
||||
if self.current_index >= len(self.fight_data):
|
||||
return None
|
||||
fight = self.fight_data[self.current_index]
|
||||
self.current_index += 1
|
||||
|
||||
r_fighter = fight["r_fighter"]
|
||||
b_fighter = fight["b_fighter"]
|
||||
r_stats = self.fighter_stats.get(r_fighter, {})
|
||||
b_stats = self.fighter_stats.get(b_fighter, {})
|
||||
|
||||
# Format the prompt
|
||||
def stats_str(name, stats):
|
||||
if not stats:
|
||||
return f"{name}: (No stats available)"
|
||||
return (
|
||||
f"Name: {name}\n"
|
||||
f"Wins: {stats.get('wins','?')} Losses: {stats.get('losses','?')} Age: {stats.get('age','?')}\n"
|
||||
f"Height: {stats.get('height','?')} cm Weight: {stats.get('weight','?')} kg Reach: {stats.get('reach','?')} cm Stance: {stats.get('stance','?')}\n"
|
||||
f"SLpM: {stats.get('SLpM','?')} Sig Str Acc: {stats.get('sig_str_acc','?')} SApM: {stats.get('SApM','?')} Str Def: {stats.get('str_def','?')}\n"
|
||||
f"TD Avg: {stats.get('td_avg','?')} TD Acc: {stats.get('td_acc','?')} TD Def: {stats.get('td_def','?')} Sub Avg: {stats.get('sub_avg','?')}\n"
|
||||
)
|
||||
|
||||
prompt_text = (
|
||||
"🎤 LADIES AND GENTLEMEN! Welcome to the most electrifying show in sports entertainment - the UFC Fight Prediction Show! "
|
||||
"Let's break down this matchup that's got everyone talking!\n\n"
|
||||
f"*Drumroll please* In the red corner, we have :\n{stats_str(r_fighter, r_stats)}\n\n"
|
||||
f"And in the blue corner:\n{stats_str(b_fighter, b_stats)}\n\n"
|
||||
"Now, as your favorite fight analyst who's definitely not just making this up as I go along, I want you to:\n"
|
||||
"1. Break down these fighters like you're explaining why your favorite TV show character would win in a fight\n"
|
||||
"2. Compare their styles\n"
|
||||
"3. Point out their advantages\n"
|
||||
"Give us your best fight commentary! Make it exciting, make it dramatic, make it sound like you're calling the fight live! "
|
||||
"Throw in some classic commentator phrases, maybe a 'OH MY GOODNESS!' or two, and definitely some dramatic pauses for effect.\n\n"
|
||||
"End your masterpiece with the winner's name in this exact format:\n"
|
||||
"\\boxed{fighter name}"
|
||||
)
|
||||
|
||||
prompt = tuple([
|
||||
frozenset({"role": "user", "content": prompt_text}.items())
|
||||
])
|
||||
|
||||
winner = fight.get("winner", "") # Red or Blue
|
||||
winner_name = r_fighter if winner == "Red" else b_fighter if winner == "Blue" else ""
|
||||
ground_truth = f"Answer: {winner_name}" if winner_name else ""
|
||||
|
||||
return (prompt, ground_truth, None)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in get_next_item: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
async def collect_trajectories(self, item: Item) -> Tuple[List[Tuple[GameHistory, str, Optional[str]]], List[Item]]:
|
||||
to_score = []
|
||||
to_backlog = []
|
||||
|
||||
system_msg = {
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are an expert MMA analyst. You will be given two UFC fighters and their stats. "
|
||||
"Your task is to predict the winner of the fight based on their statistics.\n\n"
|
||||
"IMPORTANT: You MUST format your response in exactly two parts:\n"
|
||||
"1. First, analyze the fighters' stats and explain create a fight commentary\n"
|
||||
"2. Then on a new line, give ONLY your final prediction in this exact format:\n"
|
||||
"\\boxed{fighter name}\n\n"
|
||||
"For example:\n"
|
||||
"After analyzing stats... [your analysis here]\n"
|
||||
"\\boxed{John Smith}\n\n"
|
||||
"If you do not end your response with the \\boxed{} format, you will receive a score of -1.0."
|
||||
)
|
||||
}
|
||||
|
||||
user_msg = {
|
||||
"role": "user",
|
||||
"content": dict(item[0][0])["content"]
|
||||
}
|
||||
|
||||
messages = [system_msg, user_msg]
|
||||
|
||||
try:
|
||||
chat_completions = await self.inference_server.chat_completion(
|
||||
messages=messages,
|
||||
n=self.config.group_size,
|
||||
max_tokens=2048, # Increased from 512 to allow for longer, more detailed fight commentaries
|
||||
temperature=self.config.temperature,
|
||||
top_p=self.config.top_p,
|
||||
timeout=60,
|
||||
)
|
||||
for choice in chat_completions.choices:
|
||||
assistant_msg = {"role": "assistant", "content": choice.message.content}
|
||||
history = [
|
||||
{"role": "system", "content": system_msg["content"]},
|
||||
{"role": "user", "content": user_msg["content"]},
|
||||
{"role": "assistant", "content": choice.message.content}
|
||||
]
|
||||
to_score.append((history, item[1], None))
|
||||
except Exception as e:
|
||||
print(f"Error in collect_trajectories: {e}")
|
||||
traceback.print_exc()
|
||||
to_backlog.append(item)
|
||||
|
||||
if not to_score:
|
||||
return None, to_backlog
|
||||
|
||||
scored_data = await self.score(to_score)
|
||||
return scored_data, to_backlog
|
||||
|
||||
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
|
||||
if not rollout_group_data:
|
||||
return None
|
||||
|
||||
scores = ScoredDataGroup()
|
||||
scores["tokens"] = []
|
||||
scores["masks"] = []
|
||||
scores["scores"] = []
|
||||
scores["advantages"] = None
|
||||
scores["ref_logprobs"] = None
|
||||
scores["messages"] = None
|
||||
scores["group_overrides"] = {"group_size": self.config.group_size}
|
||||
scores["overrides"] = None
|
||||
scores["ground_truths"] = []
|
||||
|
||||
random.shuffle(rollout_group_data)
|
||||
for item in rollout_group_data:
|
||||
out = tokenize_for_trainer(self.tokenizer, item[0])
|
||||
tokens = out["tokens"]
|
||||
masks = out["masks"]
|
||||
|
||||
try:
|
||||
# Extract prediction and ground truth
|
||||
reply = item[0][-1]["content"]
|
||||
ground_truth = item[1].strip().lower()
|
||||
|
||||
# Extract name from ground truth (format: "answer: name")
|
||||
ground_truth_name = ground_truth.replace("answer:", "").strip()
|
||||
|
||||
# Extract name from \boxed{name} format
|
||||
import re
|
||||
boxed_match = re.search(r"\\boxed{([^}]+)}", reply)
|
||||
if boxed_match:
|
||||
prediction = boxed_match.group(1).strip().lower()
|
||||
# Compare just the names
|
||||
reward = 1.0 if prediction == ground_truth_name else -1.0
|
||||
else:
|
||||
# No boxed answer found
|
||||
reward = -1.0
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error scoring response: {e}")
|
||||
reward = -1.0
|
||||
ground_truth = item[1] if isinstance(item[1], str) else ""
|
||||
|
||||
if len([i for i in masks if i != -100]) < 10:
|
||||
continue
|
||||
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
scores["scores"].append(reward)
|
||||
scores["ground_truths"].append(ground_truth)
|
||||
|
||||
if len(scores["tokens"]) >= self.config.group_size:
|
||||
break
|
||||
|
||||
if not scores["tokens"]:
|
||||
return None
|
||||
|
||||
return scores
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""No-op evaluation"""
|
||||
return
|
||||
|
||||
if __name__ == "__main__":
|
||||
UFCEnv.cli()
|
||||