mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-05-03 17:53:26 +00:00
comparison plot (#436)
This commit is contained in:
parent
0cda6b1205
commit
5961a10145
1 changed files with 65 additions and 65 deletions
|
|
@ -587,76 +587,86 @@ def create_dashboard(summaries: Dict[str, Dict[str, Any]], categories: Dict[str,
|
||||||
def create_comparison_plot(
|
def create_comparison_plot(
|
||||||
summaries: Dict[str, Dict[str, Any]],
|
summaries: Dict[str, Dict[str, Any]],
|
||||||
other_summaries: Dict[str, Dict[str, Any]],
|
other_summaries: Dict[str, Dict[str, Any]],
|
||||||
model_id: str,
|
|
||||||
categories: Optional[Dict[str, List[str]]] = None,
|
categories: Optional[Dict[str, List[str]]] = None,
|
||||||
) -> Figure:
|
) -> Figure:
|
||||||
"""Create a comparison plot between two models.
|
"""
|
||||||
|
Build a heat-map of per-category score differences (scaled to –100 … 100).
|
||||||
|
|
||||||
Args:
|
Rows : model IDs present in both `summaries` and `other_summaries`
|
||||||
summaries: Dictionary of model summaries
|
Cols : category names (`categories`)
|
||||||
other_summaries: Dictionary of other model summaries for comparison
|
Value : 100 * (mean(score in summaries) − mean(score in other_summaries))
|
||||||
model_id: Model ID to compare with
|
|
||||||
categories: Dictionary mapping categories to dataset lists
|
|
||||||
|
|
||||||
Returns:
|
A numeric annotation (rounded to 2 dp) is rendered in every cell.
|
||||||
Matplotlib figure
|
|
||||||
"""
|
"""
|
||||||
if not summaries or not other_summaries:
|
if not summaries or not other_summaries:
|
||||||
logger.error("No summaries provided for comparison")
|
logger.error("No summaries provided for comparison")
|
||||||
return plt.figure()
|
return plt.figure()
|
||||||
|
|
||||||
current_scores, baseline_scores = {}, {}
|
if categories is None:
|
||||||
|
all_ds = next(iter(summaries.values()))["dataset_best_scores"].keys()
|
||||||
|
categories = {"all": list(all_ds)}
|
||||||
|
|
||||||
|
# models appearing in both result sets
|
||||||
|
common_models = [m for m in summaries if m in other_summaries]
|
||||||
|
if not common_models:
|
||||||
|
logger.error("No overlapping model IDs between the two result sets.")
|
||||||
|
return plt.figure()
|
||||||
|
|
||||||
|
# sort models by overall performance
|
||||||
|
overall_scores = {}
|
||||||
for model_name, summary in summaries.items():
|
for model_name, summary in summaries.items():
|
||||||
if model_name == model_id:
|
scores = list(summary["dataset_best_scores"].values())
|
||||||
for category, datasets in categories.items():
|
overall_scores[model_name] = np.mean(scores)
|
||||||
datasets_scores = [summary["dataset_best_scores"].get(dataset, 0) for dataset in datasets]
|
models = [item[0] for item in sorted(overall_scores.items(), key=lambda x: x[1], reverse=True)]
|
||||||
if datasets_scores: # Avoid division by zero
|
common_models = [m for m in models if m in common_models]
|
||||||
current_scores[category] = np.mean(datasets_scores)
|
|
||||||
else:
|
|
||||||
current_scores[category] = 0
|
|
||||||
|
|
||||||
for model_name, summary in other_summaries.items():
|
category_list = sorted(categories.keys())
|
||||||
if model_name == model_id:
|
diff_matrix = np.zeros((len(common_models), len(category_list)))
|
||||||
for category, datasets in categories.items():
|
|
||||||
datasets_scores = [summary["dataset_best_scores"].get(dataset, 0) for dataset in datasets]
|
|
||||||
if datasets_scores:
|
|
||||||
baseline_scores[category] = np.mean(datasets_scores)
|
|
||||||
else:
|
|
||||||
baseline_scores[category] = 0
|
|
||||||
|
|
||||||
logger.debug(f"Current scores: {current_scores}")
|
# compute 100 × Δ
|
||||||
logger.debug(f"Baseline scores: {baseline_scores}")
|
for i, model in enumerate(common_models):
|
||||||
|
cur_scores = summaries[model]["dataset_best_scores"]
|
||||||
|
base_scores = other_summaries[model]["dataset_best_scores"]
|
||||||
|
|
||||||
# Create a bar chart for comparison
|
for j, cat in enumerate(category_list):
|
||||||
fig, ax = plt.subplots(figsize=(20, 10))
|
ds = categories[cat]
|
||||||
categories_list = sorted(current_scores.keys())
|
cur_mean = np.mean([cur_scores.get(d, 0.0) for d in ds]) if ds else 0.0
|
||||||
current_values = [round(current_scores[cat] * 100, 2) for cat in categories_list]
|
base_mean = np.mean([base_scores.get(d, 0.0) for d in ds]) if ds else 0.0
|
||||||
baseline_values = [round(baseline_scores[cat] * 100, 2) for cat in categories_list]
|
diff_matrix[i, j] = 100 * (cur_mean - base_mean) # scale to -100 … 100
|
||||||
|
|
||||||
x = np.arange(len(categories_list))
|
# ---------------------------------------------------------------- Plot
|
||||||
width = 0.35
|
fig, ax = plt.subplots(figsize=(max(8, len(category_list) * 1.2), max(6, len(common_models) * 0.5)))
|
||||||
colors = plt.cm.tab10.colors
|
|
||||||
bars1 = ax.bar(x - width / 2, baseline_values, width, label="Baseline", color=colors[1])
|
|
||||||
bars2 = ax.bar(x + width / 2, current_values, width, label="Difficult", color=colors[0])
|
|
||||||
ax.set_ylabel("Average Score")
|
|
||||||
ax.set_title(f"Performance of {model_id} when increasing difficulty", size=15)
|
|
||||||
ax.set_xticks(x)
|
|
||||||
ax.set_xticklabels(categories_list, rotation=45, ha="right")
|
|
||||||
ax.legend()
|
|
||||||
ax.set_ylim(0, max(max(current_values), max(baseline_values)) * 1.1)
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.grid(axis="y")
|
|
||||||
|
|
||||||
# Add value labels on top of bars
|
im = ax.imshow(diff_matrix, cmap="coolwarm", aspect="auto", vmin=-100, vmax=100)
|
||||||
for bar in bars1:
|
|
||||||
height = bar.get_height()
|
|
||||||
ax.text(bar.get_x() + bar.get_width() / 2.0, height + 0.01, f"{height:.2f}", ha="center", va="bottom")
|
|
||||||
for bar in bars2:
|
|
||||||
height = bar.get_height()
|
|
||||||
ax.text(bar.get_x() + bar.get_width() / 2.0, height + 0.01, f"{height:.2f}", ha="center", va="bottom")
|
|
||||||
|
|
||||||
plt.title(f"Performance of {model_id.split('/')[-1]} on Increasing Difficulty", size=15)
|
# colour-bar
|
||||||
|
cbar = fig.colorbar(im, ax=ax)
|
||||||
|
cbar.ax.set_ylabel("Δ score (percentage-points)", rotation=-90, va="bottom")
|
||||||
|
|
||||||
|
# ticks / labels
|
||||||
|
ax.set_xticks(np.arange(len(category_list)), labels=category_list, rotation=45, ha="right")
|
||||||
|
ax.set_yticks(np.arange(len(common_models)), labels=common_models)
|
||||||
|
|
||||||
|
# grid for readability
|
||||||
|
ax.set_xticks(np.arange(-0.5, len(category_list), 1), minor=True)
|
||||||
|
ax.set_yticks(np.arange(-0.5, len(common_models), 1), minor=True)
|
||||||
|
ax.grid(which="minor", color="w", linestyle="-", linewidth=0.5)
|
||||||
|
|
||||||
|
# annotate each cell
|
||||||
|
for i in range(len(common_models)):
|
||||||
|
for j in range(len(category_list)):
|
||||||
|
value = diff_matrix[i, j]
|
||||||
|
ax.text(
|
||||||
|
j,
|
||||||
|
i,
|
||||||
|
f"{value:.2f}",
|
||||||
|
ha="center",
|
||||||
|
va="center",
|
||||||
|
color="black" if abs(value) < 50 else "white",
|
||||||
|
fontsize=8,
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_title("Per-Category Performance Δ (hard - easy)", fontsize=14)
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
@ -768,18 +778,8 @@ def main():
|
||||||
if not other_summaries:
|
if not other_summaries:
|
||||||
logger.error("No valid summaries found in comparison directory. Exiting.")
|
logger.error("No valid summaries found in comparison directory. Exiting.")
|
||||||
return 1
|
return 1
|
||||||
|
fig = create_comparison_plot(summaries, other_summaries, categories)
|
||||||
comparison_output_dir = args.output_dir / "comparison"
|
save_figure(fig, args.output_dir, "model_category_delta_heatmap", args.format, args.dpi)
|
||||||
if not comparison_output_dir.exists():
|
|
||||||
logger.info(f"Creating comparison output directory {comparison_output_dir}")
|
|
||||||
comparison_output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
for model_name in summaries.keys():
|
|
||||||
if model_name not in other_summaries:
|
|
||||||
logger.warning(f"Model {model_name} not found in comparison directory. Skippping...")
|
|
||||||
continue
|
|
||||||
fig = create_comparison_plot(summaries, other_summaries, model_name, categories)
|
|
||||||
save_figure(fig, comparison_output_dir, model_name, args.format, args.dpi)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Unknown plot type: {plot_type}")
|
logger.warning(f"Unknown plot type: {plot_type}")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue