mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-27 17:23:08 +00:00
lots of vision benchmarks
This commit is contained in:
parent
a7a87a33e4
commit
75de490849
12 changed files with 3300 additions and 0 deletions
198
environments/eval_environments/chartqa_environment.py
Normal file
198
environments/eval_environments/chartqa_environment.py
Normal file
|
|
@ -0,0 +1,198 @@
|
|||
"""ChartQA evaluation environment."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from datasets import load_dataset
|
||||
from openai import AsyncOpenAI
|
||||
from PIL import Image
|
||||
|
||||
from environments.eval_environments.eval_base import EvalBase, eval_runner
|
||||
|
||||
|
||||
class ChartQA(EvalBase):
|
||||
"""
|
||||
ChartQA evaluation environment.
|
||||
|
||||
A benchmark for question answering about charts with relaxed accuracy scoring.
|
||||
"""
|
||||
|
||||
def setup_data(self) -> list:
|
||||
subset = getattr(self, "subset", "human")
|
||||
dataset = load_dataset("ahmed-masry/ChartQA", split="test")
|
||||
|
||||
if subset == "human":
|
||||
dataset = dataset.filter(lambda x: x.get("type", "") == "human")
|
||||
elif subset == "augmented":
|
||||
dataset = dataset.filter(lambda x: x.get("type", "") == "augmented")
|
||||
|
||||
print(f"Loaded {len(dataset)} examples from ChartQA ({subset})")
|
||||
return list(dataset)
|
||||
|
||||
def encode_image(self, pil_image: Image.Image) -> str:
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
def get_image_base64(self, item: dict) -> str:
|
||||
images_path: Optional[str] = getattr(self, "images_path", None)
|
||||
if images_path:
|
||||
imgname = item.get("imgname", "")
|
||||
image_path = Path(images_path) / imgname
|
||||
with open(image_path, "rb") as f:
|
||||
return base64.b64encode(f.read()).decode("utf-8")
|
||||
|
||||
if "image" in item and item["image"] is not None:
|
||||
img = item["image"]
|
||||
if isinstance(img, bytes):
|
||||
return base64.b64encode(img).decode("utf-8")
|
||||
elif isinstance(img, Image.Image):
|
||||
return self.encode_image(img)
|
||||
else:
|
||||
raise ValueError(f"Unknown image type: {type(img)}")
|
||||
|
||||
raise ValueError("Could not find image for item")
|
||||
|
||||
def build_messages(self, item: dict) -> List[dict]:
|
||||
image_base64 = self.get_image_base64(item)
|
||||
query = item.get("query", item.get("question", ""))
|
||||
|
||||
prompt = f"""Answer this question about the chart. Provide only the answer, nothing else.
|
||||
|
||||
Question: {query}"""
|
||||
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
||||
},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
def extract_answer(self, response: str) -> str:
|
||||
response = response.strip()
|
||||
|
||||
patterns = [
|
||||
r"(?:the answer is|answer:)\s*(.+?)(?:\.|$)",
|
||||
r"^(\d+[\d,\.]*%?)$",
|
||||
r"^(yes|no)$",
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, response, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
if len(response.split()) <= 5:
|
||||
return response
|
||||
|
||||
first_line = response.split("\n")[0]
|
||||
return first_line.strip()
|
||||
|
||||
def _to_float(self, text: str) -> Optional[float]:
|
||||
"""
|
||||
Convert string to float, handling percentages.
|
||||
|
||||
Following VLMEvalKit: percentages are converted to decimals (5% -> 0.05).
|
||||
"""
|
||||
text = str(text).strip()
|
||||
try:
|
||||
# Remove commas and dollar signs
|
||||
text = text.replace(",", "").replace("$", "")
|
||||
if text.endswith("%"):
|
||||
# Convert percentage to decimal (VLMEvalKit behavior)
|
||||
return float(text.rstrip("%")) / 100.0
|
||||
else:
|
||||
return float(text)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def score_relaxed(self, prediction: str, answer: str) -> bool:
|
||||
"""
|
||||
Calculate relaxed correctness following VLMEvalKit.
|
||||
|
||||
For numeric answers: allows 5% relative tolerance.
|
||||
For non-numeric answers: exact match (case-insensitive).
|
||||
|
||||
Reference: https://arxiv.org/pdf/2203.10244.pdf, section 5.1
|
||||
"""
|
||||
pred = str(prediction).strip()
|
||||
ans = str(answer).strip()
|
||||
|
||||
relaxed_tolerance = getattr(self, "relaxed_tolerance", 0.05)
|
||||
|
||||
pred_float = self._to_float(pred)
|
||||
ans_float = self._to_float(ans)
|
||||
|
||||
if pred_float is not None and ans_float is not None:
|
||||
if ans_float == 0:
|
||||
return abs(pred_float) < 1e-6
|
||||
relative_change = abs(pred_float - ans_float) / abs(ans_float)
|
||||
return relative_change <= relaxed_tolerance
|
||||
|
||||
# Non-numeric: exact match (case-insensitive)
|
||||
return pred.lower() == ans.lower()
|
||||
|
||||
async def run_item(self, client: AsyncOpenAI, data_item: dict) -> Tuple[dict, dict]:
|
||||
try:
|
||||
messages = self.build_messages(data_item)
|
||||
|
||||
gen_params = self.get_generation_params()
|
||||
completion = await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=gen_params["temperature"],
|
||||
max_tokens=gen_params["max_tokens"],
|
||||
)
|
||||
|
||||
if not completion.choices:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
message = completion.choices[0].message
|
||||
response = message.content or ""
|
||||
if hasattr(message, "reasoning") and message.reasoning and not response:
|
||||
response = message.reasoning
|
||||
if not response and hasattr(message, "model_extra"):
|
||||
reasoning = message.model_extra.get("reasoning", "")
|
||||
if reasoning:
|
||||
response = reasoning
|
||||
|
||||
if not response:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
extracted = self.extract_answer(response)
|
||||
answer = data_item.get("label", data_item.get("answer", ""))
|
||||
correct = self.score_relaxed(extracted, answer)
|
||||
|
||||
sample = {
|
||||
"question": data_item.get("query", data_item.get("question", "")),
|
||||
"answer": answer,
|
||||
"prediction": extracted,
|
||||
"correct": correct,
|
||||
}
|
||||
|
||||
return {"accuracy": 1.0 if correct else 0.0}, sample
|
||||
|
||||
except Exception as e:
|
||||
return {"accuracy": 0.0}, {"error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(
|
||||
eval_runner(
|
||||
ChartQA,
|
||||
subset="human",
|
||||
relaxed_tolerance=0.05,
|
||||
temperature=0.0,
|
||||
max_tokens=2048,
|
||||
)
|
||||
)
|
||||
385
environments/eval_environments/charxiv_environment.py
Normal file
385
environments/eval_environments/charxiv_environment.py
Normal file
|
|
@ -0,0 +1,385 @@
|
|||
"""CharXiv evaluation environment."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import openai
|
||||
from datasets import load_dataset
|
||||
from openai import AsyncOpenAI
|
||||
from PIL import Image
|
||||
|
||||
from environments.eval_environments.eval_base import EvalBase, eval_runner
|
||||
|
||||
DESCRIPTIVE_CATEGORIES = {
|
||||
1: "Information Extraction",
|
||||
2: "Information Extraction",
|
||||
3: "Information Extraction",
|
||||
4: "Information Extraction",
|
||||
5: "Information Extraction",
|
||||
6: "Information Extraction",
|
||||
7: "Information Extraction",
|
||||
8: "Enumeration",
|
||||
9: "Enumeration",
|
||||
10: "Counting",
|
||||
11: "Pattern Recognition",
|
||||
12: "Counting",
|
||||
13: "Enumeration",
|
||||
14: "Enumeration",
|
||||
15: "Enumeration",
|
||||
16: "Pattern Recognition",
|
||||
17: "Compositionality",
|
||||
18: "Pattern Recognition",
|
||||
19: "Counting",
|
||||
}
|
||||
|
||||
REASONING_CATEGORIES = {
|
||||
1: "Text-in-Chart",
|
||||
2: "Text-in-General",
|
||||
3: "Number-in-Chart",
|
||||
4: "Number-in-General",
|
||||
}
|
||||
|
||||
DESCRIPTIVE_QUESTIONS = {
|
||||
1: "What is the title of the chart?",
|
||||
2: "What is the label of the x-axis?",
|
||||
3: "What is the label of the y-axis?",
|
||||
4: "What is the leftmost labeled tick on the x-axis?",
|
||||
5: "What is the rightmost labeled tick on the x-axis?",
|
||||
6: "What is the spatially lowest labeled tick on the y-axis?",
|
||||
7: "What is the spatially highest labeled tick on the y-axis?",
|
||||
8: "What are all the labels in the legend?",
|
||||
9: "List all the categories in the x-axis.",
|
||||
10: "How many distinct bars are there?",
|
||||
11: "Does the chart contain a grid?",
|
||||
12: "How many lines are there in the chart?",
|
||||
13: "Is there a legend in the chart?",
|
||||
14: "What are the names of the curves in the chart?",
|
||||
15: "Does the chart contain horizontal bars?",
|
||||
16: "Do the bars have error bars?",
|
||||
17: "Describe the general trend of the chart.",
|
||||
18: "Is there any point emphasized/highlighted in the chart?",
|
||||
19: "How many sections does the pie chart have?",
|
||||
}
|
||||
|
||||
GRADING_QUERY_TEMPLATE = """You are evaluating a model's answer to a chart understanding question.
|
||||
|
||||
Question: {question}
|
||||
Ground Truth Answer: {answer}
|
||||
Model's Answer: {prediction}
|
||||
|
||||
Please evaluate whether the model's answer is correct or partially correct.
|
||||
Consider semantic equivalence - different phrasings that mean the same thing should be considered correct.
|
||||
For numerical answers, exact matches or very close values should be considered correct.
|
||||
For yes/no questions, the meaning should match the ground truth.
|
||||
For enumeration questions (listing items), all items should be present regardless of order.
|
||||
|
||||
Respond with a JSON object containing:
|
||||
- "extract_answer": The key answer extracted from the model's response
|
||||
- "score": A float from 0.0 to 1.0 indicating correctness (0.0 = wrong, 0.5 = partial, 1.0 = correct)
|
||||
|
||||
Example response: {{"extract_answer": "60", "score": 1.0}}"""
|
||||
|
||||
|
||||
class CharXiv(EvalBase):
|
||||
MODES = ["descriptive", "reasoning"]
|
||||
|
||||
def setup_data(self) -> list:
|
||||
mode = getattr(self, "mode", "descriptive")
|
||||
split = getattr(self, "split", "validation")
|
||||
|
||||
dataset = load_dataset("princeton-nlp/CharXiv", "default", split=split)
|
||||
|
||||
data = []
|
||||
for item in dataset:
|
||||
if mode == "descriptive":
|
||||
for i in range(1, 5):
|
||||
q_key = f"descriptive_q{i}"
|
||||
a_key = f"descriptive_a{i}"
|
||||
if a_key in item and item[a_key]:
|
||||
template_id = item.get(q_key, i)
|
||||
if (
|
||||
isinstance(template_id, int)
|
||||
and template_id in DESCRIPTIVE_QUESTIONS
|
||||
):
|
||||
question = DESCRIPTIVE_QUESTIONS[template_id]
|
||||
else:
|
||||
question = (
|
||||
str(template_id)
|
||||
if template_id
|
||||
else f"Descriptive question {i}"
|
||||
)
|
||||
|
||||
data.append(
|
||||
{
|
||||
"image": item["image"],
|
||||
"question": question,
|
||||
"answer": item[a_key],
|
||||
"qid": (
|
||||
template_id if isinstance(template_id, int) else i
|
||||
),
|
||||
"category": item.get("category", ""),
|
||||
"grading_query": item.get("grading_query", ""),
|
||||
}
|
||||
)
|
||||
elif mode == "reasoning":
|
||||
if "reasoning_q" in item and item.get("reasoning_a"):
|
||||
data.append(
|
||||
{
|
||||
"image": item["image"],
|
||||
"question": item["reasoning_q"],
|
||||
"answer": item["reasoning_a"],
|
||||
"inst_category": item.get("category", 1),
|
||||
"grading_query": item.get("grading_query", ""),
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid mode: {mode}. Must be 'descriptive' or 'reasoning'."
|
||||
)
|
||||
|
||||
print(f"Loaded {len(data)} examples from CharXiv ({mode}, {split})")
|
||||
return data
|
||||
|
||||
def encode_image(self, pil_image: Image.Image) -> str:
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
def get_image_base64(self, item: dict) -> str:
|
||||
if "image" in item and item["image"] is not None:
|
||||
if isinstance(item["image"], Image.Image):
|
||||
return self.encode_image(item["image"])
|
||||
raise ValueError("Could not find image for item")
|
||||
|
||||
def build_messages(self, item: dict) -> List[dict]:
|
||||
image_base64 = self.get_image_base64(item)
|
||||
question = item.get("question", "")
|
||||
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
||||
},
|
||||
{"type": "text", "text": question},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
async def _judge_with_gpt(
|
||||
self, question: str, answer: str, prediction: str, grading_query: str = ""
|
||||
) -> Tuple[Optional[str], float]:
|
||||
judge_model = getattr(self, "judge_model", "gpt-4o-mini")
|
||||
judge_base_url = getattr(self, "judge_base_url", "https://api.openai.com/v1")
|
||||
judge_api_key = os.environ.get(
|
||||
getattr(self, "judge_api_key_env", "OPENAI_API_KEY"), ""
|
||||
)
|
||||
|
||||
if not judge_api_key:
|
||||
return None, 0.0
|
||||
|
||||
if grading_query:
|
||||
prompt = grading_query.replace("{PREDICTION}", prediction)
|
||||
else:
|
||||
prompt = GRADING_QUERY_TEMPLATE.format(
|
||||
question=question, answer=answer, prediction=prediction
|
||||
)
|
||||
|
||||
try:
|
||||
judge_client = openai.AsyncOpenAI(
|
||||
api_key=judge_api_key,
|
||||
base_url=judge_base_url,
|
||||
)
|
||||
|
||||
completion = await judge_client.chat.completions.create(
|
||||
model=judge_model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.0,
|
||||
max_tokens=256,
|
||||
)
|
||||
|
||||
response = completion.choices[0].message.content.strip()
|
||||
|
||||
try:
|
||||
result = json.loads(response)
|
||||
if isinstance(result, dict):
|
||||
extract_answer = result.get("extract_answer", "")
|
||||
score = float(result.get("score", 0.0))
|
||||
return extract_answer, score
|
||||
except json.JSONDecodeError:
|
||||
json_match = re.search(r"\{[^}]+\}", response)
|
||||
if json_match:
|
||||
try:
|
||||
result = json.loads(json_match.group())
|
||||
extract_answer = result.get("extract_answer", "")
|
||||
score = float(result.get("score", 0.0))
|
||||
return extract_answer, score
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
return None, 0.0
|
||||
|
||||
except Exception as e:
|
||||
print(f"GPT judge error: {e}")
|
||||
return None, 0.0
|
||||
|
||||
def _fallback_score(
|
||||
self, prediction: str, answer: str, mode: str
|
||||
) -> Tuple[str, float]:
|
||||
prediction = prediction.strip().lower()
|
||||
answer = answer.strip().lower()
|
||||
|
||||
if not prediction:
|
||||
return "", 0.0
|
||||
|
||||
if mode == "reasoning":
|
||||
if answer in prediction:
|
||||
return prediction, 1.0
|
||||
try:
|
||||
pred_nums = re.findall(r"-?\d+\.?\d*", prediction)
|
||||
ans_nums = re.findall(r"-?\d+\.?\d*", answer)
|
||||
if pred_nums and ans_nums:
|
||||
for p in pred_nums:
|
||||
for a in ans_nums:
|
||||
if abs(float(p) - float(a)) < 0.01:
|
||||
return prediction, 1.0
|
||||
except ValueError:
|
||||
pass
|
||||
return prediction, 0.0
|
||||
|
||||
else:
|
||||
pred_words = set(prediction.split())
|
||||
ans_words = set(answer.split())
|
||||
if not ans_words:
|
||||
return prediction, 0.0
|
||||
overlap = len(pred_words & ans_words) / len(ans_words)
|
||||
return prediction, min(overlap, 1.0)
|
||||
|
||||
def get_category(self, item: dict, mode: str) -> str:
|
||||
if mode == "descriptive":
|
||||
qid = item.get("qid", 1)
|
||||
return DESCRIPTIVE_CATEGORIES.get(qid, "Information Extraction")
|
||||
else:
|
||||
inst_category = item.get("inst_category", 1)
|
||||
return REASONING_CATEGORIES.get(inst_category, "Text-in-Chart")
|
||||
|
||||
async def run_item(self, client: AsyncOpenAI, data_item: dict) -> Tuple[dict, dict]:
|
||||
try:
|
||||
messages = self.build_messages(data_item)
|
||||
mode = getattr(self, "mode", "descriptive")
|
||||
|
||||
gen_params = self.get_generation_params()
|
||||
completion = await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=gen_params["temperature"],
|
||||
max_tokens=gen_params["max_tokens"],
|
||||
)
|
||||
|
||||
if not completion.choices:
|
||||
return {"accuracy": 0.0, "score": 0.0}, {"error": "Empty response"}
|
||||
|
||||
message = completion.choices[0].message
|
||||
response = message.content or ""
|
||||
if hasattr(message, "reasoning") and message.reasoning and not response:
|
||||
response = message.reasoning
|
||||
if not response and hasattr(message, "model_extra"):
|
||||
reasoning = message.model_extra.get("reasoning", "")
|
||||
if reasoning:
|
||||
response = reasoning
|
||||
|
||||
if not response:
|
||||
return {"accuracy": 0.0, "score": 0.0}, {"error": "Empty response"}
|
||||
|
||||
use_gpt_judge = getattr(self, "use_gpt_judge", True)
|
||||
grading_query = data_item.get("grading_query", "")
|
||||
answer = data_item.get("answer", "")
|
||||
question = data_item.get("question", "")
|
||||
|
||||
if use_gpt_judge:
|
||||
extracted, score = await self._judge_with_gpt(
|
||||
question, answer, response, grading_query
|
||||
)
|
||||
evaluation_method = "gpt_judge"
|
||||
else:
|
||||
extracted, score = self._fallback_score(response, answer, mode)
|
||||
evaluation_method = "fallback"
|
||||
|
||||
if extracted is None:
|
||||
extracted, score = self._fallback_score(response, answer, mode)
|
||||
evaluation_method = "fallback"
|
||||
|
||||
category = self.get_category(data_item, mode)
|
||||
|
||||
sample = {
|
||||
"question": data_item.get("question", ""),
|
||||
"answer": answer,
|
||||
"prediction": response[:500],
|
||||
"extract_answer": extracted,
|
||||
"score": score,
|
||||
"category": category,
|
||||
"mode": mode,
|
||||
"qid": data_item.get("qid", data_item.get("inst_category", "")),
|
||||
"evaluation_method": evaluation_method,
|
||||
}
|
||||
|
||||
return {"accuracy": score, "score": score}, sample
|
||||
|
||||
except Exception as e:
|
||||
return {"accuracy": 0.0, "score": 0.0}, {"error": str(e)}
|
||||
|
||||
|
||||
def compute_category_metrics(samples: List[dict]) -> Dict:
|
||||
from collections import defaultdict
|
||||
|
||||
scores_by_category = defaultdict(list)
|
||||
|
||||
for sample in samples:
|
||||
if "error" in sample:
|
||||
continue
|
||||
category = sample.get("category", "Unknown")
|
||||
score = sample.get("score", 0.0)
|
||||
scores_by_category[category].append(score)
|
||||
|
||||
metrics = {}
|
||||
total_score = 0.0
|
||||
total_count = 0
|
||||
|
||||
for category, scores in scores_by_category.items():
|
||||
if scores:
|
||||
avg_score = sum(scores) / len(scores)
|
||||
metrics[category] = {
|
||||
"count": len(scores),
|
||||
"average_score": avg_score,
|
||||
}
|
||||
total_score += sum(scores)
|
||||
total_count += len(scores)
|
||||
|
||||
if total_count > 0:
|
||||
metrics["Overall"] = {
|
||||
"count": total_count,
|
||||
"average_score": total_score / total_count,
|
||||
}
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(
|
||||
eval_runner(
|
||||
CharXiv,
|
||||
mode="descriptive", # or "reasoning"
|
||||
split="validation",
|
||||
use_gpt_judge=True,
|
||||
judge_model="gpt-4o-mini",
|
||||
temperature=0.0,
|
||||
max_tokens=1024,
|
||||
)
|
||||
)
|
||||
204
environments/eval_environments/docvqa_environment.py
Normal file
204
environments/eval_environments/docvqa_environment.py
Normal file
|
|
@ -0,0 +1,204 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import re
|
||||
from typing import List, Tuple
|
||||
|
||||
from datasets import load_dataset
|
||||
from openai import AsyncOpenAI
|
||||
from PIL import Image
|
||||
|
||||
from environments.eval_environments.eval_base import EvalBase, eval_runner
|
||||
|
||||
|
||||
class DocVQA(EvalBase):
|
||||
QUESTION_TYPES = [
|
||||
"figure/diagram",
|
||||
"layout",
|
||||
"table/list",
|
||||
"Image/Photo",
|
||||
"handwritten",
|
||||
"form",
|
||||
"free_text",
|
||||
"others",
|
||||
]
|
||||
|
||||
def setup_data(self) -> list:
|
||||
# Note: test split has hidden answers (for server evaluation)
|
||||
# Use validation for local evaluation
|
||||
split = getattr(self, "split", "validation")
|
||||
dataset = load_dataset("lmms-lab/DocVQA", "DocVQA", split=split)
|
||||
print(f"Loaded {len(dataset)} examples from DocVQA ({split})")
|
||||
return list(dataset)
|
||||
|
||||
def encode_image(self, pil_image: Image.Image) -> str:
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
def get_image_base64(self, item: dict) -> str:
|
||||
if "image" in item and item["image"] is not None:
|
||||
if isinstance(item["image"], Image.Image):
|
||||
return self.encode_image(item["image"])
|
||||
raise ValueError(
|
||||
f"Could not find image for item {item.get('questionId', 'unknown')}"
|
||||
)
|
||||
|
||||
def build_messages(self, item: dict) -> List[dict]:
|
||||
image_base64 = self.get_image_base64(item)
|
||||
question = item.get("question", "")
|
||||
|
||||
prompt = f"""Look at the document and answer the question.
|
||||
|
||||
Question: {question}
|
||||
|
||||
Provide only the answer, as concisely as possible."""
|
||||
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
||||
},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
def extract_answer(self, response: str) -> str:
|
||||
response = response.strip()
|
||||
|
||||
patterns = [
|
||||
r"answer[:\s]+(.+?)(?:\.|$)",
|
||||
r"\"([^\"]+)\"",
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, response, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
lines = response.split("\n")
|
||||
if lines:
|
||||
return lines[-1].strip()
|
||||
|
||||
return response
|
||||
|
||||
def normalize_text(self, text: str) -> str:
|
||||
text = text.lower().strip()
|
||||
text = re.sub(r"[^\w\s]", "", text)
|
||||
text = " ".join(text.split())
|
||||
return text
|
||||
|
||||
def anls_score(
|
||||
self, prediction: str, answers: List[str], threshold: float = 0.5
|
||||
) -> float:
|
||||
"""
|
||||
Calculate Average Normalized Levenshtein Similarity (ANLS).
|
||||
This is the standard metric for DocVQA.
|
||||
"""
|
||||
pred_norm = self.normalize_text(prediction)
|
||||
|
||||
if not pred_norm:
|
||||
return 0.0
|
||||
|
||||
max_score = 0.0
|
||||
for answer in answers:
|
||||
ans_norm = self.normalize_text(answer)
|
||||
if not ans_norm:
|
||||
continue
|
||||
|
||||
if pred_norm == ans_norm:
|
||||
max_score = 1.0
|
||||
break
|
||||
|
||||
distance = self._levenshtein_distance(pred_norm, ans_norm)
|
||||
max_len = max(len(pred_norm), len(ans_norm))
|
||||
nls = 1 - distance / max_len if max_len > 0 else 0
|
||||
|
||||
if nls >= threshold:
|
||||
max_score = max(max_score, nls)
|
||||
|
||||
return max_score
|
||||
|
||||
def _levenshtein_distance(self, s1: str, s2: str) -> int:
|
||||
if len(s1) < len(s2):
|
||||
return self._levenshtein_distance(s2, s1)
|
||||
|
||||
if len(s2) == 0:
|
||||
return len(s1)
|
||||
|
||||
previous_row = range(len(s2) + 1)
|
||||
for i, c1 in enumerate(s1):
|
||||
current_row = [i + 1]
|
||||
for j, c2 in enumerate(s2):
|
||||
insertions = previous_row[j + 1] + 1
|
||||
deletions = current_row[j] + 1
|
||||
substitutions = previous_row[j] + (c1 != c2)
|
||||
current_row.append(min(insertions, deletions, substitutions))
|
||||
previous_row = current_row
|
||||
|
||||
return previous_row[-1]
|
||||
|
||||
async def run_item(self, client: AsyncOpenAI, data_item: dict) -> Tuple[dict, dict]:
|
||||
try:
|
||||
messages = self.build_messages(data_item)
|
||||
|
||||
gen_params = self.get_generation_params()
|
||||
completion = await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=gen_params["temperature"],
|
||||
max_tokens=gen_params["max_tokens"],
|
||||
)
|
||||
|
||||
if not completion.choices:
|
||||
return {"accuracy": 0.0, "anls": 0.0}, {"error": "Empty response"}
|
||||
|
||||
message = completion.choices[0].message
|
||||
response = message.content or ""
|
||||
if hasattr(message, "reasoning") and message.reasoning and not response:
|
||||
response = message.reasoning
|
||||
if not response and hasattr(message, "model_extra"):
|
||||
reasoning = message.model_extra.get("reasoning", "")
|
||||
if reasoning:
|
||||
response = reasoning
|
||||
|
||||
if not response:
|
||||
return {"accuracy": 0.0, "anls": 0.0}, {"error": "Empty response"}
|
||||
|
||||
extracted = self.extract_answer(response)
|
||||
answers = data_item.get("answers", [])
|
||||
if isinstance(answers, str):
|
||||
answers = [answers]
|
||||
|
||||
anls = self.anls_score(extracted, answers)
|
||||
correct = anls >= 0.5
|
||||
|
||||
sample = {
|
||||
"questionId": data_item.get("questionId", ""),
|
||||
"question": data_item.get("question", ""),
|
||||
"answers": answers,
|
||||
"prediction": extracted,
|
||||
"anls": anls,
|
||||
"correct": correct,
|
||||
"question_types": data_item.get("question_types", []),
|
||||
}
|
||||
|
||||
return {"accuracy": 1.0 if correct else 0.0, "anls": anls}, sample
|
||||
|
||||
except Exception as e:
|
||||
return {"accuracy": 0.0, "anls": 0.0}, {"error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(
|
||||
eval_runner(
|
||||
DocVQA,
|
||||
split="test",
|
||||
temperature=0.0,
|
||||
max_tokens=256,
|
||||
)
|
||||
)
|
||||
217
environments/eval_environments/eval_base.py
Normal file
217
environments/eval_environments/eval_base.py
Normal file
|
|
@ -0,0 +1,217 @@
|
|||
"""
|
||||
Base class for evaluation environments.
|
||||
|
||||
based on PR #290 for eval-only environments.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import jsonlines
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletion
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def evaluate_log(
|
||||
metrics: Dict,
|
||||
eval_dir: Optional[str] = None,
|
||||
task_name: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
start_time: Optional[float] = None,
|
||||
end_time: Optional[float] = None,
|
||||
generation_parameters: Optional[Dict] = None,
|
||||
samples: Optional[List[Dict]] = None,
|
||||
verbose: bool = True,
|
||||
):
|
||||
if eval_dir is None:
|
||||
logger.warning("eval_dir is not set, skipping evaluation logging")
|
||||
return
|
||||
|
||||
os.makedirs(eval_dir, exist_ok=True)
|
||||
filepath = os.path.join(eval_dir, "metrics.json")
|
||||
|
||||
if start_time is None:
|
||||
start_time = time.time()
|
||||
if end_time is None:
|
||||
end_time = time.time()
|
||||
if generation_parameters is None:
|
||||
generation_parameters = {}
|
||||
|
||||
if verbose:
|
||||
print(f"\n{'='*60}")
|
||||
print(f" {task_name}")
|
||||
print(f"{'='*60}")
|
||||
for key, value in metrics.items():
|
||||
if isinstance(value, float):
|
||||
print(f" {key}: {value:.4f}")
|
||||
else:
|
||||
print(f" {key}: {value}")
|
||||
print(f" Time: {end_time - start_time:.1f}s")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
task_key = f"atropos|{task_name}|0"
|
||||
eval_result = {
|
||||
"config_general": {
|
||||
"model_name": model_name,
|
||||
"total_evaluation_time_seconds": str(end_time - start_time),
|
||||
"generation_parameters": generation_parameters,
|
||||
},
|
||||
"results": {
|
||||
task_key: metrics,
|
||||
"all": metrics,
|
||||
},
|
||||
}
|
||||
|
||||
with open(filepath, "w") as f:
|
||||
json.dump(eval_result, f, indent=2)
|
||||
|
||||
print(f"Evaluation results saved to {filepath}")
|
||||
|
||||
if samples:
|
||||
samples_filepath = os.path.join(eval_dir, "samples.jsonl")
|
||||
with jsonlines.open(samples_filepath, "w") as writer:
|
||||
for sample in samples:
|
||||
writer.write(sample)
|
||||
print(f"Evaluation samples saved to {samples_filepath}")
|
||||
|
||||
|
||||
class EvalBase(ABC):
|
||||
"""
|
||||
Base class for evaluation environments.
|
||||
|
||||
Subclasses must implement:
|
||||
- setup_data(): Returns list of data items to evaluate
|
||||
- run_item(client, data_item): Process one item, returns (metrics_dict, sample)
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
self.data = self.setup_data()
|
||||
|
||||
def get_generation_params(self) -> dict:
|
||||
return {
|
||||
"temperature": getattr(self, "temperature", 0.0),
|
||||
"max_tokens": getattr(self, "max_tokens", 4096),
|
||||
"n": getattr(self, "n", 1),
|
||||
}
|
||||
|
||||
async def chat_completion(
|
||||
self, client: AsyncOpenAI, messages: List[dict]
|
||||
) -> ChatCompletion:
|
||||
gen_params = self.get_generation_params()
|
||||
return await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
**gen_params,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def setup_data(self) -> list:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def run_item(self, client: AsyncOpenAI, data_item: dict) -> Tuple[dict, dict]:
|
||||
"""
|
||||
Process a single data item.
|
||||
|
||||
Returns:
|
||||
Tuple[dict, dict]: (metrics_dict, sample_dict)
|
||||
- metrics_dict: keys like "accuracy" with numeric values
|
||||
- sample_dict: the sample data for logging
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def __call__(self, client: AsyncOpenAI):
|
||||
start_time = time.time()
|
||||
|
||||
task_coros = [self.run_item(client, item) for item in self.data]
|
||||
task_results = await tqdm_asyncio.gather(
|
||||
*task_coros, desc=f"Evaluating {self.__class__.__name__}"
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
metrics_list = [result[0] for result in task_results]
|
||||
samples = [result[1] for result in task_results]
|
||||
|
||||
keys = list(metrics_list[0].keys())
|
||||
metrics = {
|
||||
key: sum(result[key] for result in metrics_list) / len(metrics_list)
|
||||
for key in keys
|
||||
}
|
||||
|
||||
task_name = self.__class__.__name__
|
||||
|
||||
evaluate_log(
|
||||
metrics,
|
||||
eval_dir=getattr(self, "eval_dir", None),
|
||||
task_name=task_name,
|
||||
model_name=self.model_name,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
generation_parameters=self.get_generation_params(),
|
||||
samples=samples,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
async def eval_runner(eval_cls, **eval_kwargs):
|
||||
"""
|
||||
CLI runner for evaluation environments.
|
||||
|
||||
Usage in __main__:
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
from eval_base import eval_runner
|
||||
asyncio.run(eval_runner(MyEval, temperature=0.0, max_tokens=4096))
|
||||
"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--base-url",
|
||||
type=str,
|
||||
default="http://localhost:8000/v1",
|
||||
help="Base URL for OpenAI-compatible API",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api-key",
|
||||
type=str,
|
||||
default="x",
|
||||
help="API key (use 'x' for local servers)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to save evaluation results",
|
||||
)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
client = AsyncOpenAI(
|
||||
base_url=args.base_url,
|
||||
api_key=args.api_key,
|
||||
)
|
||||
|
||||
eval_kwargs["model_name"] = args.model_name
|
||||
eval_kwargs["eval_dir"] = args.eval_dir
|
||||
|
||||
eval_env = eval_cls(**eval_kwargs)
|
||||
return await eval_env(client)
|
||||
188
environments/eval_environments/infovqa_environment.py
Normal file
188
environments/eval_environments/infovqa_environment.py
Normal file
|
|
@ -0,0 +1,188 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import re
|
||||
from typing import List, Tuple
|
||||
|
||||
from datasets import load_dataset
|
||||
from openai import AsyncOpenAI
|
||||
from PIL import Image
|
||||
|
||||
from environments.eval_environments.eval_base import EvalBase, eval_runner
|
||||
|
||||
|
||||
class InfoVQA(EvalBase):
|
||||
def setup_data(self) -> list:
|
||||
split = getattr(self, "split", "validation")
|
||||
dataset = load_dataset("lmms-lab/DocVQA", "InfographicVQA", split=split)
|
||||
print(f"Loaded {len(dataset)} examples from InfoVQA ({split})")
|
||||
return list(dataset)
|
||||
|
||||
def encode_image(self, pil_image: Image.Image) -> str:
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
def get_image_base64(self, item: dict) -> str:
|
||||
if "image" in item and item["image"] is not None:
|
||||
if isinstance(item["image"], Image.Image):
|
||||
return self.encode_image(item["image"])
|
||||
raise ValueError(f"Could not find image for item {item.get('id', 'unknown')}")
|
||||
|
||||
def build_messages(self, item: dict) -> List[dict]:
|
||||
image_base64 = self.get_image_base64(item)
|
||||
question = item.get("question", "")
|
||||
|
||||
prompt = f"""Look at the infographic and answer the question.
|
||||
|
||||
Question: {question}
|
||||
|
||||
Provide only the answer, as concisely as possible."""
|
||||
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
||||
},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
def extract_answer(self, response: str) -> str:
|
||||
response = response.strip()
|
||||
|
||||
patterns = [
|
||||
r"answer[:\s]+(.+?)(?:\.|$)",
|
||||
r"\"([^\"]+)\"",
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, response, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
lines = response.split("\n")
|
||||
if lines:
|
||||
return lines[-1].strip()
|
||||
|
||||
return response
|
||||
|
||||
def normalize_text(self, text: str) -> str:
|
||||
text = text.lower().strip()
|
||||
text = re.sub(r"[^\w\s]", "", text)
|
||||
text = " ".join(text.split())
|
||||
return text
|
||||
|
||||
def anls_score(
|
||||
self, prediction: str, answers: List[str], threshold: float = 0.5
|
||||
) -> float:
|
||||
"""
|
||||
Calculate Average Normalized Levenshtein Similarity (ANLS).
|
||||
This is the standard metric for InfoVQA.
|
||||
"""
|
||||
pred_norm = self.normalize_text(prediction)
|
||||
|
||||
if not pred_norm:
|
||||
return 0.0
|
||||
|
||||
max_score = 0.0
|
||||
for answer in answers:
|
||||
ans_norm = self.normalize_text(answer)
|
||||
if not ans_norm:
|
||||
continue
|
||||
|
||||
if pred_norm == ans_norm:
|
||||
max_score = 1.0
|
||||
break
|
||||
|
||||
distance = self._levenshtein_distance(pred_norm, ans_norm)
|
||||
max_len = max(len(pred_norm), len(ans_norm))
|
||||
nls = 1 - distance / max_len if max_len > 0 else 0
|
||||
|
||||
if nls >= threshold:
|
||||
max_score = max(max_score, nls)
|
||||
|
||||
return max_score
|
||||
|
||||
def _levenshtein_distance(self, s1: str, s2: str) -> int:
|
||||
if len(s1) < len(s2):
|
||||
return self._levenshtein_distance(s2, s1)
|
||||
|
||||
if len(s2) == 0:
|
||||
return len(s1)
|
||||
|
||||
previous_row = range(len(s2) + 1)
|
||||
for i, c1 in enumerate(s1):
|
||||
current_row = [i + 1]
|
||||
for j, c2 in enumerate(s2):
|
||||
insertions = previous_row[j + 1] + 1
|
||||
deletions = current_row[j] + 1
|
||||
substitutions = previous_row[j] + (c1 != c2)
|
||||
current_row.append(min(insertions, deletions, substitutions))
|
||||
previous_row = current_row
|
||||
|
||||
return previous_row[-1]
|
||||
|
||||
async def run_item(self, client: AsyncOpenAI, data_item: dict) -> Tuple[dict, dict]:
|
||||
try:
|
||||
messages = self.build_messages(data_item)
|
||||
|
||||
gen_params = self.get_generation_params()
|
||||
completion = await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=gen_params["temperature"],
|
||||
max_tokens=gen_params["max_tokens"],
|
||||
)
|
||||
|
||||
if not completion.choices:
|
||||
return {"accuracy": 0.0, "anls": 0.0}, {"error": "Empty response"}
|
||||
|
||||
message = completion.choices[0].message
|
||||
response = message.content or ""
|
||||
if hasattr(message, "reasoning") and message.reasoning and not response:
|
||||
response = message.reasoning
|
||||
if not response and hasattr(message, "model_extra"):
|
||||
reasoning = message.model_extra.get("reasoning", "")
|
||||
if reasoning:
|
||||
response = reasoning
|
||||
|
||||
if not response:
|
||||
return {"accuracy": 0.0, "anls": 0.0}, {"error": "Empty response"}
|
||||
|
||||
extracted = self.extract_answer(response)
|
||||
answers = data_item.get("answer", [])
|
||||
if isinstance(answers, str):
|
||||
answers = [answers]
|
||||
|
||||
anls = self.anls_score(extracted, answers)
|
||||
correct = anls >= 0.5
|
||||
|
||||
sample = {
|
||||
"id": data_item.get("id", ""),
|
||||
"question": data_item.get("question", ""),
|
||||
"answers": answers,
|
||||
"prediction": extracted,
|
||||
"anls": anls,
|
||||
"correct": correct,
|
||||
}
|
||||
|
||||
return {"accuracy": 1.0 if correct else 0.0, "anls": anls}, sample
|
||||
|
||||
except Exception as e:
|
||||
return {"accuracy": 0.0, "anls": 0.0}, {"error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(
|
||||
eval_runner(
|
||||
InfoVQA,
|
||||
split="test",
|
||||
temperature=0.0,
|
||||
max_tokens=256,
|
||||
)
|
||||
)
|
||||
306
environments/eval_environments/logicvista_environment.py
Normal file
306
environments/eval_environments/logicvista_environment.py
Normal file
|
|
@ -0,0 +1,306 @@
|
|||
"""LogicVista evaluation environment."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import openai
|
||||
from datasets import load_dataset
|
||||
from openai import AsyncOpenAI
|
||||
from PIL import Image
|
||||
|
||||
from environments.eval_environments.eval_base import EvalBase, eval_runner
|
||||
|
||||
EXTRACTION_PROMPT_TEMPLATE = """You are a information extractor that extracts multiple choice letter answer choices \
|
||||
from a paragraph that contains the answer choice and sometimes explaination of why that \
|
||||
choice is correct to the given question.
|
||||
What letter did the following answer choose? If the answer did not select a letter answer choice, \
|
||||
first try to infer the answer based off the given choices.
|
||||
If it does not correspond to an answer choice OR there is no selected answer, respond with Z.
|
||||
Make sure you answer with ONLY the letters chosen.
|
||||
Example 1:
|
||||
Question: <start>
|
||||
What is the main object in image?
|
||||
Options: A. teddy bear B. rabbit C. cat D. dog
|
||||
<end>
|
||||
Answer: <start>
|
||||
a cute teddy bear
|
||||
<end>
|
||||
Your output: A
|
||||
Example 2:
|
||||
Question: <start>
|
||||
What is the main object in image?
|
||||
Options: A. teddy bear B. rabbit C. cat D. dog
|
||||
<end>
|
||||
Answer: <start>
|
||||
Spider
|
||||
<end>
|
||||
Your output: Z
|
||||
Example 3:
|
||||
Question: <start>
|
||||
Which figure is a rotation of the object?
|
||||
<end>
|
||||
Answer: <start>
|
||||
The figure on the right, labeled "D," is a rotation of the object shown in the top left corner.
|
||||
<end>
|
||||
Your output: D
|
||||
Example 4:
|
||||
Question: <start>
|
||||
Which of the boxes comes next in the sequence? Select from A-E
|
||||
<end>
|
||||
Answer: <start>
|
||||
The sequence of the boxes is A, B, C, D, E.
|
||||
<end>
|
||||
Your output: ABCDE
|
||||
Example 5:
|
||||
Question: <start>
|
||||
{question}
|
||||
<end>
|
||||
Answer: <start>
|
||||
{prediction}
|
||||
<end>
|
||||
Your output: """
|
||||
|
||||
|
||||
class LogicVista(EvalBase):
|
||||
SKILL_CATEGORIES = [
|
||||
"inductive",
|
||||
"deductive",
|
||||
"numerical",
|
||||
"spatial",
|
||||
"mechanical",
|
||||
]
|
||||
|
||||
CAPABILITY_CATEGORIES = [
|
||||
"diagram",
|
||||
"ocr",
|
||||
"patterns",
|
||||
"graphs",
|
||||
"tables",
|
||||
"3d shapes",
|
||||
"puzzles",
|
||||
"sequences",
|
||||
"physics",
|
||||
]
|
||||
|
||||
def setup_data(self) -> list:
|
||||
split = getattr(self, "split", "test")
|
||||
dataset = load_dataset("lscpku/LogicVista", split=split)
|
||||
print(f"Loaded {len(dataset)} examples from LogicVista ({split})")
|
||||
return list(dataset)
|
||||
|
||||
def encode_image(self, pil_image: Image.Image) -> str:
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
def get_image_base64(self, item: dict) -> str:
|
||||
if "image" in item and item["image"] is not None:
|
||||
if isinstance(item["image"], Image.Image):
|
||||
return self.encode_image(item["image"])
|
||||
raise ValueError(
|
||||
f"Could not find image for item {item.get('question_id', 'unknown')}"
|
||||
)
|
||||
|
||||
def build_messages(self, item: dict) -> List[dict]:
|
||||
image_base64 = self.get_image_base64(item)
|
||||
question = item.get("question", "")
|
||||
|
||||
prompt = f"""{question}
|
||||
|
||||
Provide your answer as the letter(s) of the correct choice(s), e.g., A, B, C, D, or multiple letters if applicable."""
|
||||
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
||||
},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
async def _extract_with_gpt(self, question: str, response: str) -> Optional[str]:
|
||||
judge_model = getattr(self, "judge_model", "gpt-4o-mini")
|
||||
judge_base_url = getattr(self, "judge_base_url", "https://api.openai.com/v1")
|
||||
judge_api_key = os.environ.get(
|
||||
getattr(self, "judge_api_key_env", "OPENAI_API_KEY"), ""
|
||||
)
|
||||
|
||||
if not judge_api_key:
|
||||
return None
|
||||
|
||||
try:
|
||||
judge_client = openai.AsyncOpenAI(
|
||||
api_key=judge_api_key,
|
||||
base_url=judge_base_url,
|
||||
)
|
||||
|
||||
prompt = EXTRACTION_PROMPT_TEMPLATE.format(
|
||||
question=question, prediction=response
|
||||
)
|
||||
|
||||
completion = await judge_client.chat.completions.create(
|
||||
model=judge_model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.0,
|
||||
max_tokens=128,
|
||||
)
|
||||
|
||||
result = completion.choices[0].message.content.strip()
|
||||
|
||||
if result and result.isupper() and result.isalpha():
|
||||
return result
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"GPT extraction error: {e}")
|
||||
return None
|
||||
|
||||
def extract_answer(self, response: str) -> str:
|
||||
response = response.strip().upper()
|
||||
|
||||
letters_with_sep = re.findall(r"[A-E](?:\s*[,\s]\s*[A-E])*", response)
|
||||
if letters_with_sep:
|
||||
letters = re.findall(r"[A-E]", letters_with_sep[-1])
|
||||
return "".join(sorted(set(letters)))
|
||||
|
||||
letters = re.findall(
|
||||
r"[A-E]", response[-20:] if len(response) > 20 else response
|
||||
)
|
||||
if letters:
|
||||
return "".join(sorted(set(letters)))
|
||||
|
||||
all_letters = re.findall(r"[A-E]", response)
|
||||
if all_letters:
|
||||
return "".join(sorted(set(all_letters[-4:])))
|
||||
|
||||
return ""
|
||||
|
||||
def score(self, prediction: str, answer: str) -> bool:
|
||||
if not prediction:
|
||||
return False
|
||||
|
||||
answer_letters = re.findall(r"[A-Ea-e]", answer)
|
||||
answer_normalized = "".join(sorted(set(c.lower() for c in answer_letters)))
|
||||
|
||||
pred_letters = [c.lower() for c in prediction if c.isalpha()]
|
||||
pred_normalized = "".join(sorted(set(pred_letters)))
|
||||
|
||||
return pred_normalized == answer_normalized
|
||||
|
||||
async def run_item(self, client: AsyncOpenAI, data_item: dict) -> Tuple[dict, dict]:
|
||||
try:
|
||||
messages = self.build_messages(data_item)
|
||||
|
||||
gen_params = self.get_generation_params()
|
||||
completion = await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=gen_params["temperature"],
|
||||
max_tokens=gen_params["max_tokens"],
|
||||
)
|
||||
|
||||
if not completion.choices:
|
||||
return {"accuracy": 0.0, "hit": 0}, {"error": "Empty response"}
|
||||
|
||||
message = completion.choices[0].message
|
||||
response = message.content or ""
|
||||
if hasattr(message, "reasoning") and message.reasoning and not response:
|
||||
response = message.reasoning
|
||||
if not response and hasattr(message, "model_extra"):
|
||||
reasoning = message.model_extra.get("reasoning", "")
|
||||
if reasoning:
|
||||
response = reasoning
|
||||
|
||||
if not response:
|
||||
return {"accuracy": 0.0, "hit": 0}, {"error": "Empty response"}
|
||||
|
||||
use_gpt_extraction = getattr(self, "use_gpt_extraction", True)
|
||||
extracted = None
|
||||
extraction_method = "regex"
|
||||
|
||||
if use_gpt_extraction:
|
||||
question = data_item.get("question", "")
|
||||
gpt_result = await self._extract_with_gpt(question, response)
|
||||
if gpt_result and gpt_result != "Z":
|
||||
extracted = gpt_result
|
||||
extraction_method = "gpt"
|
||||
|
||||
if not extracted:
|
||||
extracted = self.extract_answer(response)
|
||||
extraction_method = "regex"
|
||||
|
||||
answer = data_item.get("answer", "")
|
||||
correct = self.score(extracted, answer)
|
||||
|
||||
sample = {
|
||||
"question_id": data_item.get("question_id", data_item.get("index", "")),
|
||||
"question": data_item.get("question", ""),
|
||||
"answer": answer,
|
||||
"prediction": extracted,
|
||||
"raw_response": response[:500],
|
||||
"hit": 1 if correct else 0,
|
||||
"correct": correct,
|
||||
"skill": data_item.get("skill", ""),
|
||||
"extraction_method": extraction_method,
|
||||
}
|
||||
|
||||
return {
|
||||
"accuracy": 1.0 if correct else 0.0,
|
||||
"hit": 1 if correct else 0,
|
||||
}, sample
|
||||
|
||||
except Exception as e:
|
||||
return {"accuracy": 0.0, "hit": 0}, {"error": str(e)}
|
||||
|
||||
|
||||
def compute_skill_metrics(samples: List[dict]) -> Dict:
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame(samples)
|
||||
|
||||
if "hit" not in df.columns or "skill" not in df.columns:
|
||||
return {"overall_accuracy": df.get("hit", pd.Series([0])).mean()}
|
||||
|
||||
metrics = {}
|
||||
|
||||
# Overall accuracy
|
||||
metrics["Overall"] = {
|
||||
"total": len(df),
|
||||
"correct": int(df["hit"].sum()),
|
||||
"accuracy": float(df["hit"].mean() * 100),
|
||||
}
|
||||
|
||||
# By skill category
|
||||
skill_keywords = ["inductive", "deductive", "numerical", "spatial", "mechanical"]
|
||||
|
||||
for skill in skill_keywords:
|
||||
skill_df = df[df["skill"].str.contains(skill, case=False, na=False)]
|
||||
if len(skill_df) > 0:
|
||||
metrics[skill] = {
|
||||
"total": len(skill_df),
|
||||
"correct": int(skill_df["hit"].sum()),
|
||||
"accuracy": float(skill_df["hit"].mean() * 100),
|
||||
}
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(
|
||||
eval_runner(
|
||||
LogicVista,
|
||||
split="test",
|
||||
use_gpt_extraction=True,
|
||||
judge_model="gpt-4o-mini",
|
||||
temperature=0.0,
|
||||
max_tokens=512,
|
||||
)
|
||||
)
|
||||
339
environments/eval_environments/mathverse_environment.py
Normal file
339
environments/eval_environments/mathverse_environment.py
Normal file
|
|
@ -0,0 +1,339 @@
|
|||
"""MathVerse evaluation environment."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import openai
|
||||
from datasets import load_dataset
|
||||
from openai import AsyncOpenAI
|
||||
from PIL import Image
|
||||
|
||||
from environments.eval_environments.eval_base import EvalBase, eval_runner
|
||||
|
||||
EXTRACT_ICL_EXAMPLES = [
|
||||
"1.\nModel response: 'The perimeter of the sector is approximately (-2, 1)'\n"
|
||||
"Extracted Answer: (-2, 1)\n",
|
||||
"2.\nModel response: 'The correct option is D. They give the solutions to $f(t)=g(t)$.'\n"
|
||||
"Extracted Answer: D\n",
|
||||
"3.\nModel response: 'The range is (-4, 1]. Domain: (-3, 3], Range: (-4, 1]'\n"
|
||||
"Extracted Answer: Domain: (-3, 3], Range: (-4, 1]\n",
|
||||
"4.\nModel response: 'I cannot provide the answer because there is not enough information.'\n"
|
||||
"Extracted Answer: null\n",
|
||||
"5.\nModel response: 'The distance d between Ned and Bart is approximately 22.3 meters.'\n"
|
||||
"Extracted answer: 22.3\n",
|
||||
"6.\nModel response: 'The equation for f is f(x) = -x^2 - 2x + 1'\n"
|
||||
"Extracted answer: f(x) = -x^2 - 2x + 1\n",
|
||||
]
|
||||
|
||||
SCORE_ICL_EXAMPLES = [
|
||||
"""[Question]: Write the set of numbers represented on the number line in interval notation.
|
||||
[Standard Answer]: (-2,1]
|
||||
[Model_answer] : Extracted Answer: \\((-2, 1)\\)
|
||||
Judgement: 0
|
||||
""",
|
||||
"""[Question]: As shown in the figure, circle O has a radius 1.0, if angle BAC = 60.0, then the length of BC is ()
|
||||
Choices:
|
||||
A:2
|
||||
B:2√{3}
|
||||
C:√{3}
|
||||
D:2√{2}
|
||||
[Standard Answer]: C
|
||||
[Model_answer] : B:2√{3}
|
||||
Judgement: 0
|
||||
""",
|
||||
"""[Question]: Find the domain and range of the function f using interval notation.
|
||||
[Standard Answer]: domain: [-4, 0) and range: (-3, 1]
|
||||
[Model_answer] : Range: \\((-4, 1]\\)
|
||||
Judgement: 0
|
||||
""",
|
||||
"""[Question]: As shown in the figure, circle O has a radius 1.0, if angle BAC = 60.0, then the length of BC is ()
|
||||
Choices:
|
||||
A:2
|
||||
B:2√{3}
|
||||
C:√{3}
|
||||
D:2√{2}
|
||||
[Standard Answer]: C
|
||||
[Model_answer] : null
|
||||
Judgement: 0
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
class MathVerse(EvalBase):
|
||||
PROBLEM_VERSIONS = [
|
||||
"Text Dominant",
|
||||
"Text Lite",
|
||||
"Vision Intensive",
|
||||
"Vision Dominant",
|
||||
"Vision Only",
|
||||
]
|
||||
|
||||
def setup_data(self) -> list:
|
||||
config = getattr(self, "config", "testmini")
|
||||
dataset = load_dataset("AI4Math/MathVerse", config, split="testmini")
|
||||
print(f"Loaded {len(dataset)} examples from MathVerse ({config})")
|
||||
return list(dataset)
|
||||
|
||||
def encode_image(self, pil_image: Image.Image) -> str:
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
def get_image_base64(self, item: dict) -> str:
|
||||
if "image" in item and item["image"] is not None:
|
||||
if isinstance(item["image"], Image.Image):
|
||||
return self.encode_image(item["image"])
|
||||
raise ValueError(
|
||||
f"Could not find image for item {item.get('sample_index', 'unknown')}"
|
||||
)
|
||||
|
||||
def build_messages(self, item: dict) -> List[dict]:
|
||||
image_base64 = self.get_image_base64(item)
|
||||
|
||||
use_cot = getattr(self, "use_cot", False)
|
||||
if use_cot and "query_cot" in item:
|
||||
question = item["query_cot"]
|
||||
elif "question_for_eval" in item:
|
||||
question = item["question_for_eval"]
|
||||
else:
|
||||
question = item.get("question", "")
|
||||
|
||||
prompt = f"""{question}
|
||||
|
||||
Please solve the problem step by step and provide your final answer."""
|
||||
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
||||
},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
async def _extract_with_gpt(self, prediction: str) -> Optional[str]:
|
||||
judge_model = getattr(self, "judge_model", "gpt-4o-mini")
|
||||
judge_base_url = getattr(self, "judge_base_url", "https://api.openai.com/v1")
|
||||
judge_api_key = os.environ.get(
|
||||
getattr(self, "judge_api_key_env", "OPENAI_API_KEY"), ""
|
||||
)
|
||||
|
||||
if not judge_api_key:
|
||||
return None
|
||||
|
||||
try:
|
||||
judge_client = openai.AsyncOpenAI(
|
||||
api_key=judge_api_key,
|
||||
base_url=judge_base_url,
|
||||
)
|
||||
|
||||
task_description = (
|
||||
"I am providing you a response from a model to a math problem, "
|
||||
"termed 'Model Response'. You should extract the answer from the "
|
||||
"response as 'Extracted Answer'. Directly output the extracted "
|
||||
"answer with no explanation.\n\n"
|
||||
)
|
||||
prompt = task_description
|
||||
for example in EXTRACT_ICL_EXAMPLES:
|
||||
prompt += example + "\n\n"
|
||||
prompt += f"7.\nModel response: '{prediction}'\nExtracted Answer: "
|
||||
|
||||
completion = await judge_client.chat.completions.create(
|
||||
model=judge_model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.0,
|
||||
max_tokens=256,
|
||||
)
|
||||
|
||||
result = completion.choices[0].message.content.strip()
|
||||
return result if result else None
|
||||
|
||||
except Exception as e:
|
||||
print(f"GPT extraction error: {e}")
|
||||
return None
|
||||
|
||||
async def _score_with_gpt(
|
||||
self, question: str, answer: str, extracted: str
|
||||
) -> Optional[bool]:
|
||||
judge_model = getattr(self, "judge_model", "gpt-4o-mini")
|
||||
judge_base_url = getattr(self, "judge_base_url", "https://api.openai.com/v1")
|
||||
judge_api_key = os.environ.get(
|
||||
getattr(self, "judge_api_key_env", "OPENAI_API_KEY"), ""
|
||||
)
|
||||
|
||||
if not judge_api_key:
|
||||
return None
|
||||
|
||||
if str(extracted).strip() == str(answer).strip():
|
||||
return True
|
||||
|
||||
try:
|
||||
judge_client = openai.AsyncOpenAI(
|
||||
api_key=judge_api_key,
|
||||
base_url=judge_base_url,
|
||||
)
|
||||
|
||||
task_description = (
|
||||
"Below are two answers to a math question. Question is [Question], "
|
||||
"[Standard Answer] is the standard answer to the question, and "
|
||||
"[Model_answer] is the answer extracted from a model's output to "
|
||||
"this question. Determine whether these two answers are consistent.\n"
|
||||
"Please note that only when the [Model_answer] completely matches "
|
||||
"the [Standard Answer] means they are consistent. For non-MCQ "
|
||||
"questions, if the meaning is expressed in the same way, it is also "
|
||||
"considered consistent, for example, 0.5m and 50cm.\n"
|
||||
"If they are consistent, Judgement is 1; if different, Judgement is 0.\n\n"
|
||||
)
|
||||
prompt = task_description
|
||||
for example in SCORE_ICL_EXAMPLES:
|
||||
prompt += example + "\n\n"
|
||||
prompt += f"""[Question]: {question}
|
||||
[Standard Answer]: {answer}
|
||||
[Model_answer] : {extracted}
|
||||
Judgement:"""
|
||||
|
||||
completion = await judge_client.chat.completions.create(
|
||||
model=judge_model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.0,
|
||||
max_tokens=16,
|
||||
)
|
||||
|
||||
result = completion.choices[0].message.content.strip()
|
||||
if result in ["0", "1"]:
|
||||
return int(result) == 1
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"GPT scoring error: {e}")
|
||||
return None
|
||||
|
||||
def extract_answer_fallback(self, response: str) -> str:
|
||||
response = response.strip().upper()
|
||||
|
||||
for char in reversed(response):
|
||||
if char in "ABCDE":
|
||||
return char
|
||||
|
||||
numbers = re.findall(r"-?\d+\.?\d*", response)
|
||||
if numbers:
|
||||
return numbers[-1]
|
||||
|
||||
return response[:100]
|
||||
|
||||
def score_fallback(self, prediction: str, answer: str) -> bool:
|
||||
pred = prediction.strip().upper()
|
||||
ans = answer.strip().upper()
|
||||
|
||||
if pred == ans:
|
||||
return True
|
||||
|
||||
try:
|
||||
pred_num = float(pred)
|
||||
ans_num = float(ans)
|
||||
return abs(pred_num - ans_num) < 0.01
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
async def run_item(self, client: AsyncOpenAI, data_item: dict) -> Tuple[dict, dict]:
|
||||
try:
|
||||
messages = self.build_messages(data_item)
|
||||
|
||||
gen_params = self.get_generation_params()
|
||||
completion = await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=gen_params["temperature"],
|
||||
max_tokens=gen_params["max_tokens"],
|
||||
)
|
||||
|
||||
if not completion.choices:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
message = completion.choices[0].message
|
||||
response = message.content or ""
|
||||
if hasattr(message, "reasoning") and message.reasoning and not response:
|
||||
response = message.reasoning
|
||||
if not response and hasattr(message, "model_extra"):
|
||||
reasoning = message.model_extra.get("reasoning", "")
|
||||
if reasoning:
|
||||
response = reasoning
|
||||
|
||||
if not response:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
use_gpt_evaluation = getattr(self, "use_gpt_evaluation", True)
|
||||
answer = data_item.get("answer", "")
|
||||
question = data_item.get("question_for_eval", data_item.get("question", ""))
|
||||
|
||||
if use_gpt_evaluation:
|
||||
extracted = await self._extract_with_gpt(response)
|
||||
if not extracted:
|
||||
extracted = self.extract_answer_fallback(response)
|
||||
extraction_method = "fallback"
|
||||
else:
|
||||
extraction_method = "gpt"
|
||||
else:
|
||||
extracted = self.extract_answer_fallback(response)
|
||||
extraction_method = "fallback"
|
||||
|
||||
if use_gpt_evaluation:
|
||||
score_result = await self._score_with_gpt(question, answer, extracted)
|
||||
if score_result is not None:
|
||||
correct = score_result
|
||||
scoring_method = "gpt"
|
||||
else:
|
||||
correct = self.score_fallback(extracted, answer)
|
||||
scoring_method = "fallback"
|
||||
else:
|
||||
correct = self.score_fallback(extracted, answer)
|
||||
scoring_method = "fallback"
|
||||
|
||||
metadata = data_item.get("metadata", {})
|
||||
sample = {
|
||||
"sample_index": data_item.get("sample_index", ""),
|
||||
"problem_index": data_item.get("problem_index", ""),
|
||||
"problem_version": data_item.get("problem_version", ""),
|
||||
"question": question[:200],
|
||||
"answer": answer,
|
||||
"prediction": extracted,
|
||||
"raw_response": response[:500],
|
||||
"correct": correct,
|
||||
"subject": (
|
||||
metadata.get("subject", "") if isinstance(metadata, dict) else ""
|
||||
),
|
||||
"subfield": (
|
||||
metadata.get("subfield", "") if isinstance(metadata, dict) else ""
|
||||
),
|
||||
"extraction_method": extraction_method,
|
||||
"scoring_method": scoring_method,
|
||||
}
|
||||
|
||||
return {"accuracy": 1.0 if correct else 0.0}, sample
|
||||
|
||||
except Exception as e:
|
||||
return {"accuracy": 0.0}, {"error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(
|
||||
eval_runner(
|
||||
MathVerse,
|
||||
split="testmini",
|
||||
use_cot=False,
|
||||
use_gpt_evaluation=True,
|
||||
judge_model="gpt-4o-mini",
|
||||
temperature=0.0,
|
||||
max_tokens=2048,
|
||||
)
|
||||
)
|
||||
345
environments/eval_environments/mathvision_environment.py
Normal file
345
environments/eval_environments/mathvision_environment.py
Normal file
|
|
@ -0,0 +1,345 @@
|
|||
"""MathVision evaluation environment."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import openai
|
||||
from datasets import load_dataset
|
||||
from openai import AsyncOpenAI
|
||||
from PIL import Image
|
||||
|
||||
from environments.eval_environments.eval_base import EvalBase, eval_runner
|
||||
|
||||
ICL_EXAMPLES = [
|
||||
"""Hint: Please answer the question and provide the final answer at the end.
|
||||
Question: Which number is missing?
|
||||
Model response: The number missing in the sequence is 14.
|
||||
Extracted answer: 14
|
||||
""",
|
||||
"Hint: Please answer the question and provide the final answer at the end.\n"
|
||||
"Question: What is the fraction of females facing the camera?\n"
|
||||
"Model response: The fraction of females facing the camera is 0.6.\n"
|
||||
"Extracted answer: 0.6\n",
|
||||
"""Hint: Please answer the question and provide the final answer at the end.
|
||||
Question: How much money does Luca need to buy a sour apple candy and a butter-scotch candy? (Unit: $)
|
||||
Model response: Luca needs $1.45 to buy a sour apple candy and a butterscotch candy.
|
||||
Extracted answer: 1.45
|
||||
""",
|
||||
"""Hint: Please answer the question and provide the final answer at the end.
|
||||
Question: Between which two years does the line graph saw its maximum peak?
|
||||
Model response: The line graph saw its maximum peak between 2007 and 2008.
|
||||
Extracted answer: [2007, 2008]
|
||||
""",
|
||||
"""Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end.
|
||||
Question: What fraction of the shape is blue?
|
||||
Choices: (A) 3/11 (B) 8/11 (C) 6/11 (D) 3/5
|
||||
Model response: The correct answer is (B) 8/11.
|
||||
Extracted answer: B
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
def can_infer_option(answer: str, choices: Dict[str, str]) -> Optional[str]:
|
||||
if "Failed to obtain answer via API" in answer:
|
||||
return None
|
||||
|
||||
answer_mod = answer
|
||||
for c in ".()[],:;!*#{}":
|
||||
answer_mod = answer_mod.replace(c, " ")
|
||||
|
||||
splits = [x.strip() for x in answer_mod.split()]
|
||||
count = sum(1 for ch in choices if ch in splits)
|
||||
|
||||
if count == 1:
|
||||
for ch in choices:
|
||||
if "A" in splits and len(splits) > 3:
|
||||
continue
|
||||
if ch in splits and splits.index(ch) > (len(splits) - 5):
|
||||
return ch
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def can_infer_text(answer: str, choices: Dict[str, str]) -> Optional[str]:
|
||||
answer_lower = answer.lower()
|
||||
|
||||
if len(answer_lower) > 2 * sum(len(str(v)) for v in choices.values()):
|
||||
return None
|
||||
|
||||
cands = []
|
||||
for k, v in choices.items():
|
||||
if str(v).lower() in answer_lower:
|
||||
cands.append(k)
|
||||
|
||||
if len(cands) == 1:
|
||||
return cands[0]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def can_infer(answer: str, choices: Dict[str, str]) -> Optional[str]:
|
||||
answer = str(answer)
|
||||
result = can_infer_option(answer, choices)
|
||||
if result:
|
||||
return result
|
||||
return can_infer_text(answer, choices)
|
||||
|
||||
|
||||
def is_equal(asw: str, gt_asw: str) -> bool:
|
||||
asw = str(asw).lower().strip()
|
||||
gt_asw = str(gt_asw).lower().strip()
|
||||
|
||||
if gt_asw == asw:
|
||||
return True
|
||||
|
||||
try:
|
||||
a = eval(gt_asw)
|
||||
b = eval(asw)
|
||||
if abs(float(a) - float(b)) < 1e-6:
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
from latex2sympy2 import latex2sympy
|
||||
|
||||
a = latex2sympy(gt_asw)
|
||||
b = latex2sympy(asw)
|
||||
if abs(eval(str(a)) - eval(str(b))) < 1e-6:
|
||||
return True
|
||||
if abs(float(a) - float(b)) < 1e-6:
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class MathVision(EvalBase):
|
||||
def setup_data(self) -> list:
|
||||
split = getattr(self, "split", "testmini")
|
||||
try:
|
||||
dataset = load_dataset("MathLLMs/MathVision", split=split)
|
||||
except Exception:
|
||||
dataset = load_dataset("MathLLMs/MathVision", "default", split=split)
|
||||
print(f"Loaded {len(dataset)} examples from MathVision ({split})")
|
||||
return list(dataset)
|
||||
|
||||
def encode_image(self, pil_image: Image.Image) -> str:
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
def get_image_base64(self, item: dict) -> str:
|
||||
for key in ["decoded_image", "image"]:
|
||||
if key in item and item[key] is not None:
|
||||
if isinstance(item[key], Image.Image):
|
||||
return self.encode_image(item[key])
|
||||
raise ValueError(f"Could not find image for item {item.get('id', 'unknown')}")
|
||||
|
||||
def build_messages(self, item: dict) -> List[dict]:
|
||||
image_base64 = self.get_image_base64(item)
|
||||
question = item.get("question", "")
|
||||
choices = item.get("choices", [])
|
||||
|
||||
if choices:
|
||||
try:
|
||||
if isinstance(choices, str):
|
||||
choices = eval(choices)
|
||||
choices_text = "\n".join(
|
||||
[f"({chr(65+i)}) {c}" for i, c in enumerate(choices)]
|
||||
)
|
||||
hint = "Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end."
|
||||
prompt = f"Hint: {hint}\nQuestion: {question}\nChoices:\n{choices_text}"
|
||||
except Exception:
|
||||
hint = "Please answer the question and provide the final answer at the end."
|
||||
prompt = f"Hint: {hint}\nQuestion: {question}"
|
||||
else:
|
||||
hint = "Please answer the question and provide the final answer at the end."
|
||||
prompt = f"Hint: {hint}\nQuestion: {question}"
|
||||
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
||||
},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
def _prefetch_answer(self, response: str, item: dict) -> Tuple[Optional[str], bool]:
|
||||
choices = item.get("choices", [])
|
||||
|
||||
if choices:
|
||||
try:
|
||||
if isinstance(choices, str):
|
||||
choices = eval(choices)
|
||||
if len(choices) > 0:
|
||||
choices_dict = {chr(65 + i): val for i, val in enumerate(choices)}
|
||||
result = can_infer(response, choices_dict)
|
||||
if result:
|
||||
return result, True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None, False
|
||||
|
||||
async def _extract_with_gpt(self, question: str, response: str) -> Optional[str]:
|
||||
judge_model = getattr(self, "judge_model", "gpt-4o-mini")
|
||||
judge_base_url = getattr(self, "judge_base_url", "https://api.openai.com/v1")
|
||||
judge_api_key = os.environ.get(
|
||||
getattr(self, "judge_api_key_env", "OPENAI_API_KEY"), ""
|
||||
)
|
||||
|
||||
if not judge_api_key:
|
||||
return None
|
||||
|
||||
try:
|
||||
judge_client = openai.AsyncOpenAI(
|
||||
api_key=judge_api_key,
|
||||
base_url=judge_base_url,
|
||||
)
|
||||
|
||||
task_description = """Please read the following example.
|
||||
Then extract the answer from the model response and type it at the end of the prompt.
|
||||
|
||||
"""
|
||||
prompt = task_description
|
||||
for example in ICL_EXAMPLES:
|
||||
prompt += example + "\n"
|
||||
prompt += question + "\n"
|
||||
prompt += f"Model response: {response}\n"
|
||||
prompt += "Extracted answer:"
|
||||
|
||||
completion = await judge_client.chat.completions.create(
|
||||
model=judge_model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.0,
|
||||
max_tokens=128,
|
||||
)
|
||||
|
||||
result = completion.choices[0].message.content.strip()
|
||||
return result if result else None
|
||||
|
||||
except Exception as e:
|
||||
print(f"GPT extraction error: {e}")
|
||||
return None
|
||||
|
||||
def extract_answer_fallback(self, response: str) -> str:
|
||||
response = response.strip()
|
||||
|
||||
for char in reversed(response.upper()):
|
||||
if char in "ABCDEFGH":
|
||||
return char
|
||||
|
||||
numbers = re.findall(r"-?\d+\.?\d*", response)
|
||||
if numbers:
|
||||
return numbers[-1]
|
||||
|
||||
return response[:100]
|
||||
|
||||
def score(self, prediction: str, answer: str, item: dict) -> bool:
|
||||
choices = item.get("choices", [])
|
||||
|
||||
if choices:
|
||||
try:
|
||||
if isinstance(choices, str):
|
||||
choices = eval(choices)
|
||||
if len(choices) > 0:
|
||||
choices_dict = {chr(65 + i): val for i, val in enumerate(choices)}
|
||||
result = can_infer(prediction, choices_dict)
|
||||
if result:
|
||||
return result.upper() == answer.upper()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return is_equal(prediction, answer)
|
||||
|
||||
async def run_item(self, client: AsyncOpenAI, data_item: dict) -> Tuple[dict, dict]:
|
||||
try:
|
||||
messages = self.build_messages(data_item)
|
||||
|
||||
gen_params = self.get_generation_params()
|
||||
completion = await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=gen_params["temperature"],
|
||||
max_tokens=gen_params["max_tokens"],
|
||||
)
|
||||
|
||||
if not completion.choices:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
message = completion.choices[0].message
|
||||
response = message.content or ""
|
||||
if hasattr(message, "reasoning") and message.reasoning and not response:
|
||||
response = message.reasoning
|
||||
if not response and hasattr(message, "model_extra"):
|
||||
reasoning = message.model_extra.get("reasoning", "")
|
||||
if reasoning:
|
||||
response = reasoning
|
||||
|
||||
if not response:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
use_gpt_extraction = getattr(self, "use_gpt_extraction", True)
|
||||
answer = data_item.get("answer", "")
|
||||
|
||||
prefetch_result, prefetch_success = self._prefetch_answer(
|
||||
response, data_item
|
||||
)
|
||||
|
||||
if prefetch_success and prefetch_result:
|
||||
extracted = prefetch_result
|
||||
extraction_method = "prefetch"
|
||||
elif use_gpt_extraction:
|
||||
question = data_item.get("question", "")
|
||||
gpt_result = await self._extract_with_gpt(question, response)
|
||||
if gpt_result:
|
||||
extracted = gpt_result
|
||||
extraction_method = "gpt"
|
||||
else:
|
||||
extracted = self.extract_answer_fallback(response)
|
||||
extraction_method = "fallback"
|
||||
else:
|
||||
extracted = self.extract_answer_fallback(response)
|
||||
extraction_method = "fallback"
|
||||
|
||||
correct = self.score(extracted, answer, data_item)
|
||||
|
||||
sample = {
|
||||
"id": data_item.get("id", data_item.get("index", "")),
|
||||
"question": data_item.get("question", "")[:200],
|
||||
"answer": answer,
|
||||
"prediction": extracted,
|
||||
"raw_response": response[:500],
|
||||
"correct": correct,
|
||||
"category": data_item.get("category", ""),
|
||||
"extraction_method": extraction_method,
|
||||
}
|
||||
|
||||
return {"accuracy": 1.0 if correct else 0.0}, sample
|
||||
|
||||
except Exception as e:
|
||||
return {"accuracy": 0.0}, {"error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(
|
||||
eval_runner(
|
||||
MathVision,
|
||||
split="testmini",
|
||||
use_gpt_extraction=True,
|
||||
judge_model="gpt-4o-mini",
|
||||
temperature=0.0,
|
||||
max_tokens=2048,
|
||||
)
|
||||
)
|
||||
411
environments/eval_environments/mathvista_environment.py
Normal file
411
environments/eval_environments/mathvista_environment.py
Normal file
|
|
@ -0,0 +1,411 @@
|
|||
"""MathVista evaluation environment."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import openai
|
||||
from datasets import load_dataset
|
||||
from openai import AsyncOpenAI
|
||||
from PIL import Image
|
||||
|
||||
from environments.eval_environments.eval_base import EvalBase, eval_runner
|
||||
|
||||
ICL_EXAMPLES = [
|
||||
"""
|
||||
Hint: Please answer the question requiring an integer answer and provide the final value,
|
||||
e.g., 1, 2, 3, at the end.
|
||||
Question: Which number is missing?
|
||||
Model response: The number missing in the sequence is 14.
|
||||
Extracted answer: 14
|
||||
""",
|
||||
"""
|
||||
Hint: Please answer the question requiring a floating-point number with one decimal place and provide the final value,
|
||||
e.g., 1.2, 1.3, 1.4, at the end.
|
||||
Question: What is the fraction of females facing the camera?
|
||||
Model response: The fraction of females facing the camera is 0.6,
|
||||
which means that six out of ten females in the group are facing the camera.
|
||||
Extracted answer: 0.6
|
||||
""",
|
||||
"""
|
||||
Hint: Please answer the question requiring a floating-point number with two decimal places and provide the final value,
|
||||
e.g., 1.23, 1.34, 1.45, at the end.
|
||||
Question: How much money does Luca need to buy a sour apple candy and a butter-scotch candy? (Unit: $)
|
||||
Model response: Luca needs $1.45 to buy a sour apple candy and a butterscotch candy.
|
||||
Extracted answer: 1.45
|
||||
""",
|
||||
"""
|
||||
Hint: Please answer the question requiring a Python list as an answer and provide the final list,
|
||||
e.g., [1, 2, 3], [1.2, 1.3, 1.4], at the end.
|
||||
Question: Between which two years does the line graph saw its maximum peak?
|
||||
Model response: The line graph saw its maximum peak between 2007 and 2008.
|
||||
Extracted answer: [2007, 2008]
|
||||
""",
|
||||
"""
|
||||
Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end.
|
||||
Question: What fraction of the shape is blue?
|
||||
Choices: (A) 3/11 (B) 8/11 (C) 6/11 (D) 3/5
|
||||
Model response: The correct answer is (B) 8/11.
|
||||
Extracted answer: B
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
def build_extraction_prompt(question: str, prediction: str) -> str:
|
||||
task_description = """Please read the following example.
|
||||
Then extract the answer from the model response and type it at the end of the prompt.
|
||||
"""
|
||||
prompt = task_description
|
||||
for example in ICL_EXAMPLES:
|
||||
prompt += example + "\n"
|
||||
prompt += question + "\n"
|
||||
prompt += "Model response: " + prediction + "\n"
|
||||
prompt += "Extracted answer:"
|
||||
return prompt
|
||||
|
||||
|
||||
def can_infer_option(answer: str, choices: Dict[str, str]) -> Optional[str]:
|
||||
if "Failed to obtain answer via API" in answer:
|
||||
return None
|
||||
|
||||
answer_mod = answer
|
||||
for c in ".()[],:;!*#{}":
|
||||
answer_mod = answer_mod.replace(c, " ")
|
||||
|
||||
splits = [x.strip() for x in answer_mod.split()]
|
||||
count = sum(1 for ch in choices if ch in splits)
|
||||
|
||||
if count == 1:
|
||||
for ch in choices:
|
||||
if "A" in splits and len(splits) > 3:
|
||||
continue
|
||||
if ch in splits and splits.index(ch) > (len(splits) - 5):
|
||||
return ch
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def can_infer_text(answer: str, choices: Dict[str, str]) -> Optional[str]:
|
||||
answer_lower = answer.lower()
|
||||
|
||||
if len(answer_lower) > 2 * sum(len(str(v)) for v in choices.values()):
|
||||
return None
|
||||
|
||||
cands = []
|
||||
for k, v in choices.items():
|
||||
if str(v).lower() in answer_lower:
|
||||
cands.append(k)
|
||||
|
||||
if len(cands) == 1:
|
||||
return cands[0]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def can_infer(answer: str, choices: Dict[str, str]) -> Optional[str]:
|
||||
answer = str(answer)
|
||||
result = can_infer_option(answer, choices)
|
||||
if result:
|
||||
return result
|
||||
return can_infer_text(answer, choices)
|
||||
|
||||
|
||||
class MathVista(EvalBase):
|
||||
TASK_TYPES = ["FQA", "GPS", "MWP", "TQA", "VQA"]
|
||||
SKILL_TYPES = ["ALG", "ARI", "GEO", "LOG", "NUM", "SCI", "STA"]
|
||||
|
||||
def setup_data(self) -> list:
|
||||
split = getattr(self, "split", "testmini")
|
||||
dataset = load_dataset("AI4Math/MathVista", split=split)
|
||||
print(f"Loaded {len(dataset)} examples from MathVista ({split})")
|
||||
return list(dataset)
|
||||
|
||||
def encode_image(self, pil_image: Image.Image) -> str:
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
def get_image_base64(self, item: dict) -> str:
|
||||
if "decoded_image" in item and item["decoded_image"] is not None:
|
||||
return self.encode_image(item["decoded_image"])
|
||||
if "image" in item and item["image"] is not None:
|
||||
if isinstance(item["image"], Image.Image):
|
||||
return self.encode_image(item["image"])
|
||||
raise ValueError(f"Could not find image for item {item.get('pid', 'unknown')}")
|
||||
|
||||
def build_messages(self, item: dict) -> List[dict]:
|
||||
image_base64 = self.get_image_base64(item)
|
||||
|
||||
use_query = getattr(self, "use_query", True)
|
||||
if use_query and "query" in item:
|
||||
prompt = item["query"]
|
||||
else:
|
||||
prompt = self._build_custom_prompt(item)
|
||||
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
||||
},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
def _build_custom_prompt(self, item: dict) -> str:
|
||||
question = item.get("question", "")
|
||||
question_type = item.get("question_type", "free_form")
|
||||
answer_type = item.get("answer_type", "text")
|
||||
precision = item.get("precision", 2)
|
||||
|
||||
if question_type == "multi_choice":
|
||||
choices = item.get("choices", [])
|
||||
choices_text = "\n".join(choices) if choices else ""
|
||||
hint = (
|
||||
"Please answer the question and provide the correct option letter, "
|
||||
"e.g., A, B, C, D, at the end."
|
||||
)
|
||||
return f"Hint: {hint}\nQuestion: {question}\nChoices:\n{choices_text}"
|
||||
|
||||
if answer_type == "integer":
|
||||
hint = (
|
||||
"Please answer the question requiring an integer answer "
|
||||
"and provide the final value, e.g., 1, 2, 3, at the end."
|
||||
)
|
||||
elif answer_type == "float":
|
||||
hint = (
|
||||
f"Please answer the question requiring a floating-point number "
|
||||
f"with {precision} decimal place(s) and provide the final value at the end."
|
||||
)
|
||||
elif answer_type == "list":
|
||||
hint = (
|
||||
"Please answer the question requiring a Python list as an answer "
|
||||
"and provide the final list, e.g., [1, 2, 3], at the end."
|
||||
)
|
||||
else:
|
||||
hint = "Please answer the question and provide the final answer at the end."
|
||||
|
||||
return f"Hint: {hint}\nQuestion: {question}"
|
||||
|
||||
def _prefetch_answer(self, response: str, item: dict) -> Tuple[Optional[str], bool]:
|
||||
question_type = item.get("question_type", "free_form")
|
||||
answer_type = item.get("answer_type", "text")
|
||||
|
||||
if question_type == "multi_choice":
|
||||
choices_list = item.get("choices", [])
|
||||
if choices_list:
|
||||
choices = {chr(65 + i): val for i, val in enumerate(choices_list)}
|
||||
result = can_infer(response, choices)
|
||||
if result:
|
||||
return result, True
|
||||
|
||||
# Fallback: find last letter
|
||||
for char in reversed(response.upper()):
|
||||
if char in "ABCDEFGH":
|
||||
return char, True
|
||||
return None, False
|
||||
|
||||
response = response.strip()
|
||||
|
||||
if answer_type == "integer":
|
||||
numbers = re.findall(r"-?\d+", response)
|
||||
if numbers:
|
||||
return numbers[-1], True
|
||||
|
||||
elif answer_type == "float":
|
||||
numbers = re.findall(r"-?\d+\.?\d*", response)
|
||||
if numbers:
|
||||
return numbers[-1], True
|
||||
|
||||
elif answer_type == "list":
|
||||
match = re.search(r"\[[\d\.,\s-]+\]", response)
|
||||
if match:
|
||||
return match.group(0), True
|
||||
|
||||
return None, False
|
||||
|
||||
async def _extract_with_gpt(self, question: str, response: str) -> Optional[str]:
|
||||
judge_model = getattr(self, "judge_model", "gpt-4o-mini")
|
||||
judge_base_url = getattr(self, "judge_base_url", "https://api.openai.com/v1")
|
||||
judge_api_key = os.environ.get(
|
||||
getattr(self, "judge_api_key_env", "OPENAI_API_KEY"), ""
|
||||
)
|
||||
|
||||
if not judge_api_key:
|
||||
return None
|
||||
|
||||
try:
|
||||
judge_client = openai.AsyncOpenAI(
|
||||
api_key=judge_api_key,
|
||||
base_url=judge_base_url,
|
||||
)
|
||||
|
||||
prompt = build_extraction_prompt(question, response)
|
||||
|
||||
completion = await judge_client.chat.completions.create(
|
||||
model=judge_model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.0,
|
||||
max_tokens=128,
|
||||
)
|
||||
|
||||
result = completion.choices[0].message.content.strip()
|
||||
return result if result else None
|
||||
|
||||
except Exception as e:
|
||||
print(f"GPT extraction error: {e}")
|
||||
return None
|
||||
|
||||
def extract_answer(
|
||||
self, response: str, answer_type: str, question_type: str
|
||||
) -> str:
|
||||
response = response.strip()
|
||||
|
||||
if question_type == "multi_choice":
|
||||
for char in reversed(response.upper()):
|
||||
if char in "ABCDEFGH":
|
||||
return char
|
||||
return ""
|
||||
|
||||
if answer_type == "integer":
|
||||
numbers = re.findall(r"-?\d+", response)
|
||||
return numbers[-1] if numbers else ""
|
||||
|
||||
if answer_type == "float":
|
||||
numbers = re.findall(r"-?\d+\.?\d*", response)
|
||||
return numbers[-1] if numbers else ""
|
||||
|
||||
if answer_type == "list":
|
||||
match = re.search(r"\[[\d\.,\s-]+\]", response)
|
||||
return match.group(0) if match else ""
|
||||
|
||||
return response
|
||||
|
||||
def score(
|
||||
self, prediction: str, answer: str, answer_type: str, precision: int = 0
|
||||
) -> bool:
|
||||
pred = prediction.strip()
|
||||
ans = answer.strip()
|
||||
|
||||
if not pred:
|
||||
return False
|
||||
|
||||
if answer_type == "text":
|
||||
return pred.upper() == ans.upper()
|
||||
|
||||
if answer_type == "integer":
|
||||
try:
|
||||
return int(float(pred)) == int(float(ans))
|
||||
except (ValueError, OverflowError):
|
||||
return False
|
||||
|
||||
if answer_type == "float":
|
||||
try:
|
||||
tolerance = 10 ** (-precision) if precision > 0 else 0.01
|
||||
return abs(float(pred) - float(ans)) < tolerance
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
if answer_type == "list":
|
||||
try:
|
||||
pred_list = eval(pred)
|
||||
ans_list = eval(ans)
|
||||
return pred_list == ans_list
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
return pred.lower() == ans.lower()
|
||||
|
||||
async def run_item(self, client: AsyncOpenAI, data_item: dict) -> Tuple[dict, dict]:
|
||||
try:
|
||||
messages = self.build_messages(data_item)
|
||||
|
||||
gen_params = self.get_generation_params()
|
||||
completion = await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=gen_params["temperature"],
|
||||
max_tokens=gen_params["max_tokens"],
|
||||
)
|
||||
|
||||
if not completion.choices:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
message = completion.choices[0].message
|
||||
response = message.content or ""
|
||||
if hasattr(message, "reasoning") and message.reasoning and not response:
|
||||
response = message.reasoning
|
||||
if not response and hasattr(message, "model_extra"):
|
||||
reasoning = message.model_extra.get("reasoning", "")
|
||||
if reasoning:
|
||||
response = reasoning
|
||||
|
||||
if not response:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
answer_type = data_item.get("answer_type", "text")
|
||||
question_type = data_item.get("question_type", "free_form")
|
||||
precision = data_item.get("precision", 0)
|
||||
|
||||
use_gpt_extraction = getattr(self, "use_gpt_extraction", True)
|
||||
prefetch_result, prefetch_success = self._prefetch_answer(
|
||||
response, data_item
|
||||
)
|
||||
|
||||
if prefetch_success and prefetch_result:
|
||||
extracted = prefetch_result
|
||||
extraction_method = "prefetch"
|
||||
elif use_gpt_extraction:
|
||||
question = data_item.get("query", data_item.get("question", ""))
|
||||
gpt_result = await self._extract_with_gpt(question, response)
|
||||
if gpt_result:
|
||||
extracted = gpt_result
|
||||
extraction_method = "gpt"
|
||||
else:
|
||||
extracted = self.extract_answer(
|
||||
response, answer_type, question_type
|
||||
)
|
||||
extraction_method = "regex_fallback"
|
||||
else:
|
||||
extracted = self.extract_answer(response, answer_type, question_type)
|
||||
extraction_method = "regex"
|
||||
|
||||
answer = data_item.get("answer", "")
|
||||
correct = self.score(extracted, answer, answer_type, precision)
|
||||
|
||||
sample = {
|
||||
"pid": data_item.get("pid", ""),
|
||||
"question": data_item.get("question", ""),
|
||||
"answer": answer,
|
||||
"prediction": extracted,
|
||||
"raw_response": response[:500],
|
||||
"correct": correct,
|
||||
"question_type": question_type,
|
||||
"answer_type": answer_type,
|
||||
"extraction_method": extraction_method,
|
||||
}
|
||||
|
||||
return {"accuracy": 1.0 if correct else 0.0}, sample
|
||||
|
||||
except Exception as e:
|
||||
return {"accuracy": 0.0}, {"error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(
|
||||
eval_runner(
|
||||
MathVista,
|
||||
split="testmini",
|
||||
use_query=True,
|
||||
use_gpt_extraction=True,
|
||||
judge_model="gpt-4o-mini",
|
||||
temperature=0.0,
|
||||
max_tokens=4096,
|
||||
)
|
||||
)
|
||||
132
environments/eval_environments/realworldqa_environment.py
Normal file
132
environments/eval_environments/realworldqa_environment.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
from typing import List, Tuple
|
||||
|
||||
from datasets import load_dataset
|
||||
from openai import AsyncOpenAI
|
||||
from PIL import Image
|
||||
|
||||
from environments.eval_environments.eval_base import EvalBase, eval_runner
|
||||
|
||||
|
||||
class RealWorldQA(EvalBase):
|
||||
def setup_data(self) -> list:
|
||||
split = getattr(self, "split", "test")
|
||||
dataset = load_dataset("xai-org/RealworldQA", split=split)
|
||||
print(f"Loaded {len(dataset)} examples from RealWorldQA ({split})")
|
||||
return list(dataset)
|
||||
|
||||
def encode_image(self, pil_image: Image.Image) -> str:
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
def get_image_base64(self, item: dict) -> str:
|
||||
if "image" in item and item["image"] is not None:
|
||||
if isinstance(item["image"], Image.Image):
|
||||
return self.encode_image(item["image"])
|
||||
raise ValueError("Could not find image for item")
|
||||
|
||||
def build_messages(self, item: dict) -> List[dict]:
|
||||
image_base64 = self.get_image_base64(item)
|
||||
question = item.get("question", "")
|
||||
|
||||
prompt = f"""{question}
|
||||
|
||||
Provide a brief, direct answer."""
|
||||
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
||||
},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
def extract_answer(self, response: str) -> str:
|
||||
response = response.strip()
|
||||
lines = response.split("\n")
|
||||
if lines:
|
||||
return lines[0].strip()
|
||||
return response
|
||||
|
||||
def score(self, prediction: str, answer: str) -> bool:
|
||||
pred = prediction.strip().lower()
|
||||
ans = answer.strip().lower()
|
||||
|
||||
if not pred:
|
||||
return False
|
||||
|
||||
if pred == ans:
|
||||
return True
|
||||
|
||||
if ans in pred or pred in ans:
|
||||
return True
|
||||
|
||||
pred_words = set(pred.split())
|
||||
ans_words = set(ans.split())
|
||||
overlap = pred_words & ans_words
|
||||
if len(overlap) >= len(ans_words) * 0.5:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def run_item(self, client: AsyncOpenAI, data_item: dict) -> Tuple[dict, dict]:
|
||||
try:
|
||||
messages = self.build_messages(data_item)
|
||||
|
||||
gen_params = self.get_generation_params()
|
||||
completion = await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=gen_params["temperature"],
|
||||
max_tokens=gen_params["max_tokens"],
|
||||
)
|
||||
|
||||
if not completion.choices:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
message = completion.choices[0].message
|
||||
response = message.content or ""
|
||||
if hasattr(message, "reasoning") and message.reasoning and not response:
|
||||
response = message.reasoning
|
||||
if not response and hasattr(message, "model_extra"):
|
||||
reasoning = message.model_extra.get("reasoning", "")
|
||||
if reasoning:
|
||||
response = reasoning
|
||||
|
||||
if not response:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
extracted = self.extract_answer(response)
|
||||
answer = data_item.get("answer", "")
|
||||
correct = self.score(extracted, answer)
|
||||
|
||||
sample = {
|
||||
"question": data_item.get("question", ""),
|
||||
"answer": answer,
|
||||
"prediction": extracted,
|
||||
"correct": correct,
|
||||
}
|
||||
|
||||
return {"accuracy": 1.0 if correct else 0.0}, sample
|
||||
|
||||
except Exception as e:
|
||||
return {"accuracy": 0.0}, {"error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(
|
||||
eval_runner(
|
||||
RealWorldQA,
|
||||
split="test",
|
||||
temperature=0.0,
|
||||
max_tokens=256,
|
||||
)
|
||||
)
|
||||
194
environments/eval_environments/visulogic_environment.py
Normal file
194
environments/eval_environments/visulogic_environment.py
Normal file
|
|
@ -0,0 +1,194 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from openai import AsyncOpenAI
|
||||
from PIL import Image
|
||||
|
||||
from environments.eval_environments.eval_base import EvalBase, eval_runner
|
||||
|
||||
DEFAULT_DATA_DIR = Path.home() / ".cache" / "visulogic_hf"
|
||||
|
||||
|
||||
class VisuLogic(EvalBase):
|
||||
TAGS = [
|
||||
"Quantitative Reasoning",
|
||||
"Spatial Reasoning",
|
||||
"Positional Reasoning",
|
||||
"Attribute Reasoning",
|
||||
"Stylistic Reasoning",
|
||||
"Other",
|
||||
]
|
||||
|
||||
def _download_data(self, data_dir: Path) -> None:
|
||||
jsonl_path = data_dir / "data.jsonl"
|
||||
images_dir = data_dir / "images"
|
||||
|
||||
if jsonl_path.exists() and images_dir.exists():
|
||||
return
|
||||
|
||||
print(f"Downloading VisuLogic dataset to {data_dir}...")
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Download data.jsonl
|
||||
hf_hub_download(
|
||||
repo_id="VisuLogic/VisuLogic",
|
||||
filename="data.jsonl",
|
||||
repo_type="dataset",
|
||||
local_dir=data_dir,
|
||||
)
|
||||
|
||||
# Download and extract images.zip
|
||||
images_zip_path = hf_hub_download(
|
||||
repo_id="VisuLogic/VisuLogic",
|
||||
filename="images.zip",
|
||||
repo_type="dataset",
|
||||
local_dir=data_dir,
|
||||
)
|
||||
|
||||
print("Extracting images...")
|
||||
with zipfile.ZipFile(images_zip_path, "r") as zip_ref:
|
||||
zip_ref.extractall(data_dir)
|
||||
|
||||
print("Download complete!")
|
||||
|
||||
def setup_data(self) -> list:
|
||||
"""
|
||||
Load and return dataset as a list.
|
||||
|
||||
Auto-downloads the VisuLogic dataset if data_path is not specified
|
||||
or doesn't exist.
|
||||
"""
|
||||
data_path = getattr(self, "data_path", None)
|
||||
|
||||
if data_path is None:
|
||||
data_dir = DEFAULT_DATA_DIR
|
||||
self._download_data(data_dir)
|
||||
jsonl_path = data_dir / "data.jsonl"
|
||||
self.images_base = str(data_dir)
|
||||
else:
|
||||
data_dir = Path(data_path)
|
||||
jsonl_path = data_dir / "data.jsonl"
|
||||
self.images_base = str(data_dir)
|
||||
|
||||
if not jsonl_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Dataset not found at {jsonl_path}. "
|
||||
"Remove data_path argument to auto-download."
|
||||
)
|
||||
|
||||
dataset = []
|
||||
with open(jsonl_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
item = json.loads(line.strip())
|
||||
dataset.append(item)
|
||||
|
||||
print(f"Loaded {len(dataset)} examples from VisuLogic")
|
||||
return dataset
|
||||
|
||||
def encode_image(self, pil_image: Image.Image) -> str:
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
def get_image_base64(self, item: dict) -> str:
|
||||
image_path = item.get("image_path", "")
|
||||
full_path = Path(self.images_base) / image_path
|
||||
if full_path.exists():
|
||||
with Image.open(full_path) as img:
|
||||
return self.encode_image(img)
|
||||
raise ValueError(f"Could not find image at {full_path}")
|
||||
|
||||
def build_messages(self, item: dict) -> List[dict]:
|
||||
image_base64 = self.get_image_base64(item)
|
||||
question = item.get("question", "")
|
||||
|
||||
prompt = f"""{question}
|
||||
|
||||
Answer with only the letter (A, B, C, or D)."""
|
||||
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
||||
},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
def extract_answer(self, response: str) -> str:
|
||||
response = response.strip().upper()
|
||||
|
||||
for char in reversed(response):
|
||||
if char in "ABCD":
|
||||
return char
|
||||
|
||||
return ""
|
||||
|
||||
def score(self, prediction: str, answer: str) -> bool:
|
||||
if not prediction:
|
||||
return False
|
||||
return prediction.upper() == answer.upper()
|
||||
|
||||
async def run_item(self, client: AsyncOpenAI, data_item: dict) -> Tuple[dict, dict]:
|
||||
try:
|
||||
messages = self.build_messages(data_item)
|
||||
|
||||
gen_params = self.get_generation_params()
|
||||
completion = await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=gen_params["temperature"],
|
||||
max_tokens=gen_params["max_tokens"],
|
||||
)
|
||||
|
||||
if not completion.choices:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
message = completion.choices[0].message
|
||||
response = message.content or ""
|
||||
if hasattr(message, "reasoning") and message.reasoning and not response:
|
||||
response = message.reasoning
|
||||
if not response and hasattr(message, "model_extra"):
|
||||
reasoning = message.model_extra.get("reasoning", "")
|
||||
if reasoning:
|
||||
response = reasoning
|
||||
|
||||
if not response:
|
||||
return {"accuracy": 0.0}, {"error": "Empty response"}
|
||||
|
||||
extracted = self.extract_answer(response)
|
||||
answer = data_item.get("label", "")
|
||||
correct = self.score(extracted, answer)
|
||||
|
||||
sample = {
|
||||
"question": data_item.get("question", ""),
|
||||
"answer": answer,
|
||||
"prediction": extracted,
|
||||
"correct": correct,
|
||||
"tag": data_item.get("tag", ""),
|
||||
}
|
||||
|
||||
return {"accuracy": 1.0 if correct else 0.0}, sample
|
||||
|
||||
except Exception as e:
|
||||
return {"accuracy": 0.0}, {"error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(
|
||||
eval_runner(
|
||||
VisuLogic,
|
||||
temperature=0.0,
|
||||
max_tokens=256,
|
||||
)
|
||||
)
|
||||
381
environments/eval_environments/wemath_environment.py
Normal file
381
environments/eval_environments/wemath_environment.py
Normal file
|
|
@ -0,0 +1,381 @@
|
|||
"""We-Math evaluation environment."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import re
|
||||
import string
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import pandas as pd
|
||||
from datasets import load_dataset
|
||||
from openai import AsyncOpenAI
|
||||
from PIL import Image
|
||||
|
||||
from environments.eval_environments.eval_base import EvalBase, eval_runner
|
||||
|
||||
|
||||
class WeMath(EvalBase):
|
||||
"""
|
||||
We-Math evaluation environment.
|
||||
|
||||
A benchmark for visual mathematical reasoning with multi-step problems
|
||||
and 4-dimensional evaluation metrics (IK, IG, CM, RM).
|
||||
"""
|
||||
|
||||
def setup_data(self) -> list:
|
||||
split = getattr(self, "split", "testmini")
|
||||
dataset = load_dataset("We-Math/We-Math", split=split)
|
||||
print(f"Loaded {len(dataset)} examples from We-Math ({split})")
|
||||
return list(dataset)
|
||||
|
||||
def encode_image(self, pil_image: Image.Image) -> str:
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
def get_image_base64(self, item: dict) -> str:
|
||||
img = item.get("image_path") or item.get("image")
|
||||
if img is not None:
|
||||
if isinstance(img, Image.Image):
|
||||
return self.encode_image(img)
|
||||
elif isinstance(img, bytes):
|
||||
return base64.b64encode(img).decode("utf-8")
|
||||
raise ValueError(
|
||||
f"Could not find image for item {item.get('ID', item.get('problem_id', 'unknown'))}"
|
||||
)
|
||||
|
||||
def build_messages(self, item: dict) -> List[dict]:
|
||||
"""Build prompt with question, options, and optional hint (MCQ format)."""
|
||||
image_base64 = self.get_image_base64(item)
|
||||
question = item.get("question", "")
|
||||
|
||||
# Build options from A-H if present
|
||||
options = {}
|
||||
for letter in string.ascii_uppercase[:8]: # A-H
|
||||
if (
|
||||
letter in item
|
||||
and item[letter] is not None
|
||||
and not pd.isna(item.get(letter, float("nan")))
|
||||
):
|
||||
options[letter] = item[letter]
|
||||
|
||||
# Build prompt
|
||||
prompt_parts = []
|
||||
|
||||
# Add hint if present
|
||||
hint = item.get("hint", "")
|
||||
if hint and not pd.isna(hint):
|
||||
prompt_parts.append(f"Hint: {hint}")
|
||||
|
||||
prompt_parts.append(f"Question: {question}")
|
||||
|
||||
# Add options if present
|
||||
if options:
|
||||
options_text = "Options:\n"
|
||||
for letter, value in options.items():
|
||||
options_text += f"{letter}. {value}\n"
|
||||
prompt_parts.append(options_text)
|
||||
|
||||
# Add COT requirement if dataset is WeMath_COT
|
||||
use_cot = getattr(self, "use_cot", False)
|
||||
requirement = item.get("requirement", "")
|
||||
if use_cot and requirement and not pd.isna(requirement):
|
||||
prompt_parts.append(requirement)
|
||||
else:
|
||||
prompt_parts.append(
|
||||
"Answer with the option's letter from the given choices directly."
|
||||
)
|
||||
|
||||
prompt = "\n".join(prompt_parts)
|
||||
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
||||
},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
def extract_answer(self, response: str) -> str:
|
||||
"""
|
||||
Extract MCQ answer letter from response.
|
||||
|
||||
Following VLMEvalKit logic: look for letter after "Answer" keyword,
|
||||
or extract first valid letter (A-H).
|
||||
"""
|
||||
response = str(response).strip()
|
||||
|
||||
# Try to find answer after "Answer" keyword
|
||||
answer_match = re.search(r"Answer[:\s]*([A-Ha-h])", response, re.IGNORECASE)
|
||||
if answer_match:
|
||||
return answer_match.group(1).upper()
|
||||
|
||||
# Clean response and look for first valid letter
|
||||
cleaned = re.sub(r"[>><<:.\s]", "", response).strip()
|
||||
if cleaned and cleaned[0].upper() in "ABCDEFGH":
|
||||
return cleaned[0].upper()
|
||||
|
||||
# Fallback: find any letter A-H in the response
|
||||
for char in response.upper():
|
||||
if char in "ABCDEFGH":
|
||||
return char
|
||||
|
||||
return ""
|
||||
|
||||
def score(self, prediction: str, answer: str) -> bool:
|
||||
"""Check if prediction matches answer (case-insensitive)."""
|
||||
if not prediction:
|
||||
return False
|
||||
return prediction.upper() == answer.upper()
|
||||
|
||||
async def run_item(self, client: AsyncOpenAI, data_item: dict) -> Tuple[dict, dict]:
|
||||
try:
|
||||
messages = self.build_messages(data_item)
|
||||
|
||||
gen_params = self.get_generation_params()
|
||||
completion = await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=gen_params["temperature"],
|
||||
max_tokens=gen_params["max_tokens"],
|
||||
)
|
||||
|
||||
if not completion.choices:
|
||||
return {"accuracy": 0.0, "hit": 0}, {"error": "Empty response"}
|
||||
|
||||
message = completion.choices[0].message
|
||||
response = message.content or ""
|
||||
if hasattr(message, "reasoning") and message.reasoning and not response:
|
||||
response = message.reasoning
|
||||
if not response and hasattr(message, "model_extra"):
|
||||
reasoning = message.model_extra.get("reasoning", "")
|
||||
if reasoning:
|
||||
response = reasoning
|
||||
|
||||
if not response:
|
||||
return {"accuracy": 0.0, "hit": 0}, {"error": "Empty response"}
|
||||
|
||||
extracted = self.extract_answer(response)
|
||||
answer = data_item.get("answer", "")
|
||||
correct = self.score(extracted, answer)
|
||||
|
||||
# Get problem metadata for 4-dimensional analysis
|
||||
problem_id = data_item.get("ID", data_item.get("problem_id", ""))
|
||||
key = data_item.get("key", "") # e.g., "2steps_1", "2steps_multi", etc.
|
||||
|
||||
sample = {
|
||||
"ID": problem_id,
|
||||
"key": key,
|
||||
"question": data_item.get("question", ""),
|
||||
"answer": answer,
|
||||
"prediction": extracted,
|
||||
"raw_response": response[:500], # Truncate for logging
|
||||
"hit": 1 if correct else 0,
|
||||
"joker": correct, # VLMEvalKit naming convention
|
||||
}
|
||||
|
||||
return {
|
||||
"accuracy": 1.0 if correct else 0.0,
|
||||
"hit": 1 if correct else 0,
|
||||
}, sample
|
||||
|
||||
except Exception as e:
|
||||
return {"accuracy": 0.0, "hit": 0}, {"error": str(e)}
|
||||
|
||||
|
||||
def compute_4d_metrics(samples: List[dict]) -> Dict:
|
||||
"""
|
||||
Compute We-Math 4-dimensional metrics from evaluation samples.
|
||||
|
||||
This implements the evaluation logic from VLMEvalKit's wemath.py.
|
||||
|
||||
Returns metrics for:
|
||||
- IK (Insufficient Knowledge): Steps wrong AND multi wrong
|
||||
- IG (Inadequate Generalization): Steps right BUT multi wrong
|
||||
- CM (Complete Mastery): Steps right AND multi right
|
||||
- RM (Rote Memorization): Steps wrong BUT multi right
|
||||
"""
|
||||
# Convert samples to DataFrame
|
||||
df = pd.DataFrame(samples)
|
||||
|
||||
if "key" not in df.columns or df["key"].isna().all():
|
||||
# Dataset doesn't have step structure, return basic accuracy
|
||||
return {
|
||||
"overall_accuracy": df["hit"].mean() if "hit" in df.columns else 0.0,
|
||||
"note": "Dataset lacks step structure for 4D metrics",
|
||||
}
|
||||
|
||||
# Separate by step type
|
||||
data_2steps = df[df["key"].str.contains("2steps", na=False)]
|
||||
data_3steps = df[df["key"].str.contains("3steps", na=False)]
|
||||
|
||||
metrics = {}
|
||||
|
||||
# Process 2-step problems
|
||||
if len(data_2steps) > 0:
|
||||
merged_2steps = _process_steps_data(data_2steps, 2)
|
||||
if merged_2steps is not None:
|
||||
metrics["2step"] = _calculate_step_metrics(merged_2steps, 2)
|
||||
|
||||
# Process 3-step problems
|
||||
if len(data_3steps) > 0:
|
||||
merged_3steps = _process_steps_data(data_3steps, 3)
|
||||
if merged_3steps is not None:
|
||||
metrics["3step"] = _calculate_step_metrics(merged_3steps, 3)
|
||||
|
||||
# Compute overall 4D metrics
|
||||
if "2step" in metrics or "3step" in metrics:
|
||||
total_counts = _compute_total_counts(metrics)
|
||||
total_count = 525 # Standard We-Math total
|
||||
|
||||
# Compute rates and final scores
|
||||
final_metrics = _compute_final_scores(total_counts, total_count)
|
||||
metrics["overall"] = final_metrics
|
||||
|
||||
# Basic accuracy
|
||||
metrics["step_accuracy"] = df["hit"].mean() if "hit" in df.columns else 0.0
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def _process_steps_data(df: pd.DataFrame, steps: int) -> pd.DataFrame:
|
||||
"""Process step data and merge by problem ID."""
|
||||
try:
|
||||
steps_data = {}
|
||||
for i in range(1, steps + 1):
|
||||
key = f"{steps}steps_{i}"
|
||||
step_df = df[df["key"] == key].copy()
|
||||
if len(step_df) == 0:
|
||||
return None
|
||||
step_df.columns = [f"{col}_{i}" for col in step_df.columns]
|
||||
steps_data[i] = step_df
|
||||
|
||||
# Get multi-step data
|
||||
multi_key = f"{steps}steps_multi"
|
||||
multi_df = df[df["key"] == multi_key].copy()
|
||||
if len(multi_df) == 0:
|
||||
return None
|
||||
multi_df.columns = [f"{col}_multi" for col in multi_df.columns]
|
||||
|
||||
# Merge all steps
|
||||
merged = steps_data[1]
|
||||
for i in range(2, steps + 1):
|
||||
merged = pd.merge(
|
||||
merged,
|
||||
steps_data[i],
|
||||
left_on="ID_1",
|
||||
right_on=f"ID_{i}",
|
||||
how="left",
|
||||
)
|
||||
merged = pd.merge(
|
||||
merged, multi_df, left_on="ID_1", right_on="ID_multi", how="left"
|
||||
)
|
||||
|
||||
return merged
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _calculate_step_metrics(merged: pd.DataFrame, steps: int) -> Dict:
|
||||
"""Calculate metrics for a step type (2-step or 3-step)."""
|
||||
try:
|
||||
# Get joker columns
|
||||
joker_cols = [f"joker_{i}" for i in range(1, steps + 1)]
|
||||
joker_multi = "joker_multi"
|
||||
|
||||
# Check if columns exist
|
||||
for col in joker_cols + [joker_multi]:
|
||||
if col not in merged.columns:
|
||||
return {}
|
||||
|
||||
# Calculate conditions
|
||||
all_steps_correct = merged[joker_cols].all(axis=1)
|
||||
any_step_correct = merged[joker_cols].any(axis=1)
|
||||
all_steps_wrong = ~merged[joker_cols].any(axis=1)
|
||||
any_step_wrong = ~merged[joker_cols].all(axis=1)
|
||||
multi_correct = merged[joker_multi] == True # noqa: E712
|
||||
|
||||
return {
|
||||
# Strict: ALL steps must be correct
|
||||
"CM_strict": int((all_steps_correct & multi_correct).sum()),
|
||||
"IG": int((all_steps_correct & ~multi_correct).sum()),
|
||||
"RM_strict": int((any_step_wrong & multi_correct).sum()),
|
||||
"IK": int((any_step_wrong & ~multi_correct).sum()),
|
||||
# Loose: ANY step correct
|
||||
"CM_loose": int((any_step_correct & multi_correct).sum()),
|
||||
"RM_loose": int((all_steps_wrong & multi_correct).sum()),
|
||||
"total": len(merged),
|
||||
}
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def _compute_total_counts(metrics: Dict) -> Dict:
|
||||
"""Aggregate counts across step types."""
|
||||
totals = defaultdict(int)
|
||||
|
||||
for step_type in ["2step", "3step"]:
|
||||
if step_type in metrics:
|
||||
for key in ["CM_strict", "CM_loose", "IG", "RM_strict", "RM_loose", "IK"]:
|
||||
if key in metrics[step_type]:
|
||||
totals[key] += metrics[step_type][key]
|
||||
|
||||
return dict(totals)
|
||||
|
||||
|
||||
def _compute_final_scores(total_counts: Dict, total_count: int = 525) -> Dict:
|
||||
"""Compute final 4D scores and rates."""
|
||||
results = {}
|
||||
|
||||
# Calculate rates
|
||||
for key in ["IK", "IG", "CM_strict", "CM_loose", "RM_strict", "RM_loose"]:
|
||||
count = total_counts.get(key, 0)
|
||||
results[f"{key}_count"] = count
|
||||
results[f"{key}_rate"] = count / total_count if total_count > 0 else 0.0
|
||||
|
||||
# Calculate RM rates (relative to CM + RM)
|
||||
cm_rm_strict = total_counts.get("CM_strict", 0) + total_counts.get("RM_strict", 0)
|
||||
cm_rm_loose = total_counts.get("CM_loose", 0) + total_counts.get("RM_loose", 0)
|
||||
|
||||
results["RM_strict_relative"] = (
|
||||
total_counts.get("RM_strict", 0) / cm_rm_strict if cm_rm_strict > 0 else 0.0
|
||||
)
|
||||
results["RM_loose_relative"] = (
|
||||
total_counts.get("RM_loose", 0) / cm_rm_loose if cm_rm_loose > 0 else 0.0
|
||||
)
|
||||
|
||||
# Final scores (VLMEvalKit formula)
|
||||
results["score_strict"] = (
|
||||
total_count
|
||||
- 0.5 * total_counts.get("IG", 0)
|
||||
- total_counts.get("RM_strict", 0)
|
||||
- total_counts.get("IK", 0)
|
||||
) / total_count
|
||||
|
||||
results["score_loose"] = (
|
||||
total_count
|
||||
- 0.5 * total_counts.get("IG", 0)
|
||||
- total_counts.get("RM_loose", 0)
|
||||
- total_counts.get("IK", 0)
|
||||
) / total_count
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(
|
||||
eval_runner(
|
||||
WeMath,
|
||||
split="testmini",
|
||||
use_cot=False,
|
||||
temperature=0.0,
|
||||
max_tokens=512,
|
||||
)
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue