mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
heatmap (#438)
This commit is contained in:
parent
add527ada1
commit
93e731c29c
1 changed files with 48 additions and 53 deletions
|
|
@ -366,85 +366,80 @@ def create_performance_distribution_violin(summaries: Dict[str, Dict[str, Any]])
|
|||
return fig
|
||||
|
||||
|
||||
def create_performance_heatmap(summaries: Dict[str, Dict[str, Any]], categories: Dict[str, List[str]]) -> Figure:
|
||||
"""Create a heatmap of model performance across datasets.
|
||||
def create_performance_heatmap(
|
||||
summaries: Dict[str, Dict[str, Any]],
|
||||
categories: Dict[str, List[str]],
|
||||
) -> Figure:
|
||||
"""
|
||||
Heat-map of model performance (0–100 %) across individual datasets.
|
||||
|
||||
Args:
|
||||
summaries: Dictionary of model summaries
|
||||
categories: Dictionary mapping categories to dataset lists
|
||||
|
||||
Returns:
|
||||
Matplotlib figure
|
||||
Rows : models (sorted by overall mean score, high→low)
|
||||
Cols : datasets grouped by `categories`
|
||||
Cell : 100 × raw score
|
||||
"""
|
||||
if not summaries:
|
||||
logger.error("No summaries provided")
|
||||
return plt.figure()
|
||||
|
||||
# Get all dataset names
|
||||
all_datasets = []
|
||||
for category, datasets in sorted(categories.items()):
|
||||
all_datasets.extend(sorted(datasets))
|
||||
# ---- gather dataset names in category order
|
||||
all_datasets: List[str] = []
|
||||
for cat, ds in sorted(categories.items()):
|
||||
all_datasets.extend(sorted(ds))
|
||||
|
||||
# Sort models by overall performance
|
||||
overall_scores = {}
|
||||
for model_name, summary in summaries.items():
|
||||
scores = list(summary["dataset_best_scores"].values())
|
||||
overall_scores[model_name] = np.mean(scores)
|
||||
models = [item[0] for item in sorted(overall_scores.items(), key=lambda x: x[1], reverse=True)]
|
||||
# ---- sort models by overall performance
|
||||
overall = {m: np.mean(list(s["dataset_best_scores"].values())) for m, s in summaries.items()}
|
||||
models = [m for m, _ in sorted(overall.items(), key=lambda x: x[1], reverse=True)]
|
||||
|
||||
# Create score matrix
|
||||
# ---- build score matrix (0–100)
|
||||
score_matrix = np.zeros((len(models), len(all_datasets)))
|
||||
|
||||
for i, model in enumerate(models):
|
||||
for j, dataset in enumerate(all_datasets):
|
||||
score_matrix[i, j] = summaries[model]["dataset_best_scores"].get(dataset, 0)
|
||||
for j, ds in enumerate(all_datasets):
|
||||
score_matrix[i, j] = 100 * summaries[model]["dataset_best_scores"].get(ds, 0.0)
|
||||
|
||||
# Create heatmap
|
||||
# ---- plot
|
||||
fig, ax = plt.subplots(figsize=(max(20, len(all_datasets) * 0.25), max(8, len(models) * 0.5)))
|
||||
|
||||
im = ax.imshow(score_matrix, cmap="viridis", aspect="auto", vmin=0, vmax=1)
|
||||
im = ax.imshow(score_matrix, cmap="YlOrRd", aspect="auto", vmin=0, vmax=100)
|
||||
|
||||
# Add colorbar
|
||||
cbar = ax.figure.colorbar(im, ax=ax)
|
||||
cbar.ax.set_ylabel("Score", rotation=-90, va="bottom")
|
||||
# colour-bar
|
||||
cbar = fig.colorbar(im, ax=ax)
|
||||
cbar.ax.set_ylabel("Score (%)", rotation=-90, va="bottom")
|
||||
|
||||
# Set ticks and labels
|
||||
# ticks & labels
|
||||
ax.set_xticks(np.arange(len(all_datasets)))
|
||||
ax.set_xticklabels(all_datasets, rotation=270, fontsize=8)
|
||||
ax.set_yticks(np.arange(len(models)))
|
||||
ax.set_xticklabels(all_datasets, rotation=90, fontsize=8)
|
||||
ax.set_yticklabels(models)
|
||||
|
||||
# Add category separators and labels
|
||||
current_idx = 0
|
||||
for category, datasets in sorted(categories.items()):
|
||||
if datasets:
|
||||
# Add vertical line after each category
|
||||
next_idx = current_idx + len(datasets)
|
||||
if next_idx < len(all_datasets):
|
||||
ax.axvline(x=next_idx - 0.5, color="white", linestyle="-", linewidth=2)
|
||||
# category separators & titles
|
||||
current = 0
|
||||
label_offset = -0.25 # ↓ push labels down (was around −0.7)
|
||||
for cat, ds in sorted(categories.items()):
|
||||
if not ds:
|
||||
continue
|
||||
nxt = current + len(ds)
|
||||
if nxt < len(all_datasets):
|
||||
ax.axvline(nxt - 0.5, color="white", linewidth=2)
|
||||
|
||||
# Add category label
|
||||
middle_idx = current_idx + len(datasets) / 2 - 0.5
|
||||
ax.text(
|
||||
middle_idx,
|
||||
-0.5,
|
||||
category,
|
||||
ha="center",
|
||||
va="top",
|
||||
fontsize=10,
|
||||
bbox=dict(facecolor="white", alpha=0.7, edgecolor="none"),
|
||||
)
|
||||
mid = current + len(ds) / 2 - 0.5
|
||||
ax.text(
|
||||
mid,
|
||||
label_offset, # <-- use offset
|
||||
cat,
|
||||
ha="center",
|
||||
va="top",
|
||||
fontsize=10,
|
||||
bbox=dict(facecolor="white", alpha=0.7, edgecolor="none"),
|
||||
)
|
||||
current = nxt
|
||||
|
||||
current_idx = next_idx
|
||||
|
||||
# Add grid lines
|
||||
# grid (mirrors comparison-plot style)
|
||||
ax.set_xticks(np.arange(-0.5, len(all_datasets), 1), minor=True)
|
||||
ax.set_yticks(np.arange(-0.5, len(models), 1), minor=True)
|
||||
ax.grid(which="minor", color="w", linestyle="-", linewidth=0.5)
|
||||
|
||||
plt.title("Model Performance Heatmap", size=15)
|
||||
# ax.set_title("Model Performance by Dataset", fontsize=15)
|
||||
plt.tight_layout()
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue