diff --git a/eval/visualize_results.py b/eval/visualize_results.py index 9d6966d5..64b2c3d7 100644 --- a/eval/visualize_results.py +++ b/eval/visualize_results.py @@ -587,76 +587,86 @@ def create_dashboard(summaries: Dict[str, Dict[str, Any]], categories: Dict[str, def create_comparison_plot( summaries: Dict[str, Dict[str, Any]], other_summaries: Dict[str, Dict[str, Any]], - model_id: str, categories: Optional[Dict[str, List[str]]] = None, ) -> Figure: - """Create a comparison plot between two models. + """ + Build a heat-map of per-category score differences (scaled to –100 … 100). - Args: - summaries: Dictionary of model summaries - other_summaries: Dictionary of other model summaries for comparison - model_id: Model ID to compare with - categories: Dictionary mapping categories to dataset lists + Rows : model IDs present in both `summaries` and `other_summaries` + Cols : category names (`categories`) + Value : 100 * (mean(score in summaries) − mean(score in other_summaries)) - Returns: - Matplotlib figure + A numeric annotation (rounded to 2 dp) is rendered in every cell. """ if not summaries or not other_summaries: logger.error("No summaries provided for comparison") return plt.figure() - current_scores, baseline_scores = {}, {} + if categories is None: + all_ds = next(iter(summaries.values()))["dataset_best_scores"].keys() + categories = {"all": list(all_ds)} + # models appearing in both result sets + common_models = [m for m in summaries if m in other_summaries] + if not common_models: + logger.error("No overlapping model IDs between the two result sets.") + return plt.figure() + + # sort models by overall performance + overall_scores = {} for model_name, summary in summaries.items(): - if model_name == model_id: - for category, datasets in categories.items(): - datasets_scores = [summary["dataset_best_scores"].get(dataset, 0) for dataset in datasets] - if datasets_scores: # Avoid division by zero - current_scores[category] = np.mean(datasets_scores) - else: - current_scores[category] = 0 + scores = list(summary["dataset_best_scores"].values()) + overall_scores[model_name] = np.mean(scores) + models = [item[0] for item in sorted(overall_scores.items(), key=lambda x: x[1], reverse=True)] + common_models = [m for m in models if m in common_models] - for model_name, summary in other_summaries.items(): - if model_name == model_id: - for category, datasets in categories.items(): - datasets_scores = [summary["dataset_best_scores"].get(dataset, 0) for dataset in datasets] - if datasets_scores: - baseline_scores[category] = np.mean(datasets_scores) - else: - baseline_scores[category] = 0 + category_list = sorted(categories.keys()) + diff_matrix = np.zeros((len(common_models), len(category_list))) - logger.debug(f"Current scores: {current_scores}") - logger.debug(f"Baseline scores: {baseline_scores}") + # compute 100 × Δ + for i, model in enumerate(common_models): + cur_scores = summaries[model]["dataset_best_scores"] + base_scores = other_summaries[model]["dataset_best_scores"] - # Create a bar chart for comparison - fig, ax = plt.subplots(figsize=(20, 10)) - categories_list = sorted(current_scores.keys()) - current_values = [round(current_scores[cat] * 100, 2) for cat in categories_list] - baseline_values = [round(baseline_scores[cat] * 100, 2) for cat in categories_list] + for j, cat in enumerate(category_list): + ds = categories[cat] + cur_mean = np.mean([cur_scores.get(d, 0.0) for d in ds]) if ds else 0.0 + base_mean = np.mean([base_scores.get(d, 0.0) for d in ds]) if ds else 0.0 + diff_matrix[i, j] = 100 * (cur_mean - base_mean) # scale to -100 … 100 - x = np.arange(len(categories_list)) - width = 0.35 - colors = plt.cm.tab10.colors - bars1 = ax.bar(x - width / 2, baseline_values, width, label="Baseline", color=colors[1]) - bars2 = ax.bar(x + width / 2, current_values, width, label="Difficult", color=colors[0]) - ax.set_ylabel("Average Score") - ax.set_title(f"Performance of {model_id} when increasing difficulty", size=15) - ax.set_xticks(x) - ax.set_xticklabels(categories_list, rotation=45, ha="right") - ax.legend() - ax.set_ylim(0, max(max(current_values), max(baseline_values)) * 1.1) - plt.tight_layout() - plt.grid(axis="y") + # ---------------------------------------------------------------- Plot + fig, ax = plt.subplots(figsize=(max(8, len(category_list) * 1.2), max(6, len(common_models) * 0.5))) - # Add value labels on top of bars - for bar in bars1: - height = bar.get_height() - ax.text(bar.get_x() + bar.get_width() / 2.0, height + 0.01, f"{height:.2f}", ha="center", va="bottom") - for bar in bars2: - height = bar.get_height() - ax.text(bar.get_x() + bar.get_width() / 2.0, height + 0.01, f"{height:.2f}", ha="center", va="bottom") + im = ax.imshow(diff_matrix, cmap="coolwarm", aspect="auto", vmin=-100, vmax=100) - plt.title(f"Performance of {model_id.split('/')[-1]} on Increasing Difficulty", size=15) + # colour-bar + cbar = fig.colorbar(im, ax=ax) + cbar.ax.set_ylabel("Δ score (percentage-points)", rotation=-90, va="bottom") + + # ticks / labels + ax.set_xticks(np.arange(len(category_list)), labels=category_list, rotation=45, ha="right") + ax.set_yticks(np.arange(len(common_models)), labels=common_models) + + # grid for readability + ax.set_xticks(np.arange(-0.5, len(category_list), 1), minor=True) + ax.set_yticks(np.arange(-0.5, len(common_models), 1), minor=True) + ax.grid(which="minor", color="w", linestyle="-", linewidth=0.5) + + # annotate each cell + for i in range(len(common_models)): + for j in range(len(category_list)): + value = diff_matrix[i, j] + ax.text( + j, + i, + f"{value:.2f}", + ha="center", + va="center", + color="black" if abs(value) < 50 else "white", + fontsize=8, + ) + + ax.set_title("Per-Category Performance Δ (hard - easy)", fontsize=14) plt.tight_layout() return fig @@ -768,18 +778,8 @@ def main(): if not other_summaries: logger.error("No valid summaries found in comparison directory. Exiting.") return 1 - - comparison_output_dir = args.output_dir / "comparison" - if not comparison_output_dir.exists(): - logger.info(f"Creating comparison output directory {comparison_output_dir}") - comparison_output_dir.mkdir(parents=True, exist_ok=True) - - for model_name in summaries.keys(): - if model_name not in other_summaries: - logger.warning(f"Model {model_name} not found in comparison directory. Skippping...") - continue - fig = create_comparison_plot(summaries, other_summaries, model_name, categories) - save_figure(fig, comparison_output_dir, model_name, args.format, args.dpi) + fig = create_comparison_plot(summaries, other_summaries, categories) + save_figure(fig, args.output_dir, "model_category_delta_heatmap", args.format, args.dpi) else: logger.warning(f"Unknown plot type: {plot_type}")