mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
369 lines
13 KiB
Python
369 lines
13 KiB
Python
"""We-Math evaluation environment."""
|
|
|
|
import asyncio
|
|
import base64
|
|
import io
|
|
import re
|
|
import string
|
|
from collections import defaultdict
|
|
from typing import Dict, List, Tuple
|
|
|
|
import pandas as pd
|
|
from datasets import load_dataset
|
|
from PIL import Image
|
|
|
|
from atroposlib.envs.server_handling.server_manager import ServerManager
|
|
from environments.eval_environments.eval import EvalBase, eval_runner
|
|
|
|
|
|
class WeMath(EvalBase):
|
|
"""
|
|
We-Math evaluation environment.
|
|
|
|
A benchmark for visual mathematical reasoning with multi-step problems
|
|
and 4-dimensional evaluation metrics (IK, IG, CM, RM).
|
|
"""
|
|
|
|
def setup_data(self) -> list:
|
|
split = getattr(self, "split", "testmini")
|
|
dataset = load_dataset("We-Math/We-Math", split=split)
|
|
print(f"Loaded {len(dataset)} examples from We-Math ({split})")
|
|
return list(dataset)
|
|
|
|
def encode_image(self, pil_image: Image.Image) -> str:
|
|
buffer = io.BytesIO()
|
|
pil_image.save(buffer, format="PNG")
|
|
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
|
|
|
def get_image_base64(self, item: dict) -> str:
|
|
img = item.get("image_path") or item.get("image")
|
|
if img is not None:
|
|
if isinstance(img, Image.Image):
|
|
return self.encode_image(img)
|
|
elif isinstance(img, bytes):
|
|
return base64.b64encode(img).decode("utf-8")
|
|
raise ValueError(
|
|
f"Could not find image for item {item.get('ID', item.get('problem_id', 'unknown'))}"
|
|
)
|
|
|
|
def build_messages(self, item: dict) -> List[dict]:
|
|
"""Build prompt with question, options, and optional hint (MCQ format)."""
|
|
image_base64 = self.get_image_base64(item)
|
|
question = item.get("question", "")
|
|
|
|
# Build options from A-H if present
|
|
options = {}
|
|
for letter in string.ascii_uppercase[:8]: # A-H
|
|
if (
|
|
letter in item
|
|
and item[letter] is not None
|
|
and not pd.isna(item.get(letter, float("nan")))
|
|
):
|
|
options[letter] = item[letter]
|
|
|
|
# Build prompt
|
|
prompt_parts = []
|
|
|
|
# Add hint if present
|
|
hint = item.get("hint", "")
|
|
if hint and not pd.isna(hint):
|
|
prompt_parts.append(f"Hint: {hint}")
|
|
|
|
prompt_parts.append(f"Question: {question}")
|
|
|
|
# Add options if present
|
|
if options:
|
|
options_text = "Options:\n"
|
|
for letter, value in options.items():
|
|
options_text += f"{letter}. {value}\n"
|
|
prompt_parts.append(options_text)
|
|
|
|
# Add COT requirement if dataset is WeMath_COT
|
|
use_cot = getattr(self, "use_cot", False)
|
|
requirement = item.get("requirement", "")
|
|
if use_cot and requirement and not pd.isna(requirement):
|
|
prompt_parts.append(requirement)
|
|
else:
|
|
prompt_parts.append(
|
|
"Answer with the option's letter from the given choices directly."
|
|
)
|
|
|
|
prompt = "\n".join(prompt_parts)
|
|
|
|
return [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
|
},
|
|
{"type": "text", "text": prompt},
|
|
],
|
|
}
|
|
]
|
|
|
|
def extract_answer(self, response: str) -> str:
|
|
"""
|
|
Extract MCQ answer letter from response.
|
|
|
|
Following VLMEvalKit logic: look for letter after "Answer" keyword,
|
|
or extract first valid letter (A-H).
|
|
"""
|
|
response = str(response).strip()
|
|
|
|
# Try to find answer after "Answer" keyword
|
|
answer_match = re.search(r"Answer[:\s]*([A-Ha-h])", response, re.IGNORECASE)
|
|
if answer_match:
|
|
return answer_match.group(1).upper()
|
|
|
|
# Clean response and look for first valid letter
|
|
cleaned = re.sub(r"[>><<:.\s]", "", response).strip()
|
|
if cleaned and cleaned[0].upper() in "ABCDEFGH":
|
|
return cleaned[0].upper()
|
|
|
|
# Fallback: find any letter A-H in the response
|
|
for char in response.upper():
|
|
if char in "ABCDEFGH":
|
|
return char
|
|
|
|
return ""
|
|
|
|
def score(self, prediction: str, answer: str) -> bool:
|
|
"""Check if prediction matches answer (case-insensitive)."""
|
|
if not prediction:
|
|
return False
|
|
return prediction.upper() == answer.upper()
|
|
|
|
async def run_item(self, server: ServerManager, data_item: dict) -> Tuple[dict, dict]:
|
|
try:
|
|
messages = self.build_messages(data_item)
|
|
|
|
completion = await self.chat_completion(server, messages)
|
|
|
|
if not completion.choices:
|
|
return {"accuracy": 0.0, "hit": 0}, {"error": "Empty response"}
|
|
|
|
message = completion.choices[0].message
|
|
response = message.content or ""
|
|
if hasattr(message, "reasoning") and message.reasoning and not response:
|
|
response = message.reasoning
|
|
if not response and hasattr(message, "model_extra"):
|
|
reasoning = message.model_extra.get("reasoning", "")
|
|
if reasoning:
|
|
response = reasoning
|
|
|
|
if not response:
|
|
return {"accuracy": 0.0, "hit": 0}, {"error": "Empty response"}
|
|
|
|
extracted = self.extract_answer(response)
|
|
answer = data_item.get("answer", "")
|
|
correct = self.score(extracted, answer)
|
|
|
|
# Get problem metadata for 4-dimensional analysis
|
|
problem_id = data_item.get("ID", data_item.get("problem_id", ""))
|
|
key = data_item.get("key", "") # e.g., "2steps_1", "2steps_multi", etc.
|
|
|
|
sample = {
|
|
"ID": problem_id,
|
|
"key": key,
|
|
"question": data_item.get("question", ""),
|
|
"answer": answer,
|
|
"prediction": extracted,
|
|
"raw_response": response[:500], # Truncate for logging
|
|
"hit": 1 if correct else 0,
|
|
"joker": correct, # VLMEvalKit naming convention
|
|
}
|
|
|
|
return {
|
|
"accuracy": 1.0 if correct else 0.0,
|
|
"hit": 1 if correct else 0,
|
|
}, sample
|
|
|
|
except Exception as e:
|
|
return {"accuracy": 0.0, "hit": 0}, {"error": str(e)}
|
|
|
|
|
|
def compute_4d_metrics(samples: List[dict]) -> Dict:
|
|
"""
|
|
Compute We-Math 4-dimensional metrics from evaluation samples.
|
|
|
|
This implements the evaluation logic from VLMEvalKit's wemath.py.
|
|
|
|
Returns metrics for:
|
|
- IK (Insufficient Knowledge): Steps wrong AND multi wrong
|
|
- IG (Inadequate Generalization): Steps right BUT multi wrong
|
|
- CM (Complete Mastery): Steps right AND multi right
|
|
- RM (Rote Memorization): Steps wrong BUT multi right
|
|
"""
|
|
# Convert samples to DataFrame
|
|
df = pd.DataFrame(samples)
|
|
|
|
if "key" not in df.columns or df["key"].isna().all():
|
|
# Dataset doesn't have step structure, return basic accuracy
|
|
return {
|
|
"overall_accuracy": df["hit"].mean() if "hit" in df.columns else 0.0,
|
|
"note": "Dataset lacks step structure for 4D metrics",
|
|
}
|
|
|
|
# Separate by step type
|
|
data_2steps = df[df["key"].str.contains("2steps", na=False)]
|
|
data_3steps = df[df["key"].str.contains("3steps", na=False)]
|
|
|
|
metrics = {}
|
|
|
|
# Process 2-step problems
|
|
if len(data_2steps) > 0:
|
|
merged_2steps = _process_steps_data(data_2steps, 2)
|
|
if merged_2steps is not None:
|
|
metrics["2step"] = _calculate_step_metrics(merged_2steps, 2)
|
|
|
|
# Process 3-step problems
|
|
if len(data_3steps) > 0:
|
|
merged_3steps = _process_steps_data(data_3steps, 3)
|
|
if merged_3steps is not None:
|
|
metrics["3step"] = _calculate_step_metrics(merged_3steps, 3)
|
|
|
|
# Compute overall 4D metrics
|
|
if "2step" in metrics or "3step" in metrics:
|
|
total_counts = _compute_total_counts(metrics)
|
|
total_count = 525 # Standard We-Math total
|
|
|
|
# Compute rates and final scores
|
|
final_metrics = _compute_final_scores(total_counts, total_count)
|
|
metrics["overall"] = final_metrics
|
|
|
|
# Basic accuracy
|
|
metrics["step_accuracy"] = df["hit"].mean() if "hit" in df.columns else 0.0
|
|
|
|
return metrics
|
|
|
|
|
|
def _process_steps_data(df: pd.DataFrame, steps: int) -> pd.DataFrame:
|
|
"""Process step data and merge by problem ID."""
|
|
try:
|
|
steps_data = {}
|
|
for i in range(1, steps + 1):
|
|
key = f"{steps}steps_{i}"
|
|
step_df = df[df["key"] == key].copy()
|
|
if len(step_df) == 0:
|
|
return None
|
|
step_df.columns = [f"{col}_{i}" for col in step_df.columns]
|
|
steps_data[i] = step_df
|
|
|
|
# Get multi-step data
|
|
multi_key = f"{steps}steps_multi"
|
|
multi_df = df[df["key"] == multi_key].copy()
|
|
if len(multi_df) == 0:
|
|
return None
|
|
multi_df.columns = [f"{col}_multi" for col in multi_df.columns]
|
|
|
|
# Merge all steps
|
|
merged = steps_data[1]
|
|
for i in range(2, steps + 1):
|
|
merged = pd.merge(
|
|
merged,
|
|
steps_data[i],
|
|
left_on="ID_1",
|
|
right_on=f"ID_{i}",
|
|
how="left",
|
|
)
|
|
merged = pd.merge(
|
|
merged, multi_df, left_on="ID_1", right_on="ID_multi", how="left"
|
|
)
|
|
|
|
return merged
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def _calculate_step_metrics(merged: pd.DataFrame, steps: int) -> Dict:
|
|
"""Calculate metrics for a step type (2-step or 3-step)."""
|
|
try:
|
|
# Get joker columns
|
|
joker_cols = [f"joker_{i}" for i in range(1, steps + 1)]
|
|
joker_multi = "joker_multi"
|
|
|
|
# Check if columns exist
|
|
for col in joker_cols + [joker_multi]:
|
|
if col not in merged.columns:
|
|
return {}
|
|
|
|
# Calculate conditions
|
|
all_steps_correct = merged[joker_cols].all(axis=1)
|
|
any_step_correct = merged[joker_cols].any(axis=1)
|
|
all_steps_wrong = ~merged[joker_cols].any(axis=1)
|
|
any_step_wrong = ~merged[joker_cols].all(axis=1)
|
|
multi_correct = merged[joker_multi] == True # noqa: E712
|
|
|
|
return {
|
|
# Strict: ALL steps must be correct
|
|
"CM_strict": int((all_steps_correct & multi_correct).sum()),
|
|
"IG": int((all_steps_correct & ~multi_correct).sum()),
|
|
"RM_strict": int((any_step_wrong & multi_correct).sum()),
|
|
"IK": int((any_step_wrong & ~multi_correct).sum()),
|
|
# Loose: ANY step correct
|
|
"CM_loose": int((any_step_correct & multi_correct).sum()),
|
|
"RM_loose": int((all_steps_wrong & multi_correct).sum()),
|
|
"total": len(merged),
|
|
}
|
|
except Exception:
|
|
return {}
|
|
|
|
|
|
def _compute_total_counts(metrics: Dict) -> Dict:
|
|
"""Aggregate counts across step types."""
|
|
totals = defaultdict(int)
|
|
|
|
for step_type in ["2step", "3step"]:
|
|
if step_type in metrics:
|
|
for key in ["CM_strict", "CM_loose", "IG", "RM_strict", "RM_loose", "IK"]:
|
|
if key in metrics[step_type]:
|
|
totals[key] += metrics[step_type][key]
|
|
|
|
return dict(totals)
|
|
|
|
|
|
def _compute_final_scores(total_counts: Dict, total_count: int = 525) -> Dict:
|
|
"""Compute final 4D scores and rates."""
|
|
results = {}
|
|
|
|
# Calculate rates
|
|
for key in ["IK", "IG", "CM_strict", "CM_loose", "RM_strict", "RM_loose"]:
|
|
count = total_counts.get(key, 0)
|
|
results[f"{key}_count"] = count
|
|
results[f"{key}_rate"] = count / total_count if total_count > 0 else 0.0
|
|
|
|
# Calculate RM rates (relative to CM + RM)
|
|
cm_rm_strict = total_counts.get("CM_strict", 0) + total_counts.get("RM_strict", 0)
|
|
cm_rm_loose = total_counts.get("CM_loose", 0) + total_counts.get("RM_loose", 0)
|
|
|
|
results["RM_strict_relative"] = (
|
|
total_counts.get("RM_strict", 0) / cm_rm_strict if cm_rm_strict > 0 else 0.0
|
|
)
|
|
results["RM_loose_relative"] = (
|
|
total_counts.get("RM_loose", 0) / cm_rm_loose if cm_rm_loose > 0 else 0.0
|
|
)
|
|
|
|
# Final scores (VLMEvalKit formula)
|
|
results["score_strict"] = (
|
|
total_count
|
|
- 0.5 * total_counts.get("IG", 0)
|
|
- total_counts.get("RM_strict", 0)
|
|
- total_counts.get("IK", 0)
|
|
) / total_count
|
|
|
|
results["score_loose"] = (
|
|
total_count
|
|
- 0.5 * total_counts.get("IG", 0)
|
|
- total_counts.get("RM_loose", 0)
|
|
- total_counts.get("IK", 0)
|
|
) / total_count
|
|
|
|
return results
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(
|
|
eval_runner(WeMath(split="testmini", use_cot=False, temperature=0.0, max_tokens=512))
|
|
)
|