mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
evals moved + readme
This commit is contained in:
parent
d84e3c70b7
commit
0a16fafadb
29 changed files with 390 additions and 696 deletions
|
|
@ -0,0 +1,369 @@
|
|||
"""We-Math evaluation environment."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import re
|
||||
import string
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import pandas as pd
|
||||
from datasets import load_dataset
|
||||
from PIL import Image
|
||||
|
||||
from atroposlib.envs.server_handling.server_manager import ServerManager
|
||||
from environments.eval_environments.eval import EvalBase, eval_runner
|
||||
|
||||
|
||||
class WeMath(EvalBase):
|
||||
"""
|
||||
We-Math evaluation environment.
|
||||
|
||||
A benchmark for visual mathematical reasoning with multi-step problems
|
||||
and 4-dimensional evaluation metrics (IK, IG, CM, RM).
|
||||
"""
|
||||
|
||||
def setup_data(self) -> list:
|
||||
split = getattr(self, "split", "testmini")
|
||||
dataset = load_dataset("We-Math/We-Math", split=split)
|
||||
print(f"Loaded {len(dataset)} examples from We-Math ({split})")
|
||||
return list(dataset)
|
||||
|
||||
def encode_image(self, pil_image: Image.Image) -> str:
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
def get_image_base64(self, item: dict) -> str:
|
||||
img = item.get("image_path") or item.get("image")
|
||||
if img is not None:
|
||||
if isinstance(img, Image.Image):
|
||||
return self.encode_image(img)
|
||||
elif isinstance(img, bytes):
|
||||
return base64.b64encode(img).decode("utf-8")
|
||||
raise ValueError(
|
||||
f"Could not find image for item {item.get('ID', item.get('problem_id', 'unknown'))}"
|
||||
)
|
||||
|
||||
def build_messages(self, item: dict) -> List[dict]:
|
||||
"""Build prompt with question, options, and optional hint (MCQ format)."""
|
||||
image_base64 = self.get_image_base64(item)
|
||||
question = item.get("question", "")
|
||||
|
||||
# Build options from A-H if present
|
||||
options = {}
|
||||
for letter in string.ascii_uppercase[:8]: # A-H
|
||||
if (
|
||||
letter in item
|
||||
and item[letter] is not None
|
||||
and not pd.isna(item.get(letter, float("nan")))
|
||||
):
|
||||
options[letter] = item[letter]
|
||||
|
||||
# Build prompt
|
||||
prompt_parts = []
|
||||
|
||||
# Add hint if present
|
||||
hint = item.get("hint", "")
|
||||
if hint and not pd.isna(hint):
|
||||
prompt_parts.append(f"Hint: {hint}")
|
||||
|
||||
prompt_parts.append(f"Question: {question}")
|
||||
|
||||
# Add options if present
|
||||
if options:
|
||||
options_text = "Options:\n"
|
||||
for letter, value in options.items():
|
||||
options_text += f"{letter}. {value}\n"
|
||||
prompt_parts.append(options_text)
|
||||
|
||||
# Add COT requirement if dataset is WeMath_COT
|
||||
use_cot = getattr(self, "use_cot", False)
|
||||
requirement = item.get("requirement", "")
|
||||
if use_cot and requirement and not pd.isna(requirement):
|
||||
prompt_parts.append(requirement)
|
||||
else:
|
||||
prompt_parts.append(
|
||||
"Answer with the option's letter from the given choices directly."
|
||||
)
|
||||
|
||||
prompt = "\n".join(prompt_parts)
|
||||
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
||||
},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
def extract_answer(self, response: str) -> str:
|
||||
"""
|
||||
Extract MCQ answer letter from response.
|
||||
|
||||
Following VLMEvalKit logic: look for letter after "Answer" keyword,
|
||||
or extract first valid letter (A-H).
|
||||
"""
|
||||
response = str(response).strip()
|
||||
|
||||
# Try to find answer after "Answer" keyword
|
||||
answer_match = re.search(r"Answer[:\s]*([A-Ha-h])", response, re.IGNORECASE)
|
||||
if answer_match:
|
||||
return answer_match.group(1).upper()
|
||||
|
||||
# Clean response and look for first valid letter
|
||||
cleaned = re.sub(r"[>><<:.\s]", "", response).strip()
|
||||
if cleaned and cleaned[0].upper() in "ABCDEFGH":
|
||||
return cleaned[0].upper()
|
||||
|
||||
# Fallback: find any letter A-H in the response
|
||||
for char in response.upper():
|
||||
if char in "ABCDEFGH":
|
||||
return char
|
||||
|
||||
return ""
|
||||
|
||||
def score(self, prediction: str, answer: str) -> bool:
|
||||
"""Check if prediction matches answer (case-insensitive)."""
|
||||
if not prediction:
|
||||
return False
|
||||
return prediction.upper() == answer.upper()
|
||||
|
||||
async def run_item(self, server: ServerManager, data_item: dict) -> Tuple[dict, dict]:
|
||||
try:
|
||||
messages = self.build_messages(data_item)
|
||||
|
||||
completion = await self.chat_completion(server, messages)
|
||||
|
||||
if not completion.choices:
|
||||
return {"accuracy": 0.0, "hit": 0}, {"error": "Empty response"}
|
||||
|
||||
message = completion.choices[0].message
|
||||
response = message.content or ""
|
||||
if hasattr(message, "reasoning") and message.reasoning and not response:
|
||||
response = message.reasoning
|
||||
if not response and hasattr(message, "model_extra"):
|
||||
reasoning = message.model_extra.get("reasoning", "")
|
||||
if reasoning:
|
||||
response = reasoning
|
||||
|
||||
if not response:
|
||||
return {"accuracy": 0.0, "hit": 0}, {"error": "Empty response"}
|
||||
|
||||
extracted = self.extract_answer(response)
|
||||
answer = data_item.get("answer", "")
|
||||
correct = self.score(extracted, answer)
|
||||
|
||||
# Get problem metadata for 4-dimensional analysis
|
||||
problem_id = data_item.get("ID", data_item.get("problem_id", ""))
|
||||
key = data_item.get("key", "") # e.g., "2steps_1", "2steps_multi", etc.
|
||||
|
||||
sample = {
|
||||
"ID": problem_id,
|
||||
"key": key,
|
||||
"question": data_item.get("question", ""),
|
||||
"answer": answer,
|
||||
"prediction": extracted,
|
||||
"raw_response": response[:500], # Truncate for logging
|
||||
"hit": 1 if correct else 0,
|
||||
"joker": correct, # VLMEvalKit naming convention
|
||||
}
|
||||
|
||||
return {
|
||||
"accuracy": 1.0 if correct else 0.0,
|
||||
"hit": 1 if correct else 0,
|
||||
}, sample
|
||||
|
||||
except Exception as e:
|
||||
return {"accuracy": 0.0, "hit": 0}, {"error": str(e)}
|
||||
|
||||
|
||||
def compute_4d_metrics(samples: List[dict]) -> Dict:
|
||||
"""
|
||||
Compute We-Math 4-dimensional metrics from evaluation samples.
|
||||
|
||||
This implements the evaluation logic from VLMEvalKit's wemath.py.
|
||||
|
||||
Returns metrics for:
|
||||
- IK (Insufficient Knowledge): Steps wrong AND multi wrong
|
||||
- IG (Inadequate Generalization): Steps right BUT multi wrong
|
||||
- CM (Complete Mastery): Steps right AND multi right
|
||||
- RM (Rote Memorization): Steps wrong BUT multi right
|
||||
"""
|
||||
# Convert samples to DataFrame
|
||||
df = pd.DataFrame(samples)
|
||||
|
||||
if "key" not in df.columns or df["key"].isna().all():
|
||||
# Dataset doesn't have step structure, return basic accuracy
|
||||
return {
|
||||
"overall_accuracy": df["hit"].mean() if "hit" in df.columns else 0.0,
|
||||
"note": "Dataset lacks step structure for 4D metrics",
|
||||
}
|
||||
|
||||
# Separate by step type
|
||||
data_2steps = df[df["key"].str.contains("2steps", na=False)]
|
||||
data_3steps = df[df["key"].str.contains("3steps", na=False)]
|
||||
|
||||
metrics = {}
|
||||
|
||||
# Process 2-step problems
|
||||
if len(data_2steps) > 0:
|
||||
merged_2steps = _process_steps_data(data_2steps, 2)
|
||||
if merged_2steps is not None:
|
||||
metrics["2step"] = _calculate_step_metrics(merged_2steps, 2)
|
||||
|
||||
# Process 3-step problems
|
||||
if len(data_3steps) > 0:
|
||||
merged_3steps = _process_steps_data(data_3steps, 3)
|
||||
if merged_3steps is not None:
|
||||
metrics["3step"] = _calculate_step_metrics(merged_3steps, 3)
|
||||
|
||||
# Compute overall 4D metrics
|
||||
if "2step" in metrics or "3step" in metrics:
|
||||
total_counts = _compute_total_counts(metrics)
|
||||
total_count = 525 # Standard We-Math total
|
||||
|
||||
# Compute rates and final scores
|
||||
final_metrics = _compute_final_scores(total_counts, total_count)
|
||||
metrics["overall"] = final_metrics
|
||||
|
||||
# Basic accuracy
|
||||
metrics["step_accuracy"] = df["hit"].mean() if "hit" in df.columns else 0.0
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def _process_steps_data(df: pd.DataFrame, steps: int) -> pd.DataFrame:
|
||||
"""Process step data and merge by problem ID."""
|
||||
try:
|
||||
steps_data = {}
|
||||
for i in range(1, steps + 1):
|
||||
key = f"{steps}steps_{i}"
|
||||
step_df = df[df["key"] == key].copy()
|
||||
if len(step_df) == 0:
|
||||
return None
|
||||
step_df.columns = [f"{col}_{i}" for col in step_df.columns]
|
||||
steps_data[i] = step_df
|
||||
|
||||
# Get multi-step data
|
||||
multi_key = f"{steps}steps_multi"
|
||||
multi_df = df[df["key"] == multi_key].copy()
|
||||
if len(multi_df) == 0:
|
||||
return None
|
||||
multi_df.columns = [f"{col}_multi" for col in multi_df.columns]
|
||||
|
||||
# Merge all steps
|
||||
merged = steps_data[1]
|
||||
for i in range(2, steps + 1):
|
||||
merged = pd.merge(
|
||||
merged,
|
||||
steps_data[i],
|
||||
left_on="ID_1",
|
||||
right_on=f"ID_{i}",
|
||||
how="left",
|
||||
)
|
||||
merged = pd.merge(
|
||||
merged, multi_df, left_on="ID_1", right_on="ID_multi", how="left"
|
||||
)
|
||||
|
||||
return merged
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _calculate_step_metrics(merged: pd.DataFrame, steps: int) -> Dict:
|
||||
"""Calculate metrics for a step type (2-step or 3-step)."""
|
||||
try:
|
||||
# Get joker columns
|
||||
joker_cols = [f"joker_{i}" for i in range(1, steps + 1)]
|
||||
joker_multi = "joker_multi"
|
||||
|
||||
# Check if columns exist
|
||||
for col in joker_cols + [joker_multi]:
|
||||
if col not in merged.columns:
|
||||
return {}
|
||||
|
||||
# Calculate conditions
|
||||
all_steps_correct = merged[joker_cols].all(axis=1)
|
||||
any_step_correct = merged[joker_cols].any(axis=1)
|
||||
all_steps_wrong = ~merged[joker_cols].any(axis=1)
|
||||
any_step_wrong = ~merged[joker_cols].all(axis=1)
|
||||
multi_correct = merged[joker_multi] == True # noqa: E712
|
||||
|
||||
return {
|
||||
# Strict: ALL steps must be correct
|
||||
"CM_strict": int((all_steps_correct & multi_correct).sum()),
|
||||
"IG": int((all_steps_correct & ~multi_correct).sum()),
|
||||
"RM_strict": int((any_step_wrong & multi_correct).sum()),
|
||||
"IK": int((any_step_wrong & ~multi_correct).sum()),
|
||||
# Loose: ANY step correct
|
||||
"CM_loose": int((any_step_correct & multi_correct).sum()),
|
||||
"RM_loose": int((all_steps_wrong & multi_correct).sum()),
|
||||
"total": len(merged),
|
||||
}
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def _compute_total_counts(metrics: Dict) -> Dict:
|
||||
"""Aggregate counts across step types."""
|
||||
totals = defaultdict(int)
|
||||
|
||||
for step_type in ["2step", "3step"]:
|
||||
if step_type in metrics:
|
||||
for key in ["CM_strict", "CM_loose", "IG", "RM_strict", "RM_loose", "IK"]:
|
||||
if key in metrics[step_type]:
|
||||
totals[key] += metrics[step_type][key]
|
||||
|
||||
return dict(totals)
|
||||
|
||||
|
||||
def _compute_final_scores(total_counts: Dict, total_count: int = 525) -> Dict:
|
||||
"""Compute final 4D scores and rates."""
|
||||
results = {}
|
||||
|
||||
# Calculate rates
|
||||
for key in ["IK", "IG", "CM_strict", "CM_loose", "RM_strict", "RM_loose"]:
|
||||
count = total_counts.get(key, 0)
|
||||
results[f"{key}_count"] = count
|
||||
results[f"{key}_rate"] = count / total_count if total_count > 0 else 0.0
|
||||
|
||||
# Calculate RM rates (relative to CM + RM)
|
||||
cm_rm_strict = total_counts.get("CM_strict", 0) + total_counts.get("RM_strict", 0)
|
||||
cm_rm_loose = total_counts.get("CM_loose", 0) + total_counts.get("RM_loose", 0)
|
||||
|
||||
results["RM_strict_relative"] = (
|
||||
total_counts.get("RM_strict", 0) / cm_rm_strict if cm_rm_strict > 0 else 0.0
|
||||
)
|
||||
results["RM_loose_relative"] = (
|
||||
total_counts.get("RM_loose", 0) / cm_rm_loose if cm_rm_loose > 0 else 0.0
|
||||
)
|
||||
|
||||
# Final scores (VLMEvalKit formula)
|
||||
results["score_strict"] = (
|
||||
total_count
|
||||
- 0.5 * total_counts.get("IG", 0)
|
||||
- total_counts.get("RM_strict", 0)
|
||||
- total_counts.get("IK", 0)
|
||||
) / total_count
|
||||
|
||||
results["score_loose"] = (
|
||||
total_count
|
||||
- 0.5 * total_counts.get("IG", 0)
|
||||
- total_counts.get("RM_loose", 0)
|
||||
- total_counts.get("IK", 0)
|
||||
) / total_count
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(
|
||||
eval_runner(WeMath(split="testmini", use_cot=False, temperature=0.0, max_tokens=512))
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue