This commit is contained in:
Zafir Stojanovski 2025-05-19 10:07:45 +02:00 committed by GitHub
parent add527ada1
commit 93e731c29c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -366,85 +366,80 @@ def create_performance_distribution_violin(summaries: Dict[str, Dict[str, Any]])
return fig
def create_performance_heatmap(summaries: Dict[str, Dict[str, Any]], categories: Dict[str, List[str]]) -> Figure:
"""Create a heatmap of model performance across datasets.
def create_performance_heatmap(
summaries: Dict[str, Dict[str, Any]],
categories: Dict[str, List[str]],
) -> Figure:
"""
Heat-map of model performance (0100 %) across individual datasets.
Args:
summaries: Dictionary of model summaries
categories: Dictionary mapping categories to dataset lists
Returns:
Matplotlib figure
Rows : models (sorted by overall mean score, highlow)
Cols : datasets grouped by `categories`
Cell : 100 × raw score
"""
if not summaries:
logger.error("No summaries provided")
return plt.figure()
# Get all dataset names
all_datasets = []
for category, datasets in sorted(categories.items()):
all_datasets.extend(sorted(datasets))
# ---- gather dataset names in category order
all_datasets: List[str] = []
for cat, ds in sorted(categories.items()):
all_datasets.extend(sorted(ds))
# Sort models by overall performance
overall_scores = {}
for model_name, summary in summaries.items():
scores = list(summary["dataset_best_scores"].values())
overall_scores[model_name] = np.mean(scores)
models = [item[0] for item in sorted(overall_scores.items(), key=lambda x: x[1], reverse=True)]
# ---- sort models by overall performance
overall = {m: np.mean(list(s["dataset_best_scores"].values())) for m, s in summaries.items()}
models = [m for m, _ in sorted(overall.items(), key=lambda x: x[1], reverse=True)]
# Create score matrix
# ---- build score matrix (0100)
score_matrix = np.zeros((len(models), len(all_datasets)))
for i, model in enumerate(models):
for j, dataset in enumerate(all_datasets):
score_matrix[i, j] = summaries[model]["dataset_best_scores"].get(dataset, 0)
for j, ds in enumerate(all_datasets):
score_matrix[i, j] = 100 * summaries[model]["dataset_best_scores"].get(ds, 0.0)
# Create heatmap
# ---- plot
fig, ax = plt.subplots(figsize=(max(20, len(all_datasets) * 0.25), max(8, len(models) * 0.5)))
im = ax.imshow(score_matrix, cmap="viridis", aspect="auto", vmin=0, vmax=1)
im = ax.imshow(score_matrix, cmap="YlOrRd", aspect="auto", vmin=0, vmax=100)
# Add colorbar
cbar = ax.figure.colorbar(im, ax=ax)
cbar.ax.set_ylabel("Score", rotation=-90, va="bottom")
# colour-bar
cbar = fig.colorbar(im, ax=ax)
cbar.ax.set_ylabel("Score (%)", rotation=-90, va="bottom")
# Set ticks and labels
# ticks & labels
ax.set_xticks(np.arange(len(all_datasets)))
ax.set_xticklabels(all_datasets, rotation=270, fontsize=8)
ax.set_yticks(np.arange(len(models)))
ax.set_xticklabels(all_datasets, rotation=90, fontsize=8)
ax.set_yticklabels(models)
# Add category separators and labels
current_idx = 0
for category, datasets in sorted(categories.items()):
if datasets:
# Add vertical line after each category
next_idx = current_idx + len(datasets)
if next_idx < len(all_datasets):
ax.axvline(x=next_idx - 0.5, color="white", linestyle="-", linewidth=2)
# category separators & titles
current = 0
label_offset = -0.25 # ↓ push labels down (was around 0.7)
for cat, ds in sorted(categories.items()):
if not ds:
continue
nxt = current + len(ds)
if nxt < len(all_datasets):
ax.axvline(nxt - 0.5, color="white", linewidth=2)
# Add category label
middle_idx = current_idx + len(datasets) / 2 - 0.5
ax.text(
middle_idx,
-0.5,
category,
ha="center",
va="top",
fontsize=10,
bbox=dict(facecolor="white", alpha=0.7, edgecolor="none"),
)
mid = current + len(ds) / 2 - 0.5
ax.text(
mid,
label_offset, # <-- use offset
cat,
ha="center",
va="top",
fontsize=10,
bbox=dict(facecolor="white", alpha=0.7, edgecolor="none"),
)
current = nxt
current_idx = next_idx
# Add grid lines
# grid (mirrors comparison-plot style)
ax.set_xticks(np.arange(-0.5, len(all_datasets), 1), minor=True)
ax.set_yticks(np.arange(-0.5, len(models), 1), minor=True)
ax.grid(which="minor", color="w", linestyle="-", linewidth=0.5)
plt.title("Model Performance Heatmap", size=15)
# ax.set_title("Model Performance by Dataset", fontsize=15)
plt.tight_layout()
return fig