fix order classification bug

This commit is contained in:
sam-paech 2025-07-14 22:30:41 +10:00
parent 86563587a5
commit 7735831513
3 changed files with 40 additions and 12 deletions

View file

@ -291,10 +291,15 @@ class GameHistory:
def _order_type(order: str) -> str: def _order_type(order: str) -> str:
o = order.upper() o = order.upper()
if o == "WAIVE":
if o.strip() == "WAIVE":
return "waive" return "waive"
if " H" in o or " HOLD" in o:
# hold: ends with “ H” or “ HOLD”
import re
if re.search(r"\sH\s*$", o) or re.search(r"\sHOLD\s*$", o):
return "hold" return "hold"
if " S " in o: if " S " in o:
return "support" return "support"
if " C " in o: if " C " in o:
@ -303,12 +308,13 @@ class GameHistory:
return "retreat" return "retreat"
if " - " in o: if " - " in o:
return "move" return "move"
if " BUILD" in o or o.endswith(" B") or " B " in o: if re.search(r"\sBUILD\s*$", o) or o.endswith(" B") or " B " in o:
return "build" return "build"
if " DISBAND" in o or o.endswith(" D") or " D " in o: if re.search(r"\sDISBAND\s*$", o) or o.endswith(" D") or " D " in o:
return "disband" return "disband"
return "other" return "other"
# engine fallback # engine fallback
engine_phases = {ph.name: ph for ph in getattr(game, "get_phase_history", lambda: [])()} engine_phases = {ph.name: ph for ph in getattr(game, "get_phase_history", lambda: [])()}
eng2code = { eng2code = {

View file

@ -116,6 +116,15 @@ def _add_experiment_flags(p: argparse.ArgumentParser) -> None:
default=0.05, default=0.05,
help="α for hypothesis tests in comparison mode (default 0.05).", help="α for hypothesis tests in comparison mode (default 0.05).",
) )
p.add_argument(
"--showall",
action="store_true",
help=(
"When used together with --compare_to, prints every metric in the "
"console output, not just significant results (confidence intervals still use --sig_level)."
),
)
def _add_lm_game_flags(p: argparse.ArgumentParser) -> None: def _add_lm_game_flags(p: argparse.ArgumentParser) -> None:
@ -469,7 +478,13 @@ def main() -> None:
if exp_args.compare_to is not None: if exp_args.compare_to is not None:
from experiment_runner.analysis import compare_stats # local import from experiment_runner.analysis import compare_stats # local import
compare_stats.run(exp_dir, exp_args.compare_to, alpha=exp_args.sig_level) compare_stats.run(
exp_dir,
exp_args.compare_to,
alpha=exp_args.sig_level,
show_all=exp_args.showall,
)
log.info("comparison complete; artefacts in %s/analysis/comparison", exp_dir) log.info("comparison complete; artefacts in %s/analysis/comparison", exp_dir)
return return

View file

@ -336,7 +336,7 @@ def _plot_phase_overlay(exp_a: Path, exp_b: Path, out_dir: Path) -> None:
# ───────────────────────── public API ───────────────────────── # ───────────────────────── public API ─────────────────────────
def run(exp_a: Path, exp_b: Path, alpha: float = 0.05) -> None: def run(exp_a: Path, exp_b: Path, alpha: float = 0.05, show_all: bool = False) -> None:
df_a = _load_games(exp_a) df_a = _load_games(exp_a)
df_b = _load_games(exp_b) df_b = _load_games(exp_b)
@ -373,16 +373,18 @@ def run(exp_a: Path, exp_b: Path, alpha: float = 0.05) -> None:
print(f"Comparing {tag_a} to {tag_b}: All comparisons are [{tag_b}] [{tag_a}].") print(f"Comparing {tag_a} to {tag_b}: All comparisons are [{tag_b}] [{tag_a}].")
sig_agg = _significant(agg_df, alpha) sig_agg = _significant(agg_df, alpha)
if sig_agg.empty: rows_to_print = agg_df if show_all else _significant(agg_df, alpha)
if rows_to_print.empty:
_print_hdr(f"Aggregated Across Powers no metric significant at {ci_pct}% CI") _print_hdr(f"Aggregated Across Powers no metric significant at {ci_pct}% CI")
else: else:
n_a, n_b = int(sig_agg.iloc[0]["n_A"]), int(sig_agg.iloc[0]["n_B"]) n_a, n_b = int(sig_agg.iloc[0]["n_A"]), int(sig_agg.iloc[0]["n_B"])
_print_hdr( _print_hdr(
f"Aggregated Across Powers significant at {ci_pct}% " f"Aggregated Across Powers "
f"{'all metrics' if show_all else f'significant at {ci_pct}%'} "
f"(n({tag_a})={n_a}, n({tag_b})={n_b})" f"(n({tag_a})={n_a}, n({tag_b})={n_b})"
) )
width = max(len(m) for m in sig_agg["Metric"]) + 2 width = max(len(m) for m in sig_agg["Metric"]) + 2
for _, r in sig_agg.iterrows(): for _, r in rows_to_print.iterrows():
print(_fmt_row(r["Metric"], r, width, ci_label)) print(_fmt_row(r["Metric"], r, width, ci_label))
# ── section 2: per-power breakdown ──────────────────────── # ── section 2: per-power breakdown ────────────────────────
@ -410,15 +412,20 @@ def run(exp_a: Path, exp_b: Path, alpha: float = 0.05) -> None:
pow_df.to_csv(pow_csv, index=False) pow_df.to_csv(pow_csv, index=False)
sig_pow = _significant(pow_df, alpha) sig_pow = _significant(pow_df, alpha)
if sig_pow.empty:
rows_to_print_pow = pow_df if show_all else _significant(pow_df, alpha)
if rows_to_print_pow.empty:
_print_hdr(f"Per-Power Breakdown no metric significant at {ci_pct}% CI") _print_hdr(f"Per-Power Breakdown no metric significant at {ci_pct}% CI")
else: else:
_print_hdr( _print_hdr(
f"Per-Power Breakdown metrics significant at {ci_pct}% CI (α={alpha})" f"Per-Power Breakdown "
f"{'all metrics' if show_all else f'metrics significant at {ci_pct}% CI'} "
f"(α={alpha})"
) )
width = max(len(m) for m in sig_pow["Metric"]) + 2 width = max(len(m) for m in sig_pow["Metric"]) + 2
for power in powers: for power in powers:
sub = sig_pow[sig_pow["Power"] == power] sub = rows_to_print_pow[rows_to_print_pow["Power"] == power]
if sub.empty: if sub.empty:
continue continue
n_a, n_b = int(sub.iloc[0]["n_A"]), int(sub.iloc[0]["n_B"]) n_a, n_b = int(sub.iloc[0]["n_A"]), int(sub.iloc[0]["n_B"])