mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-28 17:29:30 +00:00
MMMU, MMMU-Pro, MMBench, MMStar, AI2D, MMVP, OCRBench, MMVet, CountBench, POPE, HallusionBench, DynaMath, MMT-Bench, SEED-Bench2, BLINK, and VLMBlind evals
This commit is contained in:
parent
75de490849
commit
22884d2bf7
16 changed files with 2748 additions and 0 deletions
167
environments/eval_environments/ai2d_environment.py
Normal file
167
environments/eval_environments/ai2d_environment.py
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
"""AI2D (AI2 Diagrams) evaluation environment."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
from string import ascii_uppercase
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from datasets import load_dataset
|
||||
from openai import AsyncOpenAI
|
||||
from PIL import Image
|
||||
|
||||
from environments.eval_environments.eval_base import EvalBase, eval_runner
|
||||
from environments.eval_environments.eval_helpers import (
|
||||
extract_letter_from_answer_tag,
|
||||
extract_mcqa_answer_with_fallback,
|
||||
)
|
||||
|
||||
|
||||
class AI2D(EvalBase):
|
||||
"""AI2D evaluation - diagram understanding benchmark."""
|
||||
|
||||
def setup_data(self) -> list:
|
||||
split = getattr(self, "split", "test")
|
||||
use_mask = getattr(self, "use_mask", True)
|
||||
|
||||
try:
|
||||
dataset = load_dataset("lmms-lab/ai2d", split=split)
|
||||
print(f"Loaded {len(dataset)} examples from AI2D ({split})")
|
||||
return list(dataset)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not load AI2D: {e}")
|
||||
try:
|
||||
dataset = load_dataset("allenai/ai2_diagrams", split=split)
|
||||
print(f"Loaded {len(dataset)} examples from AI2D ({split})")
|
||||
return list(dataset)
|
||||
except Exception:
|
||||
raise ValueError(f"Could not load AI2D dataset: {e}")
|
||||
|
||||
def encode_image(self, pil_image: Image.Image) -> str:
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
def get_image_base64(self, item: dict) -> Optional[str]:
|
||||
for key in ["image", "decoded_image"]:
|
||||
if key in item and item[key] is not None:
|
||||
if isinstance(item[key], Image.Image):
|
||||
return self.encode_image(item[key])
|
||||
return None
|
||||
|
||||
def build_messages(self, item: dict) -> List[dict]:
|
||||
image_base64 = self.get_image_base64(item)
|
||||
question = item.get("question", "")
|
||||
|
||||
choices = item.get("choices", [])
|
||||
if isinstance(choices, str):
|
||||
try:
|
||||
choices = eval(choices)
|
||||
except Exception:
|
||||
choices = []
|
||||
|
||||
options = {}
|
||||
if choices:
|
||||
for i, choice in enumerate(choices):
|
||||
options[ascii_uppercase[i]] = choice
|
||||
else:
|
||||
for letter in ascii_uppercase[:6]:
|
||||
if letter in item and item[letter] is not None:
|
||||
val = item[letter]
|
||||
if isinstance(val, str) and val.strip():
|
||||
options[letter] = val
|
||||
|
||||
prompt = f"Question: {question}\n"
|
||||
if options:
|
||||
prompt += "Options:\n"
|
||||
for letter in sorted(options.keys()):
|
||||
prompt += f"{letter}. {options[letter]}\n"
|
||||
prompt += "\nPlease select the correct answer from the options above."
|
||||
|
||||
content = []
|
||||
if image_base64:
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
||||
})
|
||||
content.append({"type": "text", "text": prompt})
|
||||
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def extract_answer(self, response: str, num_choices: int) -> Tuple[Optional[str], str]:
|
||||
valid_letters = set(ascii_uppercase[:num_choices])
|
||||
|
||||
letter, method = extract_letter_from_answer_tag(response, valid_letters)
|
||||
if letter:
|
||||
return letter, method
|
||||
|
||||
letter, method = extract_mcqa_answer_with_fallback(response, num_choices)
|
||||
return letter, method
|
||||
|
||||
async def run_item(self, client: AsyncOpenAI, data_item: dict) -> Tuple[dict, dict]:
|
||||
try:
|
||||
messages = self.build_messages(data_item)
|
||||
|
||||
gen_params = self.get_generation_params()
|
||||
completion = await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=gen_params["temperature"],
|
||||
max_tokens=gen_params["max_tokens"],
|
||||
)
|
||||
|
||||
if not completion.choices:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
message = completion.choices[0].message
|
||||
response = message.content or ""
|
||||
|
||||
if not response:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
answer = data_item.get("answer", "")
|
||||
|
||||
choices = data_item.get("choices", [])
|
||||
if isinstance(choices, str):
|
||||
try:
|
||||
choices = eval(choices)
|
||||
except Exception:
|
||||
choices = []
|
||||
|
||||
num_choices = len(choices) if choices else 4
|
||||
|
||||
extracted, method = self.extract_answer(response, num_choices)
|
||||
|
||||
correct = False
|
||||
if extracted and answer:
|
||||
if str(answer).isdigit():
|
||||
answer_letter = ascii_uppercase[int(answer)]
|
||||
else:
|
||||
answer_letter = str(answer).upper()
|
||||
correct = extracted.upper() == answer_letter
|
||||
|
||||
sample = {
|
||||
"id": data_item.get("index", data_item.get("id", "")),
|
||||
"question": data_item.get("question", "")[:200],
|
||||
"answer": answer,
|
||||
"prediction": extracted,
|
||||
"raw_response": response[:500],
|
||||
"correct": correct,
|
||||
"extraction_method": method,
|
||||
}
|
||||
|
||||
return {"accuracy": 1.0 if correct else 0.0}, sample
|
||||
|
||||
except Exception as e:
|
||||
return {"accuracy": 0.0}, {"error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(
|
||||
eval_runner(
|
||||
AI2D,
|
||||
split="test",
|
||||
temperature=0.0,
|
||||
max_tokens=256,
|
||||
)
|
||||
)
|
||||
170
environments/eval_environments/blink_environment.py
Normal file
170
environments/eval_environments/blink_environment.py
Normal file
|
|
@ -0,0 +1,170 @@
|
|||
"""BLINK evaluation environment."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
from string import ascii_uppercase
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from datasets import load_dataset
|
||||
from openai import AsyncOpenAI
|
||||
from PIL import Image
|
||||
|
||||
from environments.eval_environments.eval_base import EvalBase, eval_runner
|
||||
from environments.eval_environments.eval_helpers import (
|
||||
extract_letter_from_answer_tag,
|
||||
extract_mcqa_answer_with_fallback,
|
||||
)
|
||||
|
||||
|
||||
class BLINK(EvalBase):
|
||||
"""BLINK evaluation - visual perception benchmark."""
|
||||
|
||||
def setup_data(self) -> list:
|
||||
split = getattr(self, "split", "val")
|
||||
task = getattr(self, "task", "Counting") # One of the BLINK task categories
|
||||
|
||||
try:
|
||||
dataset = load_dataset("BLINK-Benchmark/BLINK", task, split=split)
|
||||
print(f"Loaded {len(dataset)} examples from BLINK ({split}, {task})")
|
||||
return list(dataset)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not load BLINK: {e}")
|
||||
try:
|
||||
tasks = ["Counting", "Spatial_Relation", "Object_Localization", "Visual_Similarity"]
|
||||
all_data = []
|
||||
for t in tasks:
|
||||
try:
|
||||
ds = load_dataset("BLINK-Benchmark/BLINK", t, split=split)
|
||||
for item in ds:
|
||||
item["task"] = t
|
||||
all_data.append(item)
|
||||
except Exception:
|
||||
pass
|
||||
if all_data:
|
||||
print(f"Loaded {len(all_data)} examples from BLINK ({split})")
|
||||
return all_data
|
||||
raise ValueError(f"Could not load BLINK dataset: {e}")
|
||||
except Exception:
|
||||
raise ValueError(f"Could not load BLINK dataset: {e}")
|
||||
|
||||
def encode_image(self, pil_image: Image.Image) -> str:
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
def get_images(self, item: dict) -> List[str]:
|
||||
"""Get all images from item (BLINK can have multiple images)."""
|
||||
images = []
|
||||
for i in range(1, 5):
|
||||
key = f"image_{i}"
|
||||
if key in item and item[key] is not None:
|
||||
if isinstance(item[key], Image.Image):
|
||||
images.append(self.encode_image(item[key]))
|
||||
|
||||
if not images and "image" in item and item["image"] is not None:
|
||||
if isinstance(item["image"], Image.Image):
|
||||
images.append(self.encode_image(item["image"]))
|
||||
|
||||
return images
|
||||
|
||||
def build_messages(self, item: dict) -> List[dict]:
|
||||
images = self.get_images(item)
|
||||
question = item.get("question", "")
|
||||
|
||||
options = {}
|
||||
for letter in ascii_uppercase[:6]:
|
||||
if letter in item and item[letter] is not None:
|
||||
val = item[letter]
|
||||
if isinstance(val, str) and val.strip():
|
||||
options[letter] = val
|
||||
|
||||
prompt = f"Question: {question}\n"
|
||||
if options:
|
||||
prompt += "Options:\n"
|
||||
for letter in sorted(options.keys()):
|
||||
prompt += f"{letter}. {options[letter]}\n"
|
||||
prompt += "\nPlease select the correct answer from the options above."
|
||||
|
||||
content = []
|
||||
for img_b64 in images:
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{img_b64}"},
|
||||
})
|
||||
content.append({"type": "text", "text": prompt})
|
||||
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def extract_answer(self, response: str, num_choices: int) -> Tuple[Optional[str], str]:
|
||||
valid_letters = set(ascii_uppercase[:num_choices])
|
||||
|
||||
letter, method = extract_letter_from_answer_tag(response, valid_letters)
|
||||
if letter:
|
||||
return letter, method
|
||||
|
||||
letter, method = extract_mcqa_answer_with_fallback(response, num_choices)
|
||||
return letter, method
|
||||
|
||||
async def run_item(self, client: AsyncOpenAI, data_item: dict) -> Tuple[dict, dict]:
|
||||
try:
|
||||
messages = self.build_messages(data_item)
|
||||
|
||||
gen_params = self.get_generation_params()
|
||||
completion = await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=gen_params["temperature"],
|
||||
max_tokens=gen_params["max_tokens"],
|
||||
)
|
||||
|
||||
if not completion.choices:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
message = completion.choices[0].message
|
||||
response = message.content or ""
|
||||
|
||||
if not response:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
answer = data_item.get("answer", "")
|
||||
|
||||
num_choices = sum(
|
||||
1 for letter in ascii_uppercase[:6]
|
||||
if letter in data_item and data_item[letter] is not None
|
||||
and isinstance(data_item[letter], str) and data_item[letter].strip()
|
||||
)
|
||||
num_choices = max(num_choices, 4)
|
||||
|
||||
extracted, method = self.extract_answer(response, num_choices)
|
||||
|
||||
correct = False
|
||||
if extracted and answer:
|
||||
correct = extracted.upper() == str(answer).upper()
|
||||
|
||||
sample = {
|
||||
"id": data_item.get("index", data_item.get("id", "")),
|
||||
"question": data_item.get("question", "")[:200],
|
||||
"category": data_item.get("category", ""),
|
||||
"answer": answer,
|
||||
"prediction": extracted,
|
||||
"raw_response": response[:500],
|
||||
"correct": correct,
|
||||
"extraction_method": method,
|
||||
}
|
||||
|
||||
return {"accuracy": 1.0 if correct else 0.0}, sample
|
||||
|
||||
except Exception as e:
|
||||
return {"accuracy": 0.0}, {"error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(
|
||||
eval_runner(
|
||||
BLINK,
|
||||
split="val",
|
||||
temperature=0.0,
|
||||
max_tokens=256,
|
||||
)
|
||||
)
|
||||
145
environments/eval_environments/countbench_environment.py
Normal file
145
environments/eval_environments/countbench_environment.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
"""CountBench evaluation environment."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import re
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from datasets import load_dataset
|
||||
from openai import AsyncOpenAI
|
||||
from PIL import Image
|
||||
|
||||
from environments.eval_environments.eval_base import EvalBase, eval_runner
|
||||
|
||||
|
||||
class CountBench(EvalBase):
|
||||
"""CountBench evaluation - object counting benchmark."""
|
||||
|
||||
def setup_data(self) -> list:
|
||||
split = getattr(self, "split", "train") # CountBench only has train split
|
||||
|
||||
try:
|
||||
dataset = load_dataset("nielsr/countbench", split=split)
|
||||
print(f"Loaded {len(dataset)} examples from CountBench ({split})")
|
||||
return list(dataset)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not load CountBench: {e}")
|
||||
try:
|
||||
# Try train split explicitly
|
||||
dataset = load_dataset("nielsr/countbench", split="train")
|
||||
print(f"Loaded {len(dataset)} examples from CountBench (train)")
|
||||
return list(dataset)
|
||||
except Exception:
|
||||
try:
|
||||
dataset = load_dataset("google-research/countbenchqa", split="train")
|
||||
print(f"Loaded {len(dataset)} examples from CountBench (train)")
|
||||
return list(dataset)
|
||||
except Exception:
|
||||
raise ValueError(f"Could not load CountBench dataset: {e}")
|
||||
|
||||
def encode_image(self, pil_image: Image.Image) -> str:
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
def get_image_base64(self, item: dict) -> Optional[str]:
|
||||
for key in ["image", "decoded_image"]:
|
||||
if key in item and item[key] is not None:
|
||||
if isinstance(item[key], Image.Image):
|
||||
return self.encode_image(item[key])
|
||||
return None
|
||||
|
||||
def build_messages(self, item: dict) -> List[dict]:
|
||||
image_base64 = self.get_image_base64(item)
|
||||
question = item.get("question", "")
|
||||
|
||||
prompt = f"{question}\n\nNote: Answer with a number directly, e.g., 3. Do not include any additional text."
|
||||
|
||||
content = []
|
||||
if image_base64:
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
||||
})
|
||||
content.append({"type": "text", "text": prompt})
|
||||
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def extract_number(self, response: str) -> Optional[str]:
|
||||
"""Extract a number from the response."""
|
||||
numbers = re.findall(r'\b(\d+)\b', response)
|
||||
if numbers:
|
||||
return numbers[0]
|
||||
return None
|
||||
|
||||
def score(self, prediction: str, answer: str) -> bool:
|
||||
"""Score counting answer - check if answer appears in prediction."""
|
||||
answer_str = str(answer).strip()
|
||||
|
||||
if answer_str in prediction:
|
||||
return True
|
||||
|
||||
extracted = self.extract_number(prediction)
|
||||
if extracted and extracted == answer_str:
|
||||
return True
|
||||
|
||||
try:
|
||||
pred_num = int(self.extract_number(prediction) or prediction.strip())
|
||||
ans_num = int(answer_str)
|
||||
return pred_num == ans_num
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
async def run_item(self, client: AsyncOpenAI, data_item: dict) -> Tuple[dict, dict]:
|
||||
try:
|
||||
messages = self.build_messages(data_item)
|
||||
|
||||
gen_params = self.get_generation_params()
|
||||
completion = await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=gen_params["temperature"],
|
||||
max_tokens=gen_params["max_tokens"],
|
||||
)
|
||||
|
||||
if not completion.choices:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
message = completion.choices[0].message
|
||||
response = message.content or ""
|
||||
|
||||
if not response:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
answer = data_item.get("answer", data_item.get("number", ""))
|
||||
|
||||
correct = self.score(response, answer)
|
||||
extracted = self.extract_number(response)
|
||||
|
||||
sample = {
|
||||
"id": data_item.get("index", data_item.get("id", "")),
|
||||
"question": data_item.get("question", "")[:200],
|
||||
"answer": answer,
|
||||
"prediction": extracted or response[:50],
|
||||
"raw_response": response[:200],
|
||||
"correct": correct,
|
||||
}
|
||||
|
||||
return {"accuracy": 1.0 if correct else 0.0}, sample
|
||||
|
||||
except Exception as e:
|
||||
return {"accuracy": 0.0}, {"error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(
|
||||
eval_runner(
|
||||
CountBench,
|
||||
split="test",
|
||||
temperature=0.0,
|
||||
max_tokens=64,
|
||||
)
|
||||
)
|
||||
246
environments/eval_environments/dynamath_environment.py
Normal file
246
environments/eval_environments/dynamath_environment.py
Normal file
|
|
@ -0,0 +1,246 @@
|
|||
"""DynaMath evaluation environment."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
from string import ascii_uppercase
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from openai import AsyncOpenAI
|
||||
from PIL import Image
|
||||
|
||||
from environments.eval_environments.eval_base import EvalBase, eval_runner
|
||||
|
||||
|
||||
class DynaMath(EvalBase):
|
||||
"""DynaMath evaluation - dynamic mathematical reasoning benchmark."""
|
||||
|
||||
GUIDE = """
|
||||
## Answer Instruction
|
||||
Please provide an answer to the question outlined above. Your response should adhere to the following JSON format, which includes two keys: 'solution' and 'short answer'. The 'solution' key can contain detailed steps needed to solve the question, and the 'short answer' key should provide a concise response. {INST}
|
||||
|
||||
Example of expected JSON response format:
|
||||
|
||||
{{
|
||||
"solution": "[Detailed step-by-step explanation]",
|
||||
"short answer": "[Concise Answer]"
|
||||
}}
|
||||
"""
|
||||
|
||||
def setup_data(self) -> list:
|
||||
# DynaMath_Sample uses variant splits: sample_variant1, sample_variant2, etc.
|
||||
split = getattr(self, "split", "sample_variant1")
|
||||
|
||||
try:
|
||||
# DynaMath_Sample is the publicly available dataset
|
||||
dataset = load_dataset("DynaMath/DynaMath_Sample", split=split)
|
||||
print(f"Loaded {len(dataset)} examples from DynaMath ({split})")
|
||||
return list(dataset)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not load DynaMath: {e}")
|
||||
try:
|
||||
# Try sample_variant1 explicitly
|
||||
dataset = load_dataset("DynaMath/DynaMath_Sample", split="sample_variant1")
|
||||
print(f"Loaded {len(dataset)} examples from DynaMath (sample_variant1)")
|
||||
return list(dataset)
|
||||
except Exception:
|
||||
try:
|
||||
dataset = load_dataset("lmms-lab/DynaMath", split="test")
|
||||
print(f"Loaded {len(dataset)} examples from DynaMath (test)")
|
||||
return list(dataset)
|
||||
except Exception:
|
||||
raise ValueError(f"Could not load DynaMath dataset: {e}")
|
||||
|
||||
def encode_image(self, pil_image: Image.Image) -> str:
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
def get_image_base64(self, item: dict) -> Optional[str]:
|
||||
for key in ["image", "decoded_image"]:
|
||||
if key in item and item[key] is not None:
|
||||
if isinstance(item[key], Image.Image):
|
||||
return self.encode_image(item[key])
|
||||
return None
|
||||
|
||||
def build_messages(self, item: dict) -> List[dict]:
|
||||
image_base64 = self.get_image_base64(item)
|
||||
question = item.get("question", "")
|
||||
answer_type = item.get("answer_type", "free_form")
|
||||
|
||||
use_json_format = getattr(self, "use_json_format", True)
|
||||
|
||||
if use_json_format:
|
||||
if answer_type == "multiple choice":
|
||||
inst = "Provide the corresponding choice option in the 'short answer' key, such as 'A', 'B', 'C', or 'D'."
|
||||
elif answer_type == "float":
|
||||
inst = "Format the answer as a three-digit floating-point number and provide it in the 'short answer' key."
|
||||
else:
|
||||
inst = "Float numbers in the answer should be formatted as three-digit floating-point numbers."
|
||||
|
||||
prompt = f"## Question\n{question}" + self.GUIDE.format(INST=inst)
|
||||
else:
|
||||
prompt = question
|
||||
|
||||
content = []
|
||||
if image_base64:
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
||||
})
|
||||
content.append({"type": "text", "text": prompt})
|
||||
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def preprocess_response(self, response: str) -> str:
|
||||
"""Preprocess response to extract JSON."""
|
||||
response = str(response)
|
||||
if 0 <= response.find("{") < response.rfind("}"):
|
||||
response = response[response.find("{"): response.rfind("}") + 1]
|
||||
response = response.replace("\\", "").replace("\\n", "\n")
|
||||
return response
|
||||
|
||||
def transfer_pi(self, value: str) -> float:
|
||||
"""Convert pi symbol to numeric value."""
|
||||
if "\u03c0" in value:
|
||||
parts = value.split('\u03c0')
|
||||
return float(parts[0]) * np.pi
|
||||
return float(value)
|
||||
|
||||
def parse_answer(self, answer: str, answer_type: str) -> Tuple[bool, Optional[str]]:
|
||||
"""Parse answer based on type."""
|
||||
if answer_type == "float":
|
||||
if answer.isdigit():
|
||||
return True, str(float(answer))
|
||||
parts = answer.split(' ')
|
||||
answer = parts[0]
|
||||
try:
|
||||
result = self.transfer_pi(answer)
|
||||
return True, str(result)
|
||||
except Exception:
|
||||
return False, None
|
||||
|
||||
elif answer_type == "multiple choice":
|
||||
if len(answer) == 1 and answer.upper() in ascii_uppercase[:5]:
|
||||
return True, answer.upper()
|
||||
# Check if any letter appears
|
||||
for ch in ascii_uppercase[:5]:
|
||||
if ch in answer.upper():
|
||||
return True, ch
|
||||
return False, None
|
||||
|
||||
else:
|
||||
return True, answer
|
||||
|
||||
def extract_answer(self, response: str, answer_type: str) -> Tuple[bool, Optional[str]]:
|
||||
"""Extract answer from response."""
|
||||
processed = self.preprocess_response(response)
|
||||
|
||||
try:
|
||||
dj = json.loads(processed, strict=False)
|
||||
short_answer = dj.get("short answer")
|
||||
if short_answer is not None:
|
||||
return self.parse_answer(str(short_answer), answer_type)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if answer_type == "multiple choice":
|
||||
for ch in ascii_uppercase[:5]:
|
||||
if response.strip().upper().startswith(ch):
|
||||
return True, ch
|
||||
for ch in ascii_uppercase[:5]:
|
||||
if ch in response.upper()[:20]:
|
||||
return True, ch
|
||||
elif answer_type == "float":
|
||||
numbers = re.findall(r'-?\d+\.?\d*', response)
|
||||
if numbers:
|
||||
try:
|
||||
return True, str(float(numbers[0]))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return False, None
|
||||
|
||||
def score_answer(
|
||||
self, extracted: Optional[str], answer: str, answer_type: str, parsed: bool
|
||||
) -> bool:
|
||||
"""Score the extracted answer against ground truth."""
|
||||
if not parsed or extracted is None:
|
||||
# Check if answer appears in raw response for MC
|
||||
return False
|
||||
|
||||
if answer_type == "float":
|
||||
try:
|
||||
pred_val = float(extracted)
|
||||
ans_val = float(answer)
|
||||
return abs(pred_val - ans_val) <= 0.001
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
elif answer_type == "multiple choice":
|
||||
return extracted.upper() == str(answer).upper()
|
||||
|
||||
else:
|
||||
# Free form: substring match
|
||||
return extracted.lower() in answer.lower() or answer.lower() in extracted.lower()
|
||||
|
||||
async def run_item(self, client: AsyncOpenAI, data_item: dict) -> Tuple[dict, dict]:
|
||||
try:
|
||||
messages = self.build_messages(data_item)
|
||||
|
||||
gen_params = self.get_generation_params()
|
||||
completion = await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=gen_params["temperature"],
|
||||
max_tokens=gen_params["max_tokens"],
|
||||
)
|
||||
|
||||
if not completion.choices:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
message = completion.choices[0].message
|
||||
response = message.content or ""
|
||||
|
||||
if not response:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
answer = data_item.get("ground_truth", data_item.get("answer", ""))
|
||||
answer_type = data_item.get("answer_type", "free_form")
|
||||
|
||||
parsed, extracted = self.extract_answer(response, answer_type)
|
||||
correct = self.score_answer(extracted, answer, answer_type, parsed)
|
||||
|
||||
sample = {
|
||||
"id": data_item.get("index", data_item.get("id", "")),
|
||||
"question": data_item.get("question", "")[:200],
|
||||
"subject": data_item.get("subject", ""),
|
||||
"knowledge_level": data_item.get("knowledge_level", ""),
|
||||
"answer_type": answer_type,
|
||||
"answer": answer,
|
||||
"prediction": extracted,
|
||||
"parsed": parsed,
|
||||
"raw_response": response[:500],
|
||||
"correct": correct,
|
||||
}
|
||||
|
||||
return {"accuracy": 1.0 if correct else 0.0}, sample
|
||||
|
||||
except Exception as e:
|
||||
return {"accuracy": 0.0}, {"error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(
|
||||
eval_runner(
|
||||
DynaMath,
|
||||
split="test",
|
||||
use_json_format=True,
|
||||
temperature=0.0,
|
||||
max_tokens=1024,
|
||||
)
|
||||
)
|
||||
153
environments/eval_environments/hallusionbench_environment.py
Normal file
153
environments/eval_environments/hallusionbench_environment.py
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
"""HallusionBench evaluation environment."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import re
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from datasets import load_dataset
|
||||
from openai import AsyncOpenAI
|
||||
from PIL import Image
|
||||
|
||||
from environments.eval_environments.eval_base import EvalBase, eval_runner
|
||||
|
||||
|
||||
class HallusionBench(EvalBase):
|
||||
"""HallusionBench evaluation - visual hallucination benchmark."""
|
||||
|
||||
def setup_data(self) -> list:
|
||||
# HallusionBench has 'image' and 'non_image' splits
|
||||
split = getattr(self, "split", "image")
|
||||
|
||||
try:
|
||||
dataset = load_dataset("lmms-lab/HallusionBench", split=split)
|
||||
print(f"Loaded {len(dataset)} examples from HallusionBench ({split})")
|
||||
return list(dataset)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not load HallusionBench: {e}")
|
||||
try:
|
||||
# Try combining both splits
|
||||
all_data = []
|
||||
for s in ["image", "non_image"]:
|
||||
try:
|
||||
ds = load_dataset("lmms-lab/HallusionBench", split=s)
|
||||
all_data.extend(list(ds))
|
||||
except Exception:
|
||||
pass
|
||||
if all_data:
|
||||
print(f"Loaded {len(all_data)} examples from HallusionBench (combined)")
|
||||
return all_data
|
||||
raise ValueError(f"Could not load HallusionBench dataset: {e}")
|
||||
except Exception:
|
||||
raise ValueError(f"Could not load HallusionBench dataset: {e}")
|
||||
|
||||
def encode_image(self, pil_image: Image.Image) -> str:
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
def get_image_base64(self, item: dict) -> Optional[str]:
|
||||
for key in ["image", "decoded_image"]:
|
||||
if key in item and item[key] is not None:
|
||||
if isinstance(item[key], Image.Image):
|
||||
return self.encode_image(item[key])
|
||||
return None
|
||||
|
||||
def build_messages(self, item: dict) -> List[dict]:
|
||||
image_base64 = self.get_image_base64(item)
|
||||
question = item.get("question", "")
|
||||
|
||||
prompt = f"{question}\n\nPlease answer yes or no."
|
||||
|
||||
content = []
|
||||
if image_base64:
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
||||
})
|
||||
content.append({"type": "text", "text": prompt})
|
||||
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def extract_yorn(self, response: str) -> str:
|
||||
"""Extract Yes/No from response."""
|
||||
response_lower = response.lower().strip()
|
||||
|
||||
if response_lower.startswith("yes"):
|
||||
return "Yes"
|
||||
if response_lower.startswith("no"):
|
||||
return "No"
|
||||
|
||||
yes_patterns = [r'\byes\b', r'\btrue\b', r'\bcorrect\b']
|
||||
no_patterns = [r'\bno\b', r'\bfalse\b', r'\bincorrect\b']
|
||||
|
||||
for pattern in yes_patterns:
|
||||
if re.search(pattern, response_lower):
|
||||
return "Yes"
|
||||
|
||||
for pattern in no_patterns:
|
||||
if re.search(pattern, response_lower):
|
||||
return "No"
|
||||
|
||||
return "Unknown"
|
||||
|
||||
async def run_item(self, client: AsyncOpenAI, data_item: dict) -> Tuple[dict, dict]:
|
||||
try:
|
||||
messages = self.build_messages(data_item)
|
||||
|
||||
gen_params = self.get_generation_params()
|
||||
completion = await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=gen_params["temperature"],
|
||||
max_tokens=gen_params["max_tokens"],
|
||||
)
|
||||
|
||||
if not completion.choices:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
message = completion.choices[0].message
|
||||
response = message.content or ""
|
||||
|
||||
if not response:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
answer = data_item.get("answer", data_item.get("gt_answer", ""))
|
||||
extracted = self.extract_yorn(response)
|
||||
|
||||
answer_norm = str(answer).strip().lower()
|
||||
if answer_norm in ["yes", "true", "1"]:
|
||||
answer_norm = "Yes"
|
||||
elif answer_norm in ["no", "false", "0"]:
|
||||
answer_norm = "No"
|
||||
else:
|
||||
answer_norm = str(answer).strip()
|
||||
|
||||
correct = extracted == answer_norm
|
||||
|
||||
sample = {
|
||||
"id": data_item.get("index", data_item.get("id", "")),
|
||||
"question": data_item.get("question", "")[:200],
|
||||
"category": data_item.get("category", data_item.get("subcategory", "")),
|
||||
"answer": answer_norm,
|
||||
"prediction": extracted,
|
||||
"raw_response": response[:200],
|
||||
"correct": correct,
|
||||
}
|
||||
|
||||
return {"accuracy": 1.0 if correct else 0.0}, sample
|
||||
|
||||
except Exception as e:
|
||||
return {"accuracy": 0.0}, {"error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(
|
||||
eval_runner(
|
||||
HallusionBench,
|
||||
split="test",
|
||||
temperature=0.0,
|
||||
max_tokens=64,
|
||||
)
|
||||
)
|
||||
164
environments/eval_environments/mmbench_environment.py
Normal file
164
environments/eval_environments/mmbench_environment.py
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
"""MMBench evaluation environment."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
from string import ascii_uppercase
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from datasets import load_dataset
|
||||
from openai import AsyncOpenAI
|
||||
from PIL import Image
|
||||
|
||||
from environments.eval_environments.eval_base import EvalBase, eval_runner
|
||||
from environments.eval_environments.eval_helpers import (
|
||||
extract_letter_from_answer_tag,
|
||||
extract_mcqa_answer_with_fallback,
|
||||
)
|
||||
|
||||
|
||||
class MMBench(EvalBase):
|
||||
"""MMBench evaluation - comprehensive multimodal benchmark."""
|
||||
|
||||
def setup_data(self) -> list:
|
||||
split = getattr(self, "split", "dev")
|
||||
lang = getattr(self, "lang", "en") # en, cn, cc
|
||||
version = getattr(self, "version", "v1.1") # v1.0 or v1.1
|
||||
|
||||
try:
|
||||
dataset = load_dataset("lmms-lab/MMBench", lang, split=split)
|
||||
print(f"Loaded {len(dataset)} examples from MMBench ({split}, {lang})")
|
||||
return list(dataset)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not load from lmms-lab: {e}")
|
||||
try:
|
||||
dataset = load_dataset("lmms-lab/MMBench_EN", split=split)
|
||||
print(f"Loaded {len(dataset)} examples from MMBench ({split})")
|
||||
return list(dataset)
|
||||
except Exception:
|
||||
raise ValueError(f"Could not load MMBench dataset: {e}")
|
||||
|
||||
def encode_image(self, pil_image: Image.Image) -> str:
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
def get_image_base64(self, item: dict) -> Optional[str]:
|
||||
for key in ["image", "decoded_image"]:
|
||||
if key in item and item[key] is not None:
|
||||
if isinstance(item[key], Image.Image):
|
||||
return self.encode_image(item[key])
|
||||
return None
|
||||
|
||||
def build_messages(self, item: dict) -> List[dict]:
|
||||
image_base64 = self.get_image_base64(item)
|
||||
question = item.get("question", "")
|
||||
hint = item.get("hint", "")
|
||||
|
||||
options = {}
|
||||
for letter in ascii_uppercase:
|
||||
if letter in item and item[letter] is not None:
|
||||
val = item[letter]
|
||||
if isinstance(val, str) and val.strip():
|
||||
options[letter] = val
|
||||
elif not isinstance(val, float):
|
||||
options[letter] = str(val)
|
||||
|
||||
prompt = ""
|
||||
if hint and str(hint).strip() and str(hint).lower() != "nan":
|
||||
prompt += f"Hint: {hint}\n"
|
||||
prompt += f"Question: {question}\n"
|
||||
|
||||
if options:
|
||||
prompt += "Options:\n"
|
||||
for letter in sorted(options.keys()):
|
||||
prompt += f"{letter}. {options[letter]}\n"
|
||||
prompt += "\nPlease select the correct answer from the options above."
|
||||
|
||||
content = []
|
||||
if image_base64:
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
||||
})
|
||||
content.append({"type": "text", "text": prompt})
|
||||
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def extract_answer(self, response: str, num_choices: int) -> Tuple[Optional[str], str]:
|
||||
valid_letters = set(ascii_uppercase[:num_choices])
|
||||
|
||||
letter, method = extract_letter_from_answer_tag(response, valid_letters)
|
||||
if letter:
|
||||
return letter, method
|
||||
|
||||
letter, method = extract_mcqa_answer_with_fallback(response, num_choices)
|
||||
return letter, method
|
||||
|
||||
async def run_item(self, client: AsyncOpenAI, data_item: dict) -> Tuple[dict, dict]:
|
||||
try:
|
||||
messages = self.build_messages(data_item)
|
||||
|
||||
gen_params = self.get_generation_params()
|
||||
completion = await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=gen_params["temperature"],
|
||||
max_tokens=gen_params["max_tokens"],
|
||||
)
|
||||
|
||||
if not completion.choices:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
message = completion.choices[0].message
|
||||
response = message.content or ""
|
||||
|
||||
if not response:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
answer = data_item.get("answer", "")
|
||||
|
||||
num_choices = 0
|
||||
for letter in ascii_uppercase:
|
||||
if letter in data_item and data_item[letter] is not None:
|
||||
val = data_item[letter]
|
||||
if isinstance(val, str) and val.strip():
|
||||
num_choices += 1
|
||||
elif not isinstance(val, float):
|
||||
num_choices += 1
|
||||
num_choices = max(num_choices, 4)
|
||||
|
||||
extracted, method = self.extract_answer(response, num_choices)
|
||||
|
||||
correct = False
|
||||
if extracted and answer:
|
||||
correct = extracted.upper() == str(answer).upper()
|
||||
|
||||
sample = {
|
||||
"id": data_item.get("index", data_item.get("id", "")),
|
||||
"question": data_item.get("question", "")[:200],
|
||||
"category": data_item.get("category", data_item.get("l2-category", "")),
|
||||
"answer": answer,
|
||||
"prediction": extracted,
|
||||
"raw_response": response[:500],
|
||||
"correct": correct,
|
||||
"extraction_method": method,
|
||||
}
|
||||
|
||||
return {"accuracy": 1.0 if correct else 0.0}, sample
|
||||
|
||||
except Exception as e:
|
||||
return {"accuracy": 0.0}, {"error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(
|
||||
eval_runner(
|
||||
MMBench,
|
||||
split="dev",
|
||||
lang="en",
|
||||
version="v1.1",
|
||||
temperature=0.0,
|
||||
max_tokens=256,
|
||||
)
|
||||
)
|
||||
172
environments/eval_environments/mmmu_environment.py
Normal file
172
environments/eval_environments/mmmu_environment.py
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
"""MMMU (Massive Multi-discipline Multimodal Understanding) evaluation environment."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import re
|
||||
from string import ascii_uppercase
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from datasets import load_dataset
|
||||
from openai import AsyncOpenAI
|
||||
from PIL import Image
|
||||
|
||||
from environments.eval_environments.eval_base import EvalBase, eval_runner
|
||||
from environments.eval_environments.eval_helpers import (
|
||||
extract_letter_from_answer_tag,
|
||||
extract_mcqa_answer_with_fallback,
|
||||
)
|
||||
|
||||
|
||||
class MMMU(EvalBase):
|
||||
"""MMMU evaluation - multi-discipline multimodal understanding benchmark."""
|
||||
|
||||
def setup_data(self) -> list:
|
||||
split = getattr(self, "split", "validation")
|
||||
subset = getattr(self, "subset", None)
|
||||
|
||||
if subset:
|
||||
dataset = load_dataset("MMMU/MMMU", subset, split=split)
|
||||
else:
|
||||
subjects = [
|
||||
"Accounting", "Agriculture", "Architecture_and_Engineering",
|
||||
"Art", "Art_Theory", "Basic_Medical_Science", "Biology",
|
||||
"Chemistry", "Clinical_Medicine", "Computer_Science",
|
||||
"Design", "Diagnostics_and_Laboratory_Medicine", "Economics",
|
||||
"Electronics", "Energy_and_Power", "Finance", "Geography",
|
||||
"History", "Literature", "Manage", "Marketing", "Materials",
|
||||
"Math", "Mechanical_Engineering", "Music", "Pharmacy",
|
||||
"Physics", "Psychology", "Public_Health", "Sociology"
|
||||
]
|
||||
all_data = []
|
||||
for subj in subjects:
|
||||
try:
|
||||
ds = load_dataset("MMMU/MMMU", subj, split=split)
|
||||
for item in ds:
|
||||
item["subject"] = subj
|
||||
all_data.append(item)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not load subject {subj}: {e}")
|
||||
print(f"Loaded {len(all_data)} examples from MMMU ({split})")
|
||||
return all_data
|
||||
|
||||
print(f"Loaded {len(dataset)} examples from MMMU ({split}, {subset})")
|
||||
return list(dataset)
|
||||
|
||||
def encode_image(self, pil_image: Image.Image) -> str:
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
def get_images(self, item: dict) -> List[str]:
|
||||
"""Extract all images from the item (MMMU can have multiple images)."""
|
||||
images = []
|
||||
for i in range(1, 8): # MMMU supports up to 7 images
|
||||
key = f"image_{i}"
|
||||
if key in item and item[key] is not None:
|
||||
if isinstance(item[key], Image.Image):
|
||||
images.append(self.encode_image(item[key]))
|
||||
return images
|
||||
|
||||
def build_messages(self, item: dict) -> List[dict]:
|
||||
images = self.get_images(item)
|
||||
question = item.get("question", "")
|
||||
options = item.get("options", [])
|
||||
|
||||
if isinstance(options, str):
|
||||
try:
|
||||
options = eval(options)
|
||||
except Exception:
|
||||
options = []
|
||||
|
||||
if options:
|
||||
options_text = "\n".join([
|
||||
f"({ascii_uppercase[i]}) {opt}" for i, opt in enumerate(options)
|
||||
])
|
||||
prompt = f"Question: {question}\n\nOptions:\n{options_text}\n\nPlease select the correct answer from the options above."
|
||||
else:
|
||||
prompt = f"Question: {question}\n\nProvide your answer."
|
||||
|
||||
content = []
|
||||
for img_b64 in images:
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{img_b64}"},
|
||||
})
|
||||
content.append({"type": "text", "text": prompt})
|
||||
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def extract_answer(self, response: str, num_choices: int) -> Tuple[Optional[str], str]:
|
||||
"""Extract answer letter from response."""
|
||||
valid_letters = set(ascii_uppercase[:num_choices])
|
||||
|
||||
letter, method = extract_letter_from_answer_tag(response, valid_letters)
|
||||
if letter:
|
||||
return letter, method
|
||||
|
||||
letter, method = extract_mcqa_answer_with_fallback(response, num_choices)
|
||||
return letter, method
|
||||
|
||||
async def run_item(self, client: AsyncOpenAI, data_item: dict) -> Tuple[dict, dict]:
|
||||
try:
|
||||
messages = self.build_messages(data_item)
|
||||
|
||||
gen_params = self.get_generation_params()
|
||||
completion = await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=gen_params["temperature"],
|
||||
max_tokens=gen_params["max_tokens"],
|
||||
)
|
||||
|
||||
if not completion.choices:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
message = completion.choices[0].message
|
||||
response = message.content or ""
|
||||
|
||||
if not response:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
answer = data_item.get("answer", "")
|
||||
options = data_item.get("options", [])
|
||||
if isinstance(options, str):
|
||||
try:
|
||||
options = eval(options)
|
||||
except Exception:
|
||||
options = []
|
||||
|
||||
num_choices = len(options) if options else 4
|
||||
extracted, method = self.extract_answer(response, num_choices)
|
||||
|
||||
correct = False
|
||||
if extracted and answer:
|
||||
correct = extracted.upper() == answer.upper()
|
||||
|
||||
sample = {
|
||||
"id": data_item.get("id", ""),
|
||||
"question": data_item.get("question", "")[:200],
|
||||
"subject": data_item.get("subject", ""),
|
||||
"answer": answer,
|
||||
"prediction": extracted,
|
||||
"raw_response": response[:500],
|
||||
"correct": correct,
|
||||
"extraction_method": method,
|
||||
}
|
||||
|
||||
return {"accuracy": 1.0 if correct else 0.0}, sample
|
||||
|
||||
except Exception as e:
|
||||
return {"accuracy": 0.0}, {"error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(
|
||||
eval_runner(
|
||||
MMMU,
|
||||
split="validation",
|
||||
temperature=0.0,
|
||||
max_tokens=1024,
|
||||
)
|
||||
)
|
||||
208
environments/eval_environments/mmmu_pro_environment.py
Normal file
208
environments/eval_environments/mmmu_pro_environment.py
Normal file
|
|
@ -0,0 +1,208 @@
|
|||
"""MMMU-Pro evaluation environment."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import re
|
||||
from string import ascii_uppercase
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from datasets import load_dataset
|
||||
from openai import AsyncOpenAI
|
||||
from PIL import Image
|
||||
|
||||
from environments.eval_environments.eval_base import EvalBase, eval_runner
|
||||
from environments.eval_environments.eval_helpers import (
|
||||
extract_letter_from_answer_tag,
|
||||
extract_mcqa_answer_with_fallback,
|
||||
)
|
||||
|
||||
|
||||
class MMMUPro(EvalBase):
|
||||
"""MMMU-Pro evaluation - harder version of MMMU with 10 choices."""
|
||||
|
||||
def setup_data(self) -> list:
|
||||
split = getattr(self, "split", "test")
|
||||
variant = getattr(self, "variant", "standard") # standard, vision, standard_4
|
||||
|
||||
config_map = {
|
||||
"standard": "standard (10 options)",
|
||||
"standard_4": "standard (4 options)",
|
||||
"vision": "vision",
|
||||
}
|
||||
config = config_map.get(variant, "standard (10 options)")
|
||||
|
||||
try:
|
||||
dataset = load_dataset("MMMU/MMMU_Pro", config, split=split)
|
||||
print(f"Loaded {len(dataset)} examples from MMMU-Pro ({split}, {config})")
|
||||
return list(dataset)
|
||||
except Exception as e:
|
||||
print(f"Error loading MMMU-Pro: {e}")
|
||||
try:
|
||||
dataset = load_dataset("MMMU/MMMU_Pro", "standard (10 options)", split="test")
|
||||
print(f"Loaded {len(dataset)} examples from MMMU-Pro (test)")
|
||||
return list(dataset)
|
||||
except Exception:
|
||||
raise ValueError(f"Could not load MMMU-Pro dataset: {e}")
|
||||
|
||||
def encode_image(self, pil_image: Image.Image) -> str:
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
def get_images(self, item: dict) -> List[str]:
|
||||
"""Extract all images from the item."""
|
||||
images = []
|
||||
for i in range(1, 8):
|
||||
key = f"image_{i}"
|
||||
if key in item and item[key] is not None:
|
||||
if isinstance(item[key], Image.Image):
|
||||
images.append(self.encode_image(item[key]))
|
||||
if "image" in item and item["image"] is not None:
|
||||
if isinstance(item["image"], Image.Image):
|
||||
images.append(self.encode_image(item["image"]))
|
||||
return images
|
||||
|
||||
def build_messages(self, item: dict) -> List[dict]:
|
||||
images = self.get_images(item)
|
||||
question = item.get("question", "")
|
||||
options = item.get("options", [])
|
||||
|
||||
if isinstance(options, str):
|
||||
try:
|
||||
options = eval(options)
|
||||
except Exception:
|
||||
options = []
|
||||
|
||||
variant = getattr(self, "variant", "standard")
|
||||
|
||||
if variant == "vision":
|
||||
prompt = "Answer the following multiple-choice question in the image. Answer directly with the option letter from the given choices."
|
||||
else:
|
||||
if options:
|
||||
options_text = "\n".join([
|
||||
f"{ascii_uppercase[i]}. {opt}" for i, opt in enumerate(options)
|
||||
])
|
||||
prompt = f"Question: {question}\n\nOptions:\n{options_text}\n\n"
|
||||
|
||||
if variant == "cot":
|
||||
prompt += (
|
||||
"Answer the following multiple-choice question. "
|
||||
"The last line of your response should be of the following format: "
|
||||
"'Answer: $LETTER' (without quotes) where LETTER is one of the options. "
|
||||
"Think step by step before answering."
|
||||
)
|
||||
else:
|
||||
prompt += "Answer directly with the option letter from the given choices."
|
||||
else:
|
||||
prompt = f"Question: {question}\n\nProvide your answer."
|
||||
|
||||
content = []
|
||||
for img_b64 in images:
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{img_b64}"},
|
||||
})
|
||||
content.append({"type": "text", "text": prompt})
|
||||
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def extract_answer_cot(self, response: str) -> Optional[str]:
|
||||
"""Extract answer from COT response format 'Answer: X'."""
|
||||
lines = response.strip().split('\n')
|
||||
lines = [x.strip() for x in lines]
|
||||
|
||||
for line in reversed(lines):
|
||||
if line.startswith('Answer:'):
|
||||
rest = line[7:].strip()
|
||||
from collections import Counter
|
||||
letter_counts = Counter(ch for ch in rest.upper() if ch in ascii_uppercase[:10])
|
||||
if len(letter_counts) == 1:
|
||||
return list(letter_counts.keys())[0]
|
||||
elif letter_counts:
|
||||
for ch in rest.upper():
|
||||
if ch in ascii_uppercase[:10]:
|
||||
return ch
|
||||
return None
|
||||
|
||||
def extract_answer(self, response: str, num_choices: int) -> Tuple[Optional[str], str]:
|
||||
"""Extract answer letter from response."""
|
||||
variant = getattr(self, "variant", "standard")
|
||||
|
||||
if variant == "cot":
|
||||
cot_answer = self.extract_answer_cot(response)
|
||||
if cot_answer:
|
||||
return cot_answer, "cot_extraction"
|
||||
|
||||
valid_letters = set(ascii_uppercase[:num_choices])
|
||||
|
||||
letter, method = extract_letter_from_answer_tag(response, valid_letters)
|
||||
if letter:
|
||||
return letter, method
|
||||
|
||||
letter, method = extract_mcqa_answer_with_fallback(response, num_choices)
|
||||
return letter, method
|
||||
|
||||
async def run_item(self, client: AsyncOpenAI, data_item: dict) -> Tuple[dict, dict]:
|
||||
try:
|
||||
messages = self.build_messages(data_item)
|
||||
|
||||
gen_params = self.get_generation_params()
|
||||
completion = await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=gen_params["temperature"],
|
||||
max_tokens=gen_params["max_tokens"],
|
||||
)
|
||||
|
||||
if not completion.choices:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
message = completion.choices[0].message
|
||||
response = message.content or ""
|
||||
|
||||
if not response:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
answer = data_item.get("answer", "")
|
||||
options = data_item.get("options", [])
|
||||
if isinstance(options, str):
|
||||
try:
|
||||
options = eval(options)
|
||||
except Exception:
|
||||
options = []
|
||||
|
||||
num_choices = len(options) if options else 10
|
||||
extracted, method = self.extract_answer(response, num_choices)
|
||||
|
||||
correct = False
|
||||
if extracted and answer:
|
||||
correct = extracted.upper() == answer.upper()
|
||||
|
||||
sample = {
|
||||
"id": data_item.get("id", ""),
|
||||
"question": data_item.get("question", "")[:200],
|
||||
"subject": data_item.get("subject", ""),
|
||||
"answer": answer,
|
||||
"prediction": extracted,
|
||||
"raw_response": response[:500],
|
||||
"correct": correct,
|
||||
"extraction_method": method,
|
||||
}
|
||||
|
||||
return {"accuracy": 1.0 if correct else 0.0}, sample
|
||||
|
||||
except Exception as e:
|
||||
return {"accuracy": 0.0}, {"error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(
|
||||
eval_runner(
|
||||
MMMUPro,
|
||||
split="test",
|
||||
variant="standard",
|
||||
temperature=0.0,
|
||||
max_tokens=1024,
|
||||
)
|
||||
)
|
||||
150
environments/eval_environments/mmstar_environment.py
Normal file
150
environments/eval_environments/mmstar_environment.py
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
"""MMStar evaluation environment."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
from string import ascii_uppercase
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from datasets import load_dataset
|
||||
from openai import AsyncOpenAI
|
||||
from PIL import Image
|
||||
|
||||
from environments.eval_environments.eval_base import EvalBase, eval_runner
|
||||
from environments.eval_environments.eval_helpers import (
|
||||
extract_letter_from_answer_tag,
|
||||
extract_mcqa_answer_with_fallback,
|
||||
)
|
||||
|
||||
|
||||
class MMStar(EvalBase):
|
||||
"""MMStar evaluation - expert-level multimodal benchmark."""
|
||||
|
||||
def setup_data(self) -> list:
|
||||
split = getattr(self, "split", "val")
|
||||
|
||||
try:
|
||||
dataset = load_dataset("Lin-Chen/MMStar", split=split)
|
||||
print(f"Loaded {len(dataset)} examples from MMStar ({split})")
|
||||
return list(dataset)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not load MMStar: {e}")
|
||||
try:
|
||||
dataset = load_dataset("lmms-lab/MMStar", split=split)
|
||||
print(f"Loaded {len(dataset)} examples from MMStar ({split})")
|
||||
return list(dataset)
|
||||
except Exception:
|
||||
raise ValueError(f"Could not load MMStar dataset: {e}")
|
||||
|
||||
def encode_image(self, pil_image: Image.Image) -> str:
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
def get_image_base64(self, item: dict) -> Optional[str]:
|
||||
for key in ["image", "decoded_image"]:
|
||||
if key in item and item[key] is not None:
|
||||
if isinstance(item[key], Image.Image):
|
||||
return self.encode_image(item[key])
|
||||
return None
|
||||
|
||||
def build_messages(self, item: dict) -> List[dict]:
|
||||
image_base64 = self.get_image_base64(item)
|
||||
question = item.get("question", "")
|
||||
|
||||
options = {}
|
||||
for letter in ascii_uppercase[:6]: # MMStar typically has up to 6 options
|
||||
if letter in item and item[letter] is not None:
|
||||
val = item[letter]
|
||||
if isinstance(val, str) and val.strip():
|
||||
options[letter] = val
|
||||
|
||||
prompt = f"Question: {question}\n"
|
||||
if options:
|
||||
prompt += "Options:\n"
|
||||
for letter in sorted(options.keys()):
|
||||
prompt += f"{letter}. {options[letter]}\n"
|
||||
prompt += "\nPlease select the correct answer from the options above."
|
||||
|
||||
content = []
|
||||
if image_base64:
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
||||
})
|
||||
content.append({"type": "text", "text": prompt})
|
||||
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def extract_answer(self, response: str, num_choices: int) -> Tuple[Optional[str], str]:
|
||||
valid_letters = set(ascii_uppercase[:num_choices])
|
||||
|
||||
letter, method = extract_letter_from_answer_tag(response, valid_letters)
|
||||
if letter:
|
||||
return letter, method
|
||||
|
||||
letter, method = extract_mcqa_answer_with_fallback(response, num_choices)
|
||||
return letter, method
|
||||
|
||||
async def run_item(self, client: AsyncOpenAI, data_item: dict) -> Tuple[dict, dict]:
|
||||
try:
|
||||
messages = self.build_messages(data_item)
|
||||
|
||||
gen_params = self.get_generation_params()
|
||||
completion = await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=gen_params["temperature"],
|
||||
max_tokens=gen_params["max_tokens"],
|
||||
)
|
||||
|
||||
if not completion.choices:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
message = completion.choices[0].message
|
||||
response = message.content or ""
|
||||
|
||||
if not response:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
answer = data_item.get("answer", "")
|
||||
|
||||
num_choices = sum(
|
||||
1 for letter in ascii_uppercase[:6]
|
||||
if letter in data_item and data_item[letter] is not None
|
||||
and isinstance(data_item[letter], str) and data_item[letter].strip()
|
||||
)
|
||||
num_choices = max(num_choices, 4)
|
||||
|
||||
extracted, method = self.extract_answer(response, num_choices)
|
||||
|
||||
correct = False
|
||||
if extracted and answer:
|
||||
correct = extracted.upper() == str(answer).upper()
|
||||
|
||||
sample = {
|
||||
"id": data_item.get("index", data_item.get("id", "")),
|
||||
"question": data_item.get("question", "")[:200],
|
||||
"category": data_item.get("category", ""),
|
||||
"answer": answer,
|
||||
"prediction": extracted,
|
||||
"raw_response": response[:500],
|
||||
"correct": correct,
|
||||
"extraction_method": method,
|
||||
}
|
||||
|
||||
return {"accuracy": 1.0 if correct else 0.0}, sample
|
||||
|
||||
except Exception as e:
|
||||
return {"accuracy": 0.0}, {"error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(
|
||||
eval_runner(
|
||||
MMStar,
|
||||
split="val",
|
||||
temperature=0.0,
|
||||
max_tokens=256,
|
||||
)
|
||||
)
|
||||
174
environments/eval_environments/mmt_bench_environment.py
Normal file
174
environments/eval_environments/mmt_bench_environment.py
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
"""MMT-Bench evaluation environment."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
from string import ascii_uppercase
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from datasets import load_dataset
|
||||
from openai import AsyncOpenAI
|
||||
from PIL import Image
|
||||
|
||||
from environments.eval_environments.eval_base import EvalBase, eval_runner
|
||||
from environments.eval_environments.eval_helpers import (
|
||||
extract_letter_from_answer_tag,
|
||||
extract_mcqa_answer_with_fallback,
|
||||
)
|
||||
|
||||
|
||||
class MMTBench(EvalBase):
|
||||
"""MMT-Bench evaluation - multi-task multimodal benchmark."""
|
||||
|
||||
def setup_data(self) -> list:
|
||||
split = getattr(self, "split", "train")
|
||||
max_samples = getattr(self, "max_samples", None) # None = use all samples
|
||||
|
||||
try:
|
||||
# Try full dataset download first
|
||||
dataset = load_dataset("OpenGVLab/MMT-Bench", split=split)
|
||||
data = list(dataset)
|
||||
if max_samples:
|
||||
data = data[:max_samples]
|
||||
print(f"Loaded {len(data)} examples from MMT-Bench ({split})")
|
||||
return data
|
||||
except Exception as e:
|
||||
print(f"Warning: Full download failed, using streaming: {e}")
|
||||
# Fallback to streaming if full download fails (known column mismatch issue)
|
||||
try:
|
||||
dataset = load_dataset("OpenGVLab/MMT-Bench", split=split, streaming=True)
|
||||
if max_samples:
|
||||
data = list(dataset.take(max_samples))
|
||||
else:
|
||||
# Stream all available samples
|
||||
data = []
|
||||
for i, item in enumerate(dataset):
|
||||
data.append(item)
|
||||
if i % 5000 == 0 and i > 0:
|
||||
print(f" Streamed {i} samples...")
|
||||
print(f"Loaded {len(data)} examples from MMT-Bench ({split}, streaming)")
|
||||
return data
|
||||
except Exception:
|
||||
raise ValueError(f"Could not load MMT-Bench dataset: {e}")
|
||||
|
||||
def encode_image(self, pil_image: Image.Image) -> str:
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
def get_image_base64(self, item: dict) -> Optional[str]:
|
||||
for key in ["image", "decoded_image"]:
|
||||
if key in item and item[key] is not None:
|
||||
val = item[key]
|
||||
if isinstance(val, Image.Image):
|
||||
return self.encode_image(val)
|
||||
elif isinstance(val, str) and len(val) > 100:
|
||||
# Already base64-encoded string
|
||||
return val
|
||||
return None
|
||||
|
||||
def build_messages(self, item: dict) -> List[dict]:
|
||||
image_base64 = self.get_image_base64(item)
|
||||
question = item.get("question", "")
|
||||
hint = item.get("hint", "")
|
||||
|
||||
options = {}
|
||||
for letter in ascii_uppercase[:8]: # Support up to 8 options
|
||||
if letter in item and item[letter] is not None:
|
||||
val = item[letter]
|
||||
if isinstance(val, str) and val.strip():
|
||||
options[letter] = val
|
||||
|
||||
prompt = ""
|
||||
if hint and str(hint).strip() and str(hint).lower() != "nan":
|
||||
prompt += f"Hint: {hint}\n"
|
||||
prompt += f"Question: {question}\n"
|
||||
|
||||
if options:
|
||||
prompt += "Options:\n"
|
||||
for letter in sorted(options.keys()):
|
||||
prompt += f"{letter}. {options[letter]}\n"
|
||||
prompt += "\nPlease select the correct answer from the options above."
|
||||
|
||||
content = []
|
||||
if image_base64:
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
||||
})
|
||||
content.append({"type": "text", "text": prompt})
|
||||
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def extract_answer(self, response: str, num_choices: int) -> Tuple[Optional[str], str]:
|
||||
valid_letters = set(ascii_uppercase[:num_choices])
|
||||
|
||||
letter, method = extract_letter_from_answer_tag(response, valid_letters)
|
||||
if letter:
|
||||
return letter, method
|
||||
|
||||
letter, method = extract_mcqa_answer_with_fallback(response, num_choices)
|
||||
return letter, method
|
||||
|
||||
async def run_item(self, client: AsyncOpenAI, data_item: dict) -> Tuple[dict, dict]:
|
||||
try:
|
||||
messages = self.build_messages(data_item)
|
||||
|
||||
gen_params = self.get_generation_params()
|
||||
completion = await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=gen_params["temperature"],
|
||||
max_tokens=gen_params["max_tokens"],
|
||||
)
|
||||
|
||||
if not completion.choices:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
message = completion.choices[0].message
|
||||
response = message.content or ""
|
||||
|
||||
if not response:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
answer = data_item.get("answer", "")
|
||||
|
||||
num_choices = sum(
|
||||
1 for letter in ascii_uppercase[:8]
|
||||
if letter in data_item and data_item[letter] is not None
|
||||
and isinstance(data_item[letter], str) and data_item[letter].strip()
|
||||
)
|
||||
num_choices = max(num_choices, 4)
|
||||
|
||||
extracted, method = self.extract_answer(response, num_choices)
|
||||
|
||||
correct = False
|
||||
if extracted and answer:
|
||||
correct = extracted.upper() == str(answer).upper()
|
||||
|
||||
sample = {
|
||||
"id": data_item.get("index", data_item.get("id", "")),
|
||||
"question": data_item.get("question", "")[:200],
|
||||
"task": data_item.get("task", ""),
|
||||
"answer": answer,
|
||||
"prediction": extracted,
|
||||
"raw_response": response[:500],
|
||||
"correct": correct,
|
||||
"extraction_method": method,
|
||||
}
|
||||
|
||||
return {"accuracy": 1.0 if correct else 0.0}, sample
|
||||
|
||||
except Exception as e:
|
||||
return {"accuracy": 0.0}, {"error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(
|
||||
eval_runner(
|
||||
MMTBench,
|
||||
split="val",
|
||||
temperature=0.0,
|
||||
max_tokens=256,
|
||||
)
|
||||
)
|
||||
187
environments/eval_environments/mmvet_environment.py
Normal file
187
environments/eval_environments/mmvet_environment.py
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
"""MMVet evaluation environment."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import openai
|
||||
from datasets import load_dataset
|
||||
from openai import AsyncOpenAI
|
||||
from PIL import Image
|
||||
|
||||
from environments.eval_environments.eval_base import EvalBase, eval_runner
|
||||
|
||||
|
||||
class MMVet(EvalBase):
|
||||
"""MMVet evaluation - comprehensive VLM capability benchmark with GPT-based scoring."""
|
||||
|
||||
def setup_data(self) -> list:
|
||||
split = getattr(self, "split", "test")
|
||||
|
||||
try:
|
||||
dataset = load_dataset("lmms-lab/MMVet", split=split)
|
||||
print(f"Loaded {len(dataset)} examples from MMVet ({split})")
|
||||
return list(dataset)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not load MMVet: {e}")
|
||||
try:
|
||||
dataset = load_dataset("whyu/MM-Vet", split=split)
|
||||
print(f"Loaded {len(dataset)} examples from MMVet ({split})")
|
||||
return list(dataset)
|
||||
except Exception:
|
||||
raise ValueError(f"Could not load MMVet dataset: {e}")
|
||||
|
||||
def encode_image(self, pil_image: Image.Image) -> str:
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
def get_image_base64(self, item: dict) -> Optional[str]:
|
||||
for key in ["image", "decoded_image"]:
|
||||
if key in item and item[key] is not None:
|
||||
if isinstance(item[key], Image.Image):
|
||||
return self.encode_image(item[key])
|
||||
return None
|
||||
|
||||
def build_messages(self, item: dict) -> List[dict]:
|
||||
image_base64 = self.get_image_base64(item)
|
||||
question = item.get("question", "")
|
||||
|
||||
content = []
|
||||
if image_base64:
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
||||
})
|
||||
content.append({"type": "text", "text": question})
|
||||
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
async def gpt_score(self, question: str, answer: str, prediction: str) -> float:
|
||||
"""Use GPT to score the prediction against the ground truth answer."""
|
||||
judge_model = getattr(self, "judge_model", "gpt-4o-mini")
|
||||
judge_base_url = getattr(self, "judge_base_url", "https://api.openai.com/v1")
|
||||
judge_api_key = os.environ.get(
|
||||
getattr(self, "judge_api_key_env", "OPENAI_API_KEY"), ""
|
||||
)
|
||||
|
||||
if not judge_api_key:
|
||||
pred_lower = prediction.lower().strip()
|
||||
ans_lower = answer.lower().strip()
|
||||
if pred_lower == ans_lower:
|
||||
return 1.0
|
||||
elif ans_lower in pred_lower or pred_lower in ans_lower:
|
||||
return 0.5
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
judge_client = openai.AsyncOpenAI(
|
||||
api_key=judge_api_key,
|
||||
base_url=judge_base_url,
|
||||
)
|
||||
|
||||
prompt = f"""You are evaluating the quality of a model's answer compared to a reference answer.
|
||||
|
||||
Question: {question}
|
||||
|
||||
Reference Answer: {answer}
|
||||
|
||||
Model's Answer: {prediction}
|
||||
|
||||
Score the model's answer on a scale from 0 to 1:
|
||||
- 1.0: Completely correct and matches the reference
|
||||
- 0.5-0.9: Partially correct or captures the main idea
|
||||
- 0.1-0.4: Somewhat related but mostly incorrect
|
||||
- 0.0: Completely wrong or irrelevant
|
||||
|
||||
Output ONLY a single number between 0 and 1."""
|
||||
|
||||
completion = await judge_client.chat.completions.create(
|
||||
model=judge_model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.0,
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
result = completion.choices[0].message.content.strip()
|
||||
try:
|
||||
score = float(result)
|
||||
return max(0.0, min(1.0, score))
|
||||
except ValueError:
|
||||
return 0.0
|
||||
|
||||
except Exception as e:
|
||||
print(f"GPT scoring error: {e}")
|
||||
pred_lower = prediction.lower().strip()
|
||||
ans_lower = answer.lower().strip()
|
||||
if pred_lower == ans_lower:
|
||||
return 1.0
|
||||
elif ans_lower in pred_lower:
|
||||
return 0.5
|
||||
return 0.0
|
||||
|
||||
async def run_item(self, client: AsyncOpenAI, data_item: dict) -> Tuple[dict, dict]:
|
||||
try:
|
||||
messages = self.build_messages(data_item)
|
||||
|
||||
gen_params = self.get_generation_params()
|
||||
completion = await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=gen_params["temperature"],
|
||||
max_tokens=gen_params["max_tokens"],
|
||||
)
|
||||
|
||||
if not completion.choices:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
message = completion.choices[0].message
|
||||
response = message.content or ""
|
||||
|
||||
if not response:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
question = data_item.get("question", "")
|
||||
answer = data_item.get("answer", "")
|
||||
|
||||
use_gpt_scoring = getattr(self, "use_gpt_scoring", True)
|
||||
if use_gpt_scoring:
|
||||
score = await self.gpt_score(question, answer, response)
|
||||
else:
|
||||
pred_lower = response.lower().strip()
|
||||
ans_lower = answer.lower().strip()
|
||||
if pred_lower == ans_lower:
|
||||
score = 1.0
|
||||
elif ans_lower in pred_lower:
|
||||
score = 0.5
|
||||
else:
|
||||
score = 0.0
|
||||
|
||||
sample = {
|
||||
"id": data_item.get("index", data_item.get("id", "")),
|
||||
"question": question[:200],
|
||||
"category": data_item.get("capability", data_item.get("category", "")),
|
||||
"answer": answer[:200],
|
||||
"prediction": response[:500],
|
||||
"score": score,
|
||||
}
|
||||
|
||||
return {"accuracy": score}, sample
|
||||
|
||||
except Exception as e:
|
||||
return {"accuracy": 0.0}, {"error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(
|
||||
eval_runner(
|
||||
MMVet,
|
||||
split="test",
|
||||
use_gpt_scoring=True,
|
||||
judge_model="gpt-4o-mini",
|
||||
temperature=0.0,
|
||||
max_tokens=512,
|
||||
)
|
||||
)
|
||||
158
environments/eval_environments/mmvp_environment.py
Normal file
158
environments/eval_environments/mmvp_environment.py
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
"""MMVP (Multimodal Visual Perception) evaluation environment."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
from string import ascii_uppercase
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from datasets import load_dataset
|
||||
from openai import AsyncOpenAI
|
||||
from PIL import Image
|
||||
|
||||
from environments.eval_environments.eval_base import EvalBase, eval_runner
|
||||
from environments.eval_environments.eval_helpers import (
|
||||
extract_letter_from_answer_tag,
|
||||
extract_mcqa_answer_with_fallback,
|
||||
)
|
||||
|
||||
|
||||
class MMVP(EvalBase):
|
||||
"""MMVP evaluation - visual perception benchmark testing CLIP blindspots."""
|
||||
|
||||
def setup_data(self) -> list:
|
||||
split = getattr(self, "split", "train") # MMVP only has train split
|
||||
|
||||
try:
|
||||
dataset = load_dataset("MMVP/MMVP", split=split)
|
||||
print(f"Loaded {len(dataset)} examples from MMVP ({split})")
|
||||
return list(dataset)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not load MMVP: {e}")
|
||||
try:
|
||||
dataset = load_dataset("lmms-lab/MMVP", split=split)
|
||||
print(f"Loaded {len(dataset)} examples from MMVP ({split})")
|
||||
return list(dataset)
|
||||
except Exception:
|
||||
raise ValueError(f"Could not load MMVP dataset: {e}")
|
||||
|
||||
def encode_image(self, pil_image: Image.Image) -> str:
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
def get_images(self, item: dict) -> List[str]:
|
||||
"""Get all images from item (MMVP typically has paired images)."""
|
||||
images = []
|
||||
for i in range(1, 3):
|
||||
key = f"image_{i}"
|
||||
if key in item and item[key] is not None:
|
||||
if isinstance(item[key], Image.Image):
|
||||
images.append(self.encode_image(item[key]))
|
||||
|
||||
if not images and "image" in item and item["image"] is not None:
|
||||
if isinstance(item["image"], Image.Image):
|
||||
images.append(self.encode_image(item["image"]))
|
||||
|
||||
return images
|
||||
|
||||
def build_messages(self, item: dict) -> List[dict]:
|
||||
images = self.get_images(item)
|
||||
question = item.get("question", "")
|
||||
|
||||
options = {}
|
||||
for letter in ascii_uppercase[:4]: # MMVP typically has 2-4 options
|
||||
if letter in item and item[letter] is not None:
|
||||
val = item[letter]
|
||||
if isinstance(val, str) and val.strip():
|
||||
options[letter] = val
|
||||
|
||||
prompt = f"Question: {question}\n"
|
||||
if options:
|
||||
prompt += "Options:\n"
|
||||
for letter in sorted(options.keys()):
|
||||
prompt += f"{letter}. {options[letter]}\n"
|
||||
prompt += "\nPlease select the correct answer from the options above."
|
||||
|
||||
content = []
|
||||
for img_b64 in images:
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{img_b64}"},
|
||||
})
|
||||
content.append({"type": "text", "text": prompt})
|
||||
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def extract_answer(self, response: str, num_choices: int) -> Tuple[Optional[str], str]:
|
||||
valid_letters = set(ascii_uppercase[:num_choices])
|
||||
|
||||
letter, method = extract_letter_from_answer_tag(response, valid_letters)
|
||||
if letter:
|
||||
return letter, method
|
||||
|
||||
letter, method = extract_mcqa_answer_with_fallback(response, num_choices)
|
||||
return letter, method
|
||||
|
||||
async def run_item(self, client: AsyncOpenAI, data_item: dict) -> Tuple[dict, dict]:
|
||||
try:
|
||||
messages = self.build_messages(data_item)
|
||||
|
||||
gen_params = self.get_generation_params()
|
||||
completion = await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=gen_params["temperature"],
|
||||
max_tokens=gen_params["max_tokens"],
|
||||
)
|
||||
|
||||
if not completion.choices:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
message = completion.choices[0].message
|
||||
response = message.content or ""
|
||||
|
||||
if not response:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
answer = data_item.get("answer", "")
|
||||
|
||||
num_choices = sum(
|
||||
1 for letter in ascii_uppercase[:4]
|
||||
if letter in data_item and data_item[letter] is not None
|
||||
and isinstance(data_item[letter], str) and data_item[letter].strip()
|
||||
)
|
||||
num_choices = max(num_choices, 2)
|
||||
|
||||
extracted, method = self.extract_answer(response, num_choices)
|
||||
|
||||
correct = False
|
||||
if extracted and answer:
|
||||
correct = extracted.upper() == str(answer).upper()
|
||||
|
||||
sample = {
|
||||
"id": data_item.get("index", data_item.get("id", "")),
|
||||
"question": data_item.get("question", "")[:200],
|
||||
"category": data_item.get("category", ""),
|
||||
"answer": answer,
|
||||
"prediction": extracted,
|
||||
"raw_response": response[:500],
|
||||
"correct": correct,
|
||||
"extraction_method": method,
|
||||
}
|
||||
|
||||
return {"accuracy": 1.0 if correct else 0.0}, sample
|
||||
|
||||
except Exception as e:
|
||||
return {"accuracy": 0.0}, {"error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(
|
||||
eval_runner(
|
||||
MMVP,
|
||||
split="test",
|
||||
temperature=0.0,
|
||||
max_tokens=256,
|
||||
)
|
||||
)
|
||||
152
environments/eval_environments/ocrbench_environment.py
Normal file
152
environments/eval_environments/ocrbench_environment.py
Normal file
|
|
@ -0,0 +1,152 @@
|
|||
"""OCRBench evaluation environment."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from datasets import load_dataset
|
||||
from openai import AsyncOpenAI
|
||||
from PIL import Image
|
||||
|
||||
from environments.eval_environments.eval_base import EvalBase, eval_runner
|
||||
|
||||
|
||||
class OCRBench(EvalBase):
|
||||
"""OCRBench evaluation - OCR capabilities benchmark."""
|
||||
|
||||
# Categories and their scoring
|
||||
CATEGORIES = [
|
||||
'Regular Text Recognition',
|
||||
'Irregular Text Recognition',
|
||||
'Artistic Text Recognition',
|
||||
'Handwriting Recognition',
|
||||
'Digit String Recognition',
|
||||
'Non-Semantic Text Recognition',
|
||||
'Scene Text-centric VQA',
|
||||
'Doc-oriented VQA',
|
||||
'Key Information Extraction',
|
||||
'Handwritten Mathematical Expression Recognition',
|
||||
]
|
||||
|
||||
def setup_data(self) -> list:
|
||||
split = getattr(self, "split", "test")
|
||||
|
||||
try:
|
||||
dataset = load_dataset("echo840/OCRBench", split=split)
|
||||
print(f"Loaded {len(dataset)} examples from OCRBench ({split})")
|
||||
return list(dataset)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not load OCRBench: {e}")
|
||||
try:
|
||||
dataset = load_dataset("lmms-lab/OCRBench", split=split)
|
||||
print(f"Loaded {len(dataset)} examples from OCRBench ({split})")
|
||||
return list(dataset)
|
||||
except Exception:
|
||||
raise ValueError(f"Could not load OCRBench dataset: {e}")
|
||||
|
||||
def encode_image(self, pil_image: Image.Image) -> str:
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
def get_image_base64(self, item: dict) -> Optional[str]:
|
||||
for key in ["image", "decoded_image"]:
|
||||
if key in item and item[key] is not None:
|
||||
if isinstance(item[key], Image.Image):
|
||||
return self.encode_image(item[key])
|
||||
return None
|
||||
|
||||
def build_messages(self, item: dict) -> List[dict]:
|
||||
image_base64 = self.get_image_base64(item)
|
||||
question = item.get("question", "")
|
||||
|
||||
prompt = f"{question}\n\nAnswer the question using a single word or phrase."
|
||||
|
||||
content = []
|
||||
if image_base64:
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
||||
})
|
||||
content.append({"type": "text", "text": prompt})
|
||||
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def score_ocr(self, prediction: str, answers: List[str], category: str) -> bool:
|
||||
"""Category-specific scoring for OCR tasks."""
|
||||
predict = prediction.strip()
|
||||
|
||||
if category == 'Handwritten Mathematical Expression Recognition':
|
||||
predict_clean = predict.replace('\n', ' ').replace(' ', '')
|
||||
for answer in answers:
|
||||
answer_clean = answer.strip().replace('\n', ' ').replace(' ', '')
|
||||
if answer_clean in predict_clean:
|
||||
return True
|
||||
else:
|
||||
predict_lower = predict.lower().replace('\n', ' ')
|
||||
for answer in answers:
|
||||
answer_lower = answer.lower().strip().replace('\n', ' ')
|
||||
if answer_lower in predict_lower:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def run_item(self, client: AsyncOpenAI, data_item: dict) -> Tuple[dict, dict]:
|
||||
try:
|
||||
messages = self.build_messages(data_item)
|
||||
|
||||
gen_params = self.get_generation_params()
|
||||
completion = await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=gen_params["temperature"],
|
||||
max_tokens=gen_params["max_tokens"],
|
||||
)
|
||||
|
||||
if not completion.choices:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
message = completion.choices[0].message
|
||||
response = message.content or ""
|
||||
|
||||
if not response:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
answers = data_item.get("answer", [])
|
||||
if isinstance(answers, str):
|
||||
try:
|
||||
answers = eval(answers)
|
||||
except Exception:
|
||||
answers = [answers]
|
||||
if not isinstance(answers, list):
|
||||
answers = [answers]
|
||||
|
||||
category = data_item.get("category", "")
|
||||
|
||||
correct = self.score_ocr(response, answers, category)
|
||||
|
||||
sample = {
|
||||
"id": data_item.get("index", data_item.get("id", "")),
|
||||
"question": data_item.get("question", "")[:200],
|
||||
"category": category,
|
||||
"answer": answers[0] if answers else "",
|
||||
"prediction": response[:200],
|
||||
"correct": correct,
|
||||
}
|
||||
|
||||
return {"accuracy": 1.0 if correct else 0.0}, sample
|
||||
|
||||
except Exception as e:
|
||||
return {"accuracy": 0.0}, {"error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(
|
||||
eval_runner(
|
||||
OCRBench,
|
||||
split="test",
|
||||
temperature=0.0,
|
||||
max_tokens=256,
|
||||
)
|
||||
)
|
||||
144
environments/eval_environments/pope_environment.py
Normal file
144
environments/eval_environments/pope_environment.py
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
"""POPE (Polling-based Object Probing Evaluation) evaluation environment."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import re
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from datasets import load_dataset
|
||||
from openai import AsyncOpenAI
|
||||
from PIL import Image
|
||||
|
||||
from environments.eval_environments.eval_base import EvalBase, eval_runner
|
||||
|
||||
|
||||
class POPE(EvalBase):
|
||||
"""POPE evaluation - object hallucination benchmark with yes/no questions."""
|
||||
|
||||
def setup_data(self) -> list:
|
||||
split = getattr(self, "split", "test")
|
||||
variant = getattr(self, "variant", "random") # random, popular, adversarial
|
||||
|
||||
try:
|
||||
dataset = load_dataset("lmms-lab/POPE", split=split)
|
||||
print(f"Loaded {len(dataset)} examples from POPE ({split})")
|
||||
return list(dataset)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not load POPE: {e}")
|
||||
try:
|
||||
dataset = load_dataset("OpenGVLab/POPE", split=split)
|
||||
print(f"Loaded {len(dataset)} examples from POPE ({split})")
|
||||
return list(dataset)
|
||||
except Exception:
|
||||
raise ValueError(f"Could not load POPE dataset: {e}")
|
||||
|
||||
def encode_image(self, pil_image: Image.Image) -> str:
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
def get_image_base64(self, item: dict) -> Optional[str]:
|
||||
for key in ["image", "decoded_image"]:
|
||||
if key in item and item[key] is not None:
|
||||
if isinstance(item[key], Image.Image):
|
||||
return self.encode_image(item[key])
|
||||
return None
|
||||
|
||||
def build_messages(self, item: dict) -> List[dict]:
|
||||
image_base64 = self.get_image_base64(item)
|
||||
question = item.get("question", "")
|
||||
|
||||
prompt = f"{question}\n\nPlease answer yes or no."
|
||||
|
||||
content = []
|
||||
if image_base64:
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
||||
})
|
||||
content.append({"type": "text", "text": prompt})
|
||||
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def extract_yorn(self, response: str) -> str:
|
||||
"""Extract Yes/No from response."""
|
||||
response_lower = response.lower().strip()
|
||||
|
||||
if response_lower.startswith("yes"):
|
||||
return "Yes"
|
||||
if response_lower.startswith("no"):
|
||||
return "No"
|
||||
|
||||
yes_patterns = [r'\byes\b', r'\btrue\b', r'\bcorrect\b', r'\baffirmative\b']
|
||||
no_patterns = [r'\bno\b', r'\bfalse\b', r'\bincorrect\b', r'\bnegative\b']
|
||||
|
||||
for pattern in yes_patterns:
|
||||
if re.search(pattern, response_lower):
|
||||
return "Yes"
|
||||
|
||||
for pattern in no_patterns:
|
||||
if re.search(pattern, response_lower):
|
||||
return "No"
|
||||
|
||||
return "Unknown"
|
||||
|
||||
async def run_item(self, client: AsyncOpenAI, data_item: dict) -> Tuple[dict, dict]:
|
||||
try:
|
||||
messages = self.build_messages(data_item)
|
||||
|
||||
gen_params = self.get_generation_params()
|
||||
completion = await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=gen_params["temperature"],
|
||||
max_tokens=gen_params["max_tokens"],
|
||||
)
|
||||
|
||||
if not completion.choices:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
message = completion.choices[0].message
|
||||
response = message.content or ""
|
||||
|
||||
if not response:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
answer = data_item.get("answer", "")
|
||||
extracted = self.extract_yorn(response)
|
||||
|
||||
answer_norm = answer.strip().lower()
|
||||
if answer_norm in ["yes", "true", "1"]:
|
||||
answer_norm = "Yes"
|
||||
elif answer_norm in ["no", "false", "0"]:
|
||||
answer_norm = "No"
|
||||
else:
|
||||
answer_norm = answer.strip()
|
||||
|
||||
correct = extracted == answer_norm
|
||||
|
||||
sample = {
|
||||
"id": data_item.get("index", data_item.get("question_id", "")),
|
||||
"question": data_item.get("question", "")[:200],
|
||||
"category": data_item.get("category", ""),
|
||||
"answer": answer_norm,
|
||||
"prediction": extracted,
|
||||
"raw_response": response[:200],
|
||||
"correct": correct,
|
||||
}
|
||||
|
||||
return {"accuracy": 1.0 if correct else 0.0}, sample
|
||||
|
||||
except Exception as e:
|
||||
return {"accuracy": 0.0}, {"error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(
|
||||
eval_runner(
|
||||
POPE,
|
||||
split="test",
|
||||
temperature=0.0,
|
||||
max_tokens=64,
|
||||
)
|
||||
)
|
||||
193
environments/eval_environments/seedbench2_plus_environment.py
Normal file
193
environments/eval_environments/seedbench2_plus_environment.py
Normal file
|
|
@ -0,0 +1,193 @@
|
|||
"""SEED-Bench2-Plus evaluation environment."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
from string import ascii_uppercase
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from datasets import load_dataset
|
||||
from openai import AsyncOpenAI
|
||||
from PIL import Image
|
||||
|
||||
from environments.eval_environments.eval_base import EvalBase, eval_runner
|
||||
from environments.eval_environments.eval_helpers import (
|
||||
extract_letter_from_answer_tag,
|
||||
extract_mcqa_answer_with_fallback,
|
||||
)
|
||||
|
||||
|
||||
class SEEDBench2Plus(EvalBase):
|
||||
"""SEED-Bench2-Plus evaluation - comprehensive visual understanding benchmark."""
|
||||
|
||||
def setup_data(self) -> list:
|
||||
split = getattr(self, "split", "test")
|
||||
max_samples = getattr(self, "max_samples", None)
|
||||
|
||||
try:
|
||||
# Use streaming to avoid memory issues with this large dataset
|
||||
dataset = load_dataset("lmms-lab/SEED-Bench-2", split=split, streaming=True)
|
||||
|
||||
# Take samples from streaming dataset
|
||||
if max_samples:
|
||||
data = list(dataset.take(max_samples))
|
||||
else:
|
||||
# Default to 1000 samples to avoid loading entire 24k dataset
|
||||
data = list(dataset.take(1000))
|
||||
|
||||
print(f"Loaded {len(data)} examples from SEED-Bench2 ({split}, streaming)")
|
||||
return data
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not load SEED-Bench2: {e}")
|
||||
try:
|
||||
dataset = load_dataset("lmms-lab/SEED-Bench", split=split, streaming=True)
|
||||
if max_samples:
|
||||
data = list(dataset.take(max_samples))
|
||||
else:
|
||||
data = list(dataset.take(1000))
|
||||
print(f"Loaded {len(data)} examples from SEED-Bench ({split}, streaming)")
|
||||
return data
|
||||
except Exception:
|
||||
raise ValueError(f"Could not load SEED-Bench2-Plus dataset: {e}")
|
||||
|
||||
def encode_image(self, pil_image: Image.Image) -> str:
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
def get_image_base64(self, item: dict) -> Optional[str]:
|
||||
for key in ["image", "decoded_image"]:
|
||||
if key in item and item[key] is not None:
|
||||
val = item[key]
|
||||
if isinstance(val, Image.Image):
|
||||
return self.encode_image(val)
|
||||
elif isinstance(val, list) and len(val) > 0:
|
||||
# SEED-Bench-2 stores images as a list of PIL images
|
||||
if isinstance(val[0], Image.Image):
|
||||
return self.encode_image(val[0])
|
||||
return None
|
||||
|
||||
def build_messages(self, item: dict) -> List[dict]:
|
||||
image_base64 = self.get_image_base64(item)
|
||||
question = item.get("question", "")
|
||||
|
||||
options = {}
|
||||
for letter in ascii_uppercase[:6]:
|
||||
# Check for choice_a, choice_b format
|
||||
choice_key = f"choice_{letter.lower()}"
|
||||
if choice_key in item and item[choice_key] is not None:
|
||||
val = item[choice_key]
|
||||
if isinstance(val, str) and val.strip():
|
||||
options[letter] = val
|
||||
elif letter in item and item[letter] is not None:
|
||||
val = item[letter]
|
||||
if isinstance(val, str) and val.strip():
|
||||
options[letter] = val
|
||||
|
||||
if not options:
|
||||
choices = item.get("choices", [])
|
||||
if isinstance(choices, str):
|
||||
try:
|
||||
choices = eval(choices)
|
||||
except Exception:
|
||||
choices = []
|
||||
for i, choice in enumerate(choices):
|
||||
options[ascii_uppercase[i]] = choice
|
||||
|
||||
prompt = f"Question: {question}\n"
|
||||
if options:
|
||||
prompt += "Options:\n"
|
||||
for letter in sorted(options.keys()):
|
||||
prompt += f"{letter}. {options[letter]}\n"
|
||||
prompt += "\nPlease select the correct answer from the options above."
|
||||
|
||||
content = []
|
||||
if image_base64:
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
||||
})
|
||||
content.append({"type": "text", "text": prompt})
|
||||
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def extract_answer(self, response: str, num_choices: int) -> Tuple[Optional[str], str]:
|
||||
valid_letters = set(ascii_uppercase[:num_choices])
|
||||
|
||||
letter, method = extract_letter_from_answer_tag(response, valid_letters)
|
||||
if letter:
|
||||
return letter, method
|
||||
|
||||
letter, method = extract_mcqa_answer_with_fallback(response, num_choices)
|
||||
return letter, method
|
||||
|
||||
async def run_item(self, client: AsyncOpenAI, data_item: dict) -> Tuple[dict, dict]:
|
||||
try:
|
||||
messages = self.build_messages(data_item)
|
||||
|
||||
gen_params = self.get_generation_params()
|
||||
completion = await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=gen_params["temperature"],
|
||||
max_tokens=gen_params["max_tokens"],
|
||||
)
|
||||
|
||||
if not completion.choices:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
message = completion.choices[0].message
|
||||
response = message.content or ""
|
||||
|
||||
if not response:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
answer = data_item.get("answer", "")
|
||||
|
||||
choices = data_item.get("choices", [])
|
||||
if isinstance(choices, str):
|
||||
try:
|
||||
choices = eval(choices)
|
||||
except Exception:
|
||||
choices = []
|
||||
|
||||
num_choices = len(choices) if choices else 4
|
||||
if num_choices == 0:
|
||||
num_choices = sum(
|
||||
1 for letter in ascii_uppercase[:6]
|
||||
if letter in data_item and data_item[letter] is not None
|
||||
)
|
||||
num_choices = max(num_choices, 4)
|
||||
|
||||
extracted, method = self.extract_answer(response, num_choices)
|
||||
|
||||
correct = False
|
||||
if extracted and answer:
|
||||
correct = extracted.upper() == str(answer).upper()
|
||||
|
||||
sample = {
|
||||
"id": data_item.get("index", data_item.get("question_id", "")),
|
||||
"question": data_item.get("question", "")[:200],
|
||||
"category": data_item.get("question_type_id", data_item.get("category", "")),
|
||||
"answer": answer,
|
||||
"prediction": extracted,
|
||||
"raw_response": response[:500],
|
||||
"correct": correct,
|
||||
"extraction_method": method,
|
||||
}
|
||||
|
||||
return {"accuracy": 1.0 if correct else 0.0}, sample
|
||||
|
||||
except Exception as e:
|
||||
return {"accuracy": 0.0}, {"error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(
|
||||
eval_runner(
|
||||
SEEDBench2Plus,
|
||||
split="test",
|
||||
temperature=0.0,
|
||||
max_tokens=256,
|
||||
)
|
||||
)
|
||||
165
environments/eval_environments/vlmblind_environment.py
Normal file
165
environments/eval_environments/vlmblind_environment.py
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
"""VLMBlind (VLMs are Blind) evaluation environment."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import re
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from datasets import load_dataset
|
||||
from openai import AsyncOpenAI
|
||||
from PIL import Image
|
||||
|
||||
from environments.eval_environments.eval_base import EvalBase, eval_runner
|
||||
|
||||
|
||||
class VLMBlind(EvalBase):
|
||||
"""VLMBlind evaluation - tests basic visual perception abilities of VLMs."""
|
||||
|
||||
TASK_PATTERNS = {
|
||||
"Subway Connections": r"\{([^}]+)\}",
|
||||
"Nested Squares": r"\{([^}]+)\}",
|
||||
"Line Plot Intersections": r"\{([^}]+)\}",
|
||||
"Touching Circles": None, # Substring match
|
||||
"Counting Grid": r"(\d+)\s*(?:rows?|r).*?(\d+)\s*(?:columns?|cols?|c)|(\d+)\s*[xX×]\s*(\d+)",
|
||||
"Olympic Counting": None, # Substring match
|
||||
"Circled Letter": r"\{([^}]+)\}",
|
||||
}
|
||||
|
||||
def setup_data(self) -> list:
|
||||
# XAI/vlmsareblind only has 'valid' split
|
||||
split = getattr(self, "split", "valid")
|
||||
|
||||
try:
|
||||
dataset = load_dataset("XAI/vlmsareblind", split=split)
|
||||
print(f"Loaded {len(dataset)} examples from VLMBlind ({split})")
|
||||
return list(dataset)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not load VLMBlind: {e}")
|
||||
try:
|
||||
# Try valid split explicitly
|
||||
dataset = load_dataset("XAI/vlmsareblind", split="valid")
|
||||
print(f"Loaded {len(dataset)} examples from VLMBlind (valid)")
|
||||
return list(dataset)
|
||||
except Exception:
|
||||
raise ValueError(f"Could not load VLMBlind dataset: {e}")
|
||||
|
||||
def encode_image(self, pil_image: Image.Image) -> str:
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
def get_image_base64(self, item: dict) -> Optional[str]:
|
||||
for key in ["image", "decoded_image"]:
|
||||
if key in item and item[key] is not None:
|
||||
if isinstance(item[key], Image.Image):
|
||||
return self.encode_image(item[key])
|
||||
return None
|
||||
|
||||
def build_messages(self, item: dict) -> List[dict]:
|
||||
image_base64 = self.get_image_base64(item)
|
||||
# XAI/vlmsareblind uses 'prompt' instead of 'question'
|
||||
question = item.get("prompt", item.get("question", ""))
|
||||
|
||||
prompt = f"{question}\n\nProvide your answer."
|
||||
|
||||
content = []
|
||||
if image_base64:
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
||||
})
|
||||
content.append({"type": "text", "text": prompt})
|
||||
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def extract_and_score(self, response: str, answer: str, task: str) -> Tuple[bool, str]:
|
||||
"""Task-specific answer extraction and scoring."""
|
||||
response_lower = response.lower().strip()
|
||||
answer_lower = str(answer).lower().strip()
|
||||
|
||||
if task in ["Subway Connections", "Nested Squares", "Line Plot Intersections", "Circled Letter"]:
|
||||
match = re.search(r"\{([^}]+)\}", response)
|
||||
if match:
|
||||
extracted = match.group(1).strip().lower()
|
||||
return extracted == answer_lower, extracted
|
||||
return answer_lower in response_lower, response_lower[:50]
|
||||
|
||||
elif task == "Touching Circles":
|
||||
return answer_lower in response_lower, response_lower[:50]
|
||||
|
||||
elif "Counting Grid" in task or "Grid" in task:
|
||||
patterns = [
|
||||
r"(\d+)\s*[xX×]\s*(\d+)",
|
||||
r"(\d+)\s*(?:rows?|r).*?(\d+)\s*(?:columns?|cols?|c)",
|
||||
r"(\d+)\s*(?:columns?|cols?|c).*?(\d+)\s*(?:rows?|r)",
|
||||
]
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, response)
|
||||
if match:
|
||||
groups = match.groups()
|
||||
extracted = f"{groups[0]}x{groups[1]}"
|
||||
ans_match = re.search(r"(\d+)\s*[xX×,]\s*(\d+)", answer)
|
||||
if ans_match:
|
||||
answer_parsed = f"{ans_match.group(1)}x{ans_match.group(2)}"
|
||||
return extracted == answer_parsed, extracted
|
||||
return answer_lower in response_lower, response_lower[:50]
|
||||
|
||||
elif "Olympic" in task or "Counting" in task:
|
||||
return answer_lower in response_lower, response_lower[:50]
|
||||
|
||||
else:
|
||||
return answer_lower in response_lower, response_lower[:50]
|
||||
|
||||
async def run_item(self, client: AsyncOpenAI, data_item: dict) -> Tuple[dict, dict]:
|
||||
try:
|
||||
messages = self.build_messages(data_item)
|
||||
|
||||
gen_params = self.get_generation_params()
|
||||
completion = await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=gen_params["temperature"],
|
||||
max_tokens=gen_params["max_tokens"],
|
||||
)
|
||||
|
||||
if not completion.choices:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
message = completion.choices[0].message
|
||||
response = message.content or ""
|
||||
|
||||
if not response:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
# XAI/vlmsareblind uses 'groundtruth' instead of 'answer'
|
||||
answer = data_item.get("groundtruth", data_item.get("answer", ""))
|
||||
task = data_item.get("task", data_item.get("category", ""))
|
||||
|
||||
correct, extracted = self.extract_and_score(response, answer, task)
|
||||
|
||||
sample = {
|
||||
"id": data_item.get("index", data_item.get("id", "")),
|
||||
"question": data_item.get("prompt", data_item.get("question", ""))[:200],
|
||||
"task": task,
|
||||
"answer": answer,
|
||||
"prediction": extracted,
|
||||
"raw_response": response[:500],
|
||||
"correct": correct,
|
||||
}
|
||||
|
||||
return {"accuracy": 1.0 if correct else 0.0}, sample
|
||||
|
||||
except Exception as e:
|
||||
return {"accuracy": 0.0}, {"error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(
|
||||
eval_runner(
|
||||
VLMBlind,
|
||||
split="test",
|
||||
temperature=0.0,
|
||||
max_tokens=512,
|
||||
)
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue