mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
fix(eval): comparison plot (#441)
* heatmap * filter comparison plots * latex style * curriculum heatmap * pre-commit * update figsize * large y-ticks * larger font * thinner * include 50
This commit is contained in:
parent
f51769927e
commit
b843f33b1d
2 changed files with 198 additions and 30 deletions
|
|
@ -27,6 +27,7 @@ import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import matplotlib
|
||||||
import matplotlib.colors as mcolors
|
import matplotlib.colors as mcolors
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -42,6 +43,22 @@ logging.basicConfig(
|
||||||
logger = logging.getLogger("visualize_results")
|
logger = logging.getLogger("visualize_results")
|
||||||
|
|
||||||
|
|
||||||
|
plt.rcParams.update(
|
||||||
|
{
|
||||||
|
"text.usetex": True,
|
||||||
|
"font.family": "serif",
|
||||||
|
"font.serif": ["Computer Modern Roman"],
|
||||||
|
"text.latex.preamble": r"\usepackage{amsmath,amssymb,amsfonts,mathrsfs,bm}",
|
||||||
|
"axes.labelsize": 20,
|
||||||
|
"font.size": 20,
|
||||||
|
"legend.fontsize": 14,
|
||||||
|
"xtick.labelsize": 14,
|
||||||
|
"ytick.labelsize": 14,
|
||||||
|
"axes.titlesize": 22,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_summaries(results_dir: str) -> Dict[str, Dict[str, Any]]:
|
def load_summaries(results_dir: str) -> Dict[str, Dict[str, Any]]:
|
||||||
"""Load all summary.json files from subdirectories.
|
"""Load all summary.json files from subdirectories.
|
||||||
|
|
||||||
|
|
@ -583,13 +600,14 @@ def create_comparison_plot(
|
||||||
summaries: Dict[str, Dict[str, Any]],
|
summaries: Dict[str, Dict[str, Any]],
|
||||||
other_summaries: Dict[str, Dict[str, Any]],
|
other_summaries: Dict[str, Dict[str, Any]],
|
||||||
categories: Optional[Dict[str, List[str]]] = None,
|
categories: Optional[Dict[str, List[str]]] = None,
|
||||||
|
compare_model_ids: Optional[List[str]] = None,
|
||||||
) -> Figure:
|
) -> Figure:
|
||||||
"""
|
"""
|
||||||
Build a heat-map of per-category score differences (scaled to –100 … 100).
|
Build a heat-map of per-category score differences (scaled to –100 … 100).
|
||||||
|
|
||||||
Rows : model IDs present in both `summaries` and `other_summaries`
|
Rows : category names (`categories`)
|
||||||
Cols : category names (`categories`)
|
Cols : model IDs present in both `summaries` and `other_summaries`
|
||||||
Value : 100 * (mean(score in summaries) − mean(score in other_summaries))
|
Value : 100 × (mean(score in summaries) − mean(score in other_summaries))
|
||||||
|
|
||||||
A numeric annotation (rounded to 2 dp) is rendered in every cell.
|
A numeric annotation (rounded to 2 dp) is rendered in every cell.
|
||||||
"""
|
"""
|
||||||
|
|
@ -601,55 +619,53 @@ def create_comparison_plot(
|
||||||
all_ds = next(iter(summaries.values()))["dataset_best_scores"].keys()
|
all_ds = next(iter(summaries.values()))["dataset_best_scores"].keys()
|
||||||
categories = {"all": list(all_ds)}
|
categories = {"all": list(all_ds)}
|
||||||
|
|
||||||
# models appearing in both result sets
|
# models present in both result sets
|
||||||
common_models = [m for m in summaries if m in other_summaries]
|
common_models = [m for m in summaries if m in other_summaries]
|
||||||
if not common_models:
|
if not common_models:
|
||||||
logger.error("No overlapping model IDs between the two result sets.")
|
logger.error("No overlapping model IDs between the two result sets.")
|
||||||
return plt.figure()
|
return plt.figure()
|
||||||
|
|
||||||
# sort models by overall performance
|
# sort models by overall performance
|
||||||
overall_scores = {}
|
overall_scores = {m: np.mean(list(s["dataset_best_scores"].values())) for m, s in summaries.items()}
|
||||||
for model_name, summary in summaries.items():
|
models = [m for m, _ in sorted(overall_scores.items(), key=lambda x: x[1], reverse=True) if m in common_models]
|
||||||
scores = list(summary["dataset_best_scores"].values())
|
if compare_model_ids:
|
||||||
overall_scores[model_name] = np.mean(scores)
|
models = [m for m in models if m in compare_model_ids]
|
||||||
models = [item[0] for item in sorted(overall_scores.items(), key=lambda x: x[1], reverse=True)]
|
|
||||||
common_models = [m for m in models if m in common_models]
|
|
||||||
|
|
||||||
category_list = sorted(categories.keys())
|
category_list = sorted(categories.keys())
|
||||||
diff_matrix = np.zeros((len(common_models), len(category_list)))
|
# ---------- note the transposed shape (categories × models)
|
||||||
|
diff_matrix = np.zeros((len(category_list), len(models)))
|
||||||
|
|
||||||
# compute 100 × Δ
|
# compute 100 × Δ
|
||||||
for i, model in enumerate(common_models):
|
for i, cat in enumerate(category_list):
|
||||||
|
ds = categories[cat]
|
||||||
|
for j, model in enumerate(models):
|
||||||
cur_scores = summaries[model]["dataset_best_scores"]
|
cur_scores = summaries[model]["dataset_best_scores"]
|
||||||
base_scores = other_summaries[model]["dataset_best_scores"]
|
base_scores = other_summaries[model]["dataset_best_scores"]
|
||||||
|
|
||||||
for j, cat in enumerate(category_list):
|
|
||||||
ds = categories[cat]
|
|
||||||
cur_mean = np.mean([cur_scores.get(d, 0.0) for d in ds]) if ds else 0.0
|
cur_mean = np.mean([cur_scores.get(d, 0.0) for d in ds]) if ds else 0.0
|
||||||
base_mean = np.mean([base_scores.get(d, 0.0) for d in ds]) if ds else 0.0
|
base_mean = np.mean([base_scores.get(d, 0.0) for d in ds]) if ds else 0.0
|
||||||
diff_matrix[i, j] = 100 * (cur_mean - base_mean) # scale to -100 … 100
|
diff_matrix[i, j] = 100 * (cur_mean - base_mean)
|
||||||
|
|
||||||
# ---------------------------------------------------------------- Plot
|
# ---------------------------------------------------------------- plot
|
||||||
fig, ax = plt.subplots(figsize=(max(8, len(category_list) * 1.2), max(6, len(common_models) * 0.5)))
|
fig, ax = plt.subplots(figsize=(max(8, len(models) * 1.2), max(6, len(category_list) * 0.5)))
|
||||||
|
|
||||||
im = ax.imshow(diff_matrix, cmap="coolwarm", aspect="auto", vmin=-100, vmax=100)
|
im = ax.imshow(diff_matrix, cmap="coolwarm", aspect="auto", vmin=-100, vmax=100)
|
||||||
|
|
||||||
# colour-bar
|
# colour-bar
|
||||||
cbar = fig.colorbar(im, ax=ax)
|
cbar = fig.colorbar(im, ax=ax)
|
||||||
cbar.ax.set_ylabel("Δ score (percentage-points)", rotation=-90, va="bottom")
|
cbar.ax.set_ylabel("$\Delta$ score (\%)", rotation=-90, va="bottom", fontweight="bold")
|
||||||
|
|
||||||
# ticks / labels
|
# ticks / labels
|
||||||
ax.set_xticks(np.arange(len(category_list)), labels=category_list, rotation=45, ha="right")
|
ax.set_xticks(np.arange(len(models)), labels=models, rotation=45, ha="right")
|
||||||
ax.set_yticks(np.arange(len(common_models)), labels=common_models)
|
ax.set_yticks(np.arange(len(category_list)), labels=category_list)
|
||||||
|
|
||||||
# grid for readability
|
# grid for readability
|
||||||
ax.set_xticks(np.arange(-0.5, len(category_list), 1), minor=True)
|
ax.set_xticks(np.arange(-0.5, len(models), 1), minor=True)
|
||||||
ax.set_yticks(np.arange(-0.5, len(common_models), 1), minor=True)
|
ax.set_yticks(np.arange(-0.5, len(category_list), 1), minor=True)
|
||||||
ax.grid(which="minor", color="w", linestyle="-", linewidth=0.5)
|
ax.grid(which="minor", color="w", linestyle="-", linewidth=0.5)
|
||||||
|
|
||||||
# annotate each cell
|
# annotate each cell
|
||||||
for i in range(len(common_models)):
|
for i in range(len(category_list)):
|
||||||
for j in range(len(category_list)):
|
for j in range(len(models)):
|
||||||
value = diff_matrix[i, j]
|
value = diff_matrix[i, j]
|
||||||
ax.text(
|
ax.text(
|
||||||
j,
|
j,
|
||||||
|
|
@ -658,10 +674,10 @@ def create_comparison_plot(
|
||||||
ha="center",
|
ha="center",
|
||||||
va="center",
|
va="center",
|
||||||
color="black" if abs(value) < 50 else "white",
|
color="black" if abs(value) < 50 else "white",
|
||||||
fontsize=8,
|
fontsize=12,
|
||||||
)
|
)
|
||||||
|
|
||||||
ax.set_title("Per-Category Performance Δ (hard - easy)", fontsize=14)
|
# ax.set_title("Per-Category Performance $\Delta$ (hard − easy)", fontweight="bold")
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
@ -702,6 +718,7 @@ def main():
|
||||||
"--top-mode", default="hardest", choices=["hardest", "easiest", "variable"], help="Mode for top datasets plot"
|
"--top-mode", default="hardest", choices=["hardest", "easiest", "variable"], help="Mode for top datasets plot"
|
||||||
)
|
)
|
||||||
parser.add_argument("--compare-results-dir", help="Directory to compare results with", default=None)
|
parser.add_argument("--compare-results-dir", help="Directory to compare results with", default=None)
|
||||||
|
parser.add_argument("--compare-model-ids", help="Comma-separated list of model IDs to compare", default=None)
|
||||||
parser.add_argument("--format", default="png", choices=["png", "pdf", "svg"], help="Output format for plots")
|
parser.add_argument("--format", default="png", choices=["png", "pdf", "svg"], help="Output format for plots")
|
||||||
parser.add_argument("--dpi", type=int, default=300, help="DPI for output images")
|
parser.add_argument("--dpi", type=int, default=300, help="DPI for output images")
|
||||||
parser.add_argument("--no-show", action="store_true", help="Don't display plots, just save them")
|
parser.add_argument("--no-show", action="store_true", help="Don't display plots, just save them")
|
||||||
|
|
@ -773,7 +790,8 @@ def main():
|
||||||
if not other_summaries:
|
if not other_summaries:
|
||||||
logger.error("No valid summaries found in comparison directory. Exiting.")
|
logger.error("No valid summaries found in comparison directory. Exiting.")
|
||||||
return 1
|
return 1
|
||||||
fig = create_comparison_plot(summaries, other_summaries, categories)
|
compare_model_ids = args.compare_model_ids.split(",") if args.compare_model_ids else None
|
||||||
|
fig = create_comparison_plot(summaries, other_summaries, categories, compare_model_ids)
|
||||||
save_figure(fig, args.output_dir, "model_category_delta_heatmap", args.format, args.dpi)
|
save_figure(fig, args.output_dir, "model_category_delta_heatmap", args.format, args.dpi)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
150
notebooks/plot_curriculum.ipynb
Normal file
150
notebooks/plot_curriculum.ipynb
Normal file
File diff suppressed because one or more lines are too long
Loading…
Add table
Add a link
Reference in a new issue