fix order classification bug

This commit is contained in:
sam-paech 2025-07-14 22:30:41 +10:00
parent 86563587a5
commit 7735831513
3 changed files with 40 additions and 12 deletions

View file

@ -336,7 +336,7 @@ def _plot_phase_overlay(exp_a: Path, exp_b: Path, out_dir: Path) -> None:
# ───────────────────────── public API ─────────────────────────
def run(exp_a: Path, exp_b: Path, alpha: float = 0.05) -> None:
def run(exp_a: Path, exp_b: Path, alpha: float = 0.05, show_all: bool = False) -> None:
df_a = _load_games(exp_a)
df_b = _load_games(exp_b)
@ -373,16 +373,18 @@ def run(exp_a: Path, exp_b: Path, alpha: float = 0.05) -> None:
print(f"Comparing {tag_a} to {tag_b}: All comparisons are [{tag_b}] [{tag_a}].")
sig_agg = _significant(agg_df, alpha)
if sig_agg.empty:
rows_to_print = agg_df if show_all else _significant(agg_df, alpha)
if rows_to_print.empty:
_print_hdr(f"Aggregated Across Powers no metric significant at {ci_pct}% CI")
else:
n_a, n_b = int(sig_agg.iloc[0]["n_A"]), int(sig_agg.iloc[0]["n_B"])
_print_hdr(
f"Aggregated Across Powers significant at {ci_pct}% "
f"Aggregated Across Powers "
f"{'all metrics' if show_all else f'significant at {ci_pct}%'} "
f"(n({tag_a})={n_a}, n({tag_b})={n_b})"
)
width = max(len(m) for m in sig_agg["Metric"]) + 2
for _, r in sig_agg.iterrows():
for _, r in rows_to_print.iterrows():
print(_fmt_row(r["Metric"], r, width, ci_label))
# ── section 2: per-power breakdown ────────────────────────
@ -410,15 +412,20 @@ def run(exp_a: Path, exp_b: Path, alpha: float = 0.05) -> None:
pow_df.to_csv(pow_csv, index=False)
sig_pow = _significant(pow_df, alpha)
if sig_pow.empty:
rows_to_print_pow = pow_df if show_all else _significant(pow_df, alpha)
if rows_to_print_pow.empty:
_print_hdr(f"Per-Power Breakdown no metric significant at {ci_pct}% CI")
else:
_print_hdr(
f"Per-Power Breakdown metrics significant at {ci_pct}% CI (α={alpha})"
f"Per-Power Breakdown "
f"{'all metrics' if show_all else f'metrics significant at {ci_pct}% CI'} "
f"(α={alpha})"
)
width = max(len(m) for m in sig_pow["Metric"]) + 2
for power in powers:
sub = sig_pow[sig_pow["Power"] == power]
sub = rows_to_print_pow[rows_to_print_pow["Power"] == power]
if sub.empty:
continue
n_a, n_b = int(sub.iloc[0]["n_A"]), int(sub.iloc[0]["n_B"])