atropos/environments/eval_environments/vision_evals/charxiv_environment.py
2026-01-23 00:49:51 +00:00

382 lines
14 KiB
Python

"""CharXiv evaluation environment."""
import asyncio
import base64
import io
import json
import os
import re
from typing import Dict, List, Optional, Tuple
import openai
from datasets import load_dataset
from environments.eval_environments.eval import EvalBase, eval_runner
from PIL import Image
from atroposlib.envs.server_handling.server_manager import ServerManager
DESCRIPTIVE_CATEGORIES = {
1: "Information Extraction",
2: "Information Extraction",
3: "Information Extraction",
4: "Information Extraction",
5: "Information Extraction",
6: "Information Extraction",
7: "Information Extraction",
8: "Enumeration",
9: "Enumeration",
10: "Counting",
11: "Pattern Recognition",
12: "Counting",
13: "Enumeration",
14: "Enumeration",
15: "Enumeration",
16: "Pattern Recognition",
17: "Compositionality",
18: "Pattern Recognition",
19: "Counting",
}
REASONING_CATEGORIES = {
1: "Text-in-Chart",
2: "Text-in-General",
3: "Number-in-Chart",
4: "Number-in-General",
}
DESCRIPTIVE_QUESTIONS = {
1: "What is the title of the chart?",
2: "What is the label of the x-axis?",
3: "What is the label of the y-axis?",
4: "What is the leftmost labeled tick on the x-axis?",
5: "What is the rightmost labeled tick on the x-axis?",
6: "What is the spatially lowest labeled tick on the y-axis?",
7: "What is the spatially highest labeled tick on the y-axis?",
8: "What are all the labels in the legend?",
9: "List all the categories in the x-axis.",
10: "How many distinct bars are there?",
11: "Does the chart contain a grid?",
12: "How many lines are there in the chart?",
13: "Is there a legend in the chart?",
14: "What are the names of the curves in the chart?",
15: "Does the chart contain horizontal bars?",
16: "Do the bars have error bars?",
17: "Describe the general trend of the chart.",
18: "Is there any point emphasized/highlighted in the chart?",
19: "How many sections does the pie chart have?",
}
GRADING_QUERY_TEMPLATE = """You are evaluating a model's answer to a chart understanding question.
Question: {question}
Ground Truth Answer: {answer}
Model's Answer: {prediction}
Please evaluate whether the model's answer is correct or partially correct.
Consider semantic equivalence - different phrasings that mean the same thing should be considered correct.
For numerical answers, exact matches or very close values should be considered correct.
For yes/no questions, the meaning should match the ground truth.
For enumeration questions (listing items), all items should be present regardless of order.
Respond with a JSON object containing:
- "extract_answer": The key answer extracted from the model's response
- "score": A float from 0.0 to 1.0 indicating correctness (0.0 = wrong, 0.5 = partial, 1.0 = correct)
Example response: {{"extract_answer": "60", "score": 1.0}}"""
class CharXiv(EvalBase):
MODES = ["descriptive", "reasoning"]
def setup_data(self) -> list:
mode = getattr(self, "mode", "descriptive")
split = getattr(self, "split", "validation")
dataset = load_dataset("princeton-nlp/CharXiv", "default", split=split)
data = []
for item in dataset:
if mode == "descriptive":
for i in range(1, 5):
q_key = f"descriptive_q{i}"
a_key = f"descriptive_a{i}"
if a_key in item and item[a_key]:
template_id = item.get(q_key, i)
if (
isinstance(template_id, int)
and template_id in DESCRIPTIVE_QUESTIONS
):
question = DESCRIPTIVE_QUESTIONS[template_id]
else:
question = (
str(template_id)
if template_id
else f"Descriptive question {i}"
)
data.append(
{
"image": item["image"],
"question": question,
"answer": item[a_key],
"qid": (
template_id if isinstance(template_id, int) else i
),
"category": item.get("category", ""),
"grading_query": item.get("grading_query", ""),
}
)
elif mode == "reasoning":
if "reasoning_q" in item and item.get("reasoning_a"):
data.append(
{
"image": item["image"],
"question": item["reasoning_q"],
"answer": item["reasoning_a"],
"inst_category": item.get("category", 1),
"grading_query": item.get("grading_query", ""),
}
)
else:
raise ValueError(
f"Invalid mode: {mode}. Must be 'descriptive' or 'reasoning'."
)
print(f"Loaded {len(data)} examples from CharXiv ({mode}, {split})")
return data
def encode_image(self, pil_image: Image.Image) -> str:
buffer = io.BytesIO()
pil_image.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
def get_image_base64(self, item: dict) -> str:
if "image" in item and item["image"] is not None:
if isinstance(item["image"], Image.Image):
return self.encode_image(item["image"])
raise ValueError("Could not find image for item")
def build_messages(self, item: dict) -> List[dict]:
image_base64 = self.get_image_base64(item)
question = item.get("question", "")
return [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
},
{"type": "text", "text": question},
],
}
]
async def _judge_with_gpt(
self, question: str, answer: str, prediction: str, grading_query: str = ""
) -> Tuple[Optional[str], float]:
judge_model = getattr(self, "judge_model", "gpt-4o-mini")
judge_base_url = getattr(self, "judge_base_url", "https://api.openai.com/v1")
judge_api_key = os.environ.get(
getattr(self, "judge_api_key_env", "OPENAI_API_KEY"), ""
)
if not judge_api_key:
return None, 0.0
if grading_query:
prompt = grading_query.replace("{PREDICTION}", prediction)
else:
prompt = GRADING_QUERY_TEMPLATE.format(
question=question, answer=answer, prediction=prediction
)
try:
judge_client = openai.AsyncOpenAI(
api_key=judge_api_key,
base_url=judge_base_url,
)
completion = await judge_client.chat.completions.create(
model=judge_model,
messages=[{"role": "user", "content": prompt}],
temperature=0.0,
max_tokens=256,
)
response = completion.choices[0].message.content.strip()
try:
result = json.loads(response)
if isinstance(result, dict):
extract_answer = result.get("extract_answer", "")
score = float(result.get("score", 0.0))
return extract_answer, score
except json.JSONDecodeError:
json_match = re.search(r"\{[^}]+\}", response)
if json_match:
try:
result = json.loads(json_match.group())
extract_answer = result.get("extract_answer", "")
score = float(result.get("score", 0.0))
return extract_answer, score
except (json.JSONDecodeError, ValueError):
pass
return None, 0.0
except Exception as e:
print(f"GPT judge error: {e}")
return None, 0.0
def _fallback_score(
self, prediction: str, answer: str, mode: str
) -> Tuple[str, float]:
prediction = prediction.strip().lower()
answer = answer.strip().lower()
if not prediction:
return "", 0.0
if mode == "reasoning":
if answer in prediction:
return prediction, 1.0
try:
pred_nums = re.findall(r"-?\d+\.?\d*", prediction)
ans_nums = re.findall(r"-?\d+\.?\d*", answer)
if pred_nums and ans_nums:
for p in pred_nums:
for a in ans_nums:
if abs(float(p) - float(a)) < 0.01:
return prediction, 1.0
except ValueError:
pass
return prediction, 0.0
else:
pred_words = set(prediction.split())
ans_words = set(answer.split())
if not ans_words:
return prediction, 0.0
overlap = len(pred_words & ans_words) / len(ans_words)
return prediction, min(overlap, 1.0)
def get_category(self, item: dict, mode: str) -> str:
if mode == "descriptive":
qid = item.get("qid", 1)
return DESCRIPTIVE_CATEGORIES.get(qid, "Information Extraction")
else:
inst_category = item.get("inst_category", 1)
return REASONING_CATEGORIES.get(inst_category, "Text-in-Chart")
async def run_item(
self, server: ServerManager, data_item: dict
) -> Tuple[dict, dict]:
try:
messages = self.build_messages(data_item)
mode = getattr(self, "mode", "descriptive")
completion = await self.chat_completion(server, messages)
if not completion.choices:
return {"accuracy": 0.0, "score": 0.0}, {"error": "Empty response"}
message = completion.choices[0].message
response = message.content or ""
if hasattr(message, "reasoning") and message.reasoning and not response:
response = message.reasoning
if not response and hasattr(message, "model_extra"):
reasoning = message.model_extra.get("reasoning", "")
if reasoning:
response = reasoning
if not response:
return {"accuracy": 0.0, "score": 0.0}, {"error": "Empty response"}
use_gpt_judge = getattr(self, "use_gpt_judge", True)
grading_query = data_item.get("grading_query", "")
answer = data_item.get("answer", "")
question = data_item.get("question", "")
if use_gpt_judge:
extracted, score = await self._judge_with_gpt(
question, answer, response, grading_query
)
evaluation_method = "gpt_judge"
else:
extracted, score = self._fallback_score(response, answer, mode)
evaluation_method = "fallback"
if extracted is None:
extracted, score = self._fallback_score(response, answer, mode)
evaluation_method = "fallback"
category = self.get_category(data_item, mode)
sample = {
"question": data_item.get("question", ""),
"answer": answer,
"prediction": response[:500],
"extract_answer": extracted,
"score": score,
"category": category,
"mode": mode,
"qid": data_item.get("qid", data_item.get("inst_category", "")),
"evaluation_method": evaluation_method,
}
return {"accuracy": score, "score": score}, sample
except Exception as e:
return {"accuracy": 0.0, "score": 0.0}, {"error": str(e)}
def compute_category_metrics(samples: List[dict]) -> Dict:
from collections import defaultdict
scores_by_category = defaultdict(list)
for sample in samples:
if "error" in sample:
continue
category = sample.get("category", "Unknown")
score = sample.get("score", 0.0)
scores_by_category[category].append(score)
metrics = {}
total_score = 0.0
total_count = 0
for category, scores in scores_by_category.items():
if scores:
avg_score = sum(scores) / len(scores)
metrics[category] = {
"count": len(scores),
"average_score": avg_score,
}
total_score += sum(scores)
total_count += len(scores)
if total_count > 0:
metrics["Overall"] = {
"count": total_count,
"average_score": total_score / total_count,
}
return metrics
if __name__ == "__main__":
asyncio.run(
eval_runner(
CharXiv(
mode="descriptive", # or "reasoning"
split="validation",
use_gpt_judge=True,
judge_model="gpt-4o-mini",
temperature=0.0,
max_tokens=1024,
)
)
)