fix(eval): comparison plot (#441)

* heatmap

* filter comparison plots

* latex style

* curriculum heatmap

* pre-commit

* update figsize

* large y-ticks

* larger font

* thinner

* include 50
This commit is contained in:
Zafir Stojanovski 2025-05-29 12:31:07 +02:00 committed by GitHub
parent f51769927e
commit b843f33b1d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 198 additions and 30 deletions

View file

@ -27,6 +27,7 @@ import re
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import matplotlib
import matplotlib.colors as mcolors import matplotlib.colors as mcolors
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -42,6 +43,22 @@ logging.basicConfig(
logger = logging.getLogger("visualize_results") logger = logging.getLogger("visualize_results")
plt.rcParams.update(
{
"text.usetex": True,
"font.family": "serif",
"font.serif": ["Computer Modern Roman"],
"text.latex.preamble": r"\usepackage{amsmath,amssymb,amsfonts,mathrsfs,bm}",
"axes.labelsize": 20,
"font.size": 20,
"legend.fontsize": 14,
"xtick.labelsize": 14,
"ytick.labelsize": 14,
"axes.titlesize": 22,
}
)
def load_summaries(results_dir: str) -> Dict[str, Dict[str, Any]]: def load_summaries(results_dir: str) -> Dict[str, Dict[str, Any]]:
"""Load all summary.json files from subdirectories. """Load all summary.json files from subdirectories.
@ -583,13 +600,14 @@ def create_comparison_plot(
summaries: Dict[str, Dict[str, Any]], summaries: Dict[str, Dict[str, Any]],
other_summaries: Dict[str, Dict[str, Any]], other_summaries: Dict[str, Dict[str, Any]],
categories: Optional[Dict[str, List[str]]] = None, categories: Optional[Dict[str, List[str]]] = None,
compare_model_ids: Optional[List[str]] = None,
) -> Figure: ) -> Figure:
""" """
Build a heat-map of per-category score differences (scaled to 100 100). Build a heat-map of per-category score differences (scaled to 100 100).
Rows : model IDs present in both `summaries` and `other_summaries` Rows : category names (`categories`)
Cols : category names (`categories`) Cols : model IDs present in both `summaries` and `other_summaries`
Value : 100 * (mean(score in summaries) mean(score in other_summaries)) Value : 100 × (mean(score in summaries) mean(score in other_summaries))
A numeric annotation (rounded to 2 dp) is rendered in every cell. A numeric annotation (rounded to 2 dp) is rendered in every cell.
""" """
@ -601,55 +619,53 @@ def create_comparison_plot(
all_ds = next(iter(summaries.values()))["dataset_best_scores"].keys() all_ds = next(iter(summaries.values()))["dataset_best_scores"].keys()
categories = {"all": list(all_ds)} categories = {"all": list(all_ds)}
# models appearing in both result sets # models present in both result sets
common_models = [m for m in summaries if m in other_summaries] common_models = [m for m in summaries if m in other_summaries]
if not common_models: if not common_models:
logger.error("No overlapping model IDs between the two result sets.") logger.error("No overlapping model IDs between the two result sets.")
return plt.figure() return plt.figure()
# sort models by overall performance # sort models by overall performance
overall_scores = {} overall_scores = {m: np.mean(list(s["dataset_best_scores"].values())) for m, s in summaries.items()}
for model_name, summary in summaries.items(): models = [m for m, _ in sorted(overall_scores.items(), key=lambda x: x[1], reverse=True) if m in common_models]
scores = list(summary["dataset_best_scores"].values()) if compare_model_ids:
overall_scores[model_name] = np.mean(scores) models = [m for m in models if m in compare_model_ids]
models = [item[0] for item in sorted(overall_scores.items(), key=lambda x: x[1], reverse=True)]
common_models = [m for m in models if m in common_models]
category_list = sorted(categories.keys()) category_list = sorted(categories.keys())
diff_matrix = np.zeros((len(common_models), len(category_list))) # ---------- note the transposed shape (categories × models)
diff_matrix = np.zeros((len(category_list), len(models)))
# compute 100 × Δ # compute 100 × Δ
for i, model in enumerate(common_models): for i, cat in enumerate(category_list):
ds = categories[cat]
for j, model in enumerate(models):
cur_scores = summaries[model]["dataset_best_scores"] cur_scores = summaries[model]["dataset_best_scores"]
base_scores = other_summaries[model]["dataset_best_scores"] base_scores = other_summaries[model]["dataset_best_scores"]
for j, cat in enumerate(category_list):
ds = categories[cat]
cur_mean = np.mean([cur_scores.get(d, 0.0) for d in ds]) if ds else 0.0 cur_mean = np.mean([cur_scores.get(d, 0.0) for d in ds]) if ds else 0.0
base_mean = np.mean([base_scores.get(d, 0.0) for d in ds]) if ds else 0.0 base_mean = np.mean([base_scores.get(d, 0.0) for d in ds]) if ds else 0.0
diff_matrix[i, j] = 100 * (cur_mean - base_mean) # scale to -100 … 100 diff_matrix[i, j] = 100 * (cur_mean - base_mean)
# ---------------------------------------------------------------- Plot # ---------------------------------------------------------------- plot
fig, ax = plt.subplots(figsize=(max(8, len(category_list) * 1.2), max(6, len(common_models) * 0.5))) fig, ax = plt.subplots(figsize=(max(8, len(models) * 1.2), max(6, len(category_list) * 0.5)))
im = ax.imshow(diff_matrix, cmap="coolwarm", aspect="auto", vmin=-100, vmax=100) im = ax.imshow(diff_matrix, cmap="coolwarm", aspect="auto", vmin=-100, vmax=100)
# colour-bar # colour-bar
cbar = fig.colorbar(im, ax=ax) cbar = fig.colorbar(im, ax=ax)
cbar.ax.set_ylabel("Δ score (percentage-points)", rotation=-90, va="bottom") cbar.ax.set_ylabel("$\Delta$ score (\%)", rotation=-90, va="bottom", fontweight="bold")
# ticks / labels # ticks / labels
ax.set_xticks(np.arange(len(category_list)), labels=category_list, rotation=45, ha="right") ax.set_xticks(np.arange(len(models)), labels=models, rotation=45, ha="right")
ax.set_yticks(np.arange(len(common_models)), labels=common_models) ax.set_yticks(np.arange(len(category_list)), labels=category_list)
# grid for readability # grid for readability
ax.set_xticks(np.arange(-0.5, len(category_list), 1), minor=True) ax.set_xticks(np.arange(-0.5, len(models), 1), minor=True)
ax.set_yticks(np.arange(-0.5, len(common_models), 1), minor=True) ax.set_yticks(np.arange(-0.5, len(category_list), 1), minor=True)
ax.grid(which="minor", color="w", linestyle="-", linewidth=0.5) ax.grid(which="minor", color="w", linestyle="-", linewidth=0.5)
# annotate each cell # annotate each cell
for i in range(len(common_models)): for i in range(len(category_list)):
for j in range(len(category_list)): for j in range(len(models)):
value = diff_matrix[i, j] value = diff_matrix[i, j]
ax.text( ax.text(
j, j,
@ -658,10 +674,10 @@ def create_comparison_plot(
ha="center", ha="center",
va="center", va="center",
color="black" if abs(value) < 50 else "white", color="black" if abs(value) < 50 else "white",
fontsize=8, fontsize=12,
) )
ax.set_title("Per-Category Performance Δ (hard - easy)", fontsize=14) # ax.set_title("Per-Category Performance $\Delta$ (hard easy)", fontweight="bold")
plt.tight_layout() plt.tight_layout()
return fig return fig
@ -702,6 +718,7 @@ def main():
"--top-mode", default="hardest", choices=["hardest", "easiest", "variable"], help="Mode for top datasets plot" "--top-mode", default="hardest", choices=["hardest", "easiest", "variable"], help="Mode for top datasets plot"
) )
parser.add_argument("--compare-results-dir", help="Directory to compare results with", default=None) parser.add_argument("--compare-results-dir", help="Directory to compare results with", default=None)
parser.add_argument("--compare-model-ids", help="Comma-separated list of model IDs to compare", default=None)
parser.add_argument("--format", default="png", choices=["png", "pdf", "svg"], help="Output format for plots") parser.add_argument("--format", default="png", choices=["png", "pdf", "svg"], help="Output format for plots")
parser.add_argument("--dpi", type=int, default=300, help="DPI for output images") parser.add_argument("--dpi", type=int, default=300, help="DPI for output images")
parser.add_argument("--no-show", action="store_true", help="Don't display plots, just save them") parser.add_argument("--no-show", action="store_true", help="Don't display plots, just save them")
@ -773,7 +790,8 @@ def main():
if not other_summaries: if not other_summaries:
logger.error("No valid summaries found in comparison directory. Exiting.") logger.error("No valid summaries found in comparison directory. Exiting.")
return 1 return 1
fig = create_comparison_plot(summaries, other_summaries, categories) compare_model_ids = args.compare_model_ids.split(",") if args.compare_model_ids else None
fig = create_comparison_plot(summaries, other_summaries, categories, compare_model_ids)
save_figure(fig, args.output_dir, "model_category_delta_heatmap", args.format, args.dpi) save_figure(fig, args.output_dir, "model_category_delta_heatmap", args.format, args.dpi)
else: else:

File diff suppressed because one or more lines are too long