diff --git a/eval/visualize_results.py b/eval/visualize_results.py index 64b2c3d7..1c5514f1 100644 --- a/eval/visualize_results.py +++ b/eval/visualize_results.py @@ -366,85 +366,80 @@ def create_performance_distribution_violin(summaries: Dict[str, Dict[str, Any]]) return fig -def create_performance_heatmap(summaries: Dict[str, Dict[str, Any]], categories: Dict[str, List[str]]) -> Figure: - """Create a heatmap of model performance across datasets. +def create_performance_heatmap( + summaries: Dict[str, Dict[str, Any]], + categories: Dict[str, List[str]], +) -> Figure: + """ + Heat-map of model performance (0–100 %) across individual datasets. - Args: - summaries: Dictionary of model summaries - categories: Dictionary mapping categories to dataset lists - - Returns: - Matplotlib figure + Rows : models (sorted by overall mean score, high→low) + Cols : datasets grouped by `categories` + Cell : 100 × raw score """ if not summaries: logger.error("No summaries provided") return plt.figure() - # Get all dataset names - all_datasets = [] - for category, datasets in sorted(categories.items()): - all_datasets.extend(sorted(datasets)) + # ---- gather dataset names in category order + all_datasets: List[str] = [] + for cat, ds in sorted(categories.items()): + all_datasets.extend(sorted(ds)) - # Sort models by overall performance - overall_scores = {} - for model_name, summary in summaries.items(): - scores = list(summary["dataset_best_scores"].values()) - overall_scores[model_name] = np.mean(scores) - models = [item[0] for item in sorted(overall_scores.items(), key=lambda x: x[1], reverse=True)] + # ---- sort models by overall performance + overall = {m: np.mean(list(s["dataset_best_scores"].values())) for m, s in summaries.items()} + models = [m for m, _ in sorted(overall.items(), key=lambda x: x[1], reverse=True)] - # Create score matrix + # ---- build score matrix (0–100) score_matrix = np.zeros((len(models), len(all_datasets))) - for i, model in enumerate(models): - for j, dataset in enumerate(all_datasets): - score_matrix[i, j] = summaries[model]["dataset_best_scores"].get(dataset, 0) + for j, ds in enumerate(all_datasets): + score_matrix[i, j] = 100 * summaries[model]["dataset_best_scores"].get(ds, 0.0) - # Create heatmap + # ---- plot fig, ax = plt.subplots(figsize=(max(20, len(all_datasets) * 0.25), max(8, len(models) * 0.5))) - im = ax.imshow(score_matrix, cmap="viridis", aspect="auto", vmin=0, vmax=1) + im = ax.imshow(score_matrix, cmap="YlOrRd", aspect="auto", vmin=0, vmax=100) - # Add colorbar - cbar = ax.figure.colorbar(im, ax=ax) - cbar.ax.set_ylabel("Score", rotation=-90, va="bottom") + # colour-bar + cbar = fig.colorbar(im, ax=ax) + cbar.ax.set_ylabel("Score (%)", rotation=-90, va="bottom") - # Set ticks and labels + # ticks & labels ax.set_xticks(np.arange(len(all_datasets))) + ax.set_xticklabels(all_datasets, rotation=270, fontsize=8) ax.set_yticks(np.arange(len(models))) - ax.set_xticklabels(all_datasets, rotation=90, fontsize=8) ax.set_yticklabels(models) - # Add category separators and labels - current_idx = 0 - for category, datasets in sorted(categories.items()): - if datasets: - # Add vertical line after each category - next_idx = current_idx + len(datasets) - if next_idx < len(all_datasets): - ax.axvline(x=next_idx - 0.5, color="white", linestyle="-", linewidth=2) + # category separators & titles + current = 0 + label_offset = -0.25 # ↓ push labels down (was around −0.7) + for cat, ds in sorted(categories.items()): + if not ds: + continue + nxt = current + len(ds) + if nxt < len(all_datasets): + ax.axvline(nxt - 0.5, color="white", linewidth=2) - # Add category label - middle_idx = current_idx + len(datasets) / 2 - 0.5 - ax.text( - middle_idx, - -0.5, - category, - ha="center", - va="top", - fontsize=10, - bbox=dict(facecolor="white", alpha=0.7, edgecolor="none"), - ) + mid = current + len(ds) / 2 - 0.5 + ax.text( + mid, + label_offset, # <-- use offset + cat, + ha="center", + va="top", + fontsize=10, + bbox=dict(facecolor="white", alpha=0.7, edgecolor="none"), + ) + current = nxt - current_idx = next_idx - - # Add grid lines + # grid (mirrors comparison-plot style) ax.set_xticks(np.arange(-0.5, len(all_datasets), 1), minor=True) ax.set_yticks(np.arange(-0.5, len(models), 1), minor=True) ax.grid(which="minor", color="w", linestyle="-", linewidth=0.5) - plt.title("Model Performance Heatmap", size=15) + # ax.set_title("Model Performance by Dataset", fontsize=15) plt.tight_layout() - return fig