From 77358315132d327cea277354d42303bebb9088ac Mon Sep 17 00:00:00 2001 From: sam-paech <152407511+sam-paech@users.noreply.github.com> Date: Mon, 14 Jul 2025 22:30:41 +1000 Subject: [PATCH] fix order classification bug --- ai_diplomacy/game_history.py | 14 ++++++++++---- experiment_runner.py | 17 ++++++++++++++++- experiment_runner/analysis/compare_stats.py | 21 ++++++++++++++------- 3 files changed, 40 insertions(+), 12 deletions(-) diff --git a/ai_diplomacy/game_history.py b/ai_diplomacy/game_history.py index ef51c14..fdcd947 100644 --- a/ai_diplomacy/game_history.py +++ b/ai_diplomacy/game_history.py @@ -291,10 +291,15 @@ class GameHistory: def _order_type(order: str) -> str: o = order.upper() - if o == "WAIVE": + + if o.strip() == "WAIVE": return "waive" - if " H" in o or " HOLD" in o: + + # hold: ends with “ H” or “ HOLD” + import re + if re.search(r"\sH\s*$", o) or re.search(r"\sHOLD\s*$", o): return "hold" + if " S " in o: return "support" if " C " in o: @@ -303,12 +308,13 @@ class GameHistory: return "retreat" if " - " in o: return "move" - if " BUILD" in o or o.endswith(" B") or " B " in o: + if re.search(r"\sBUILD\s*$", o) or o.endswith(" B") or " B " in o: return "build" - if " DISBAND" in o or o.endswith(" D") or " D " in o: + if re.search(r"\sDISBAND\s*$", o) or o.endswith(" D") or " D " in o: return "disband" return "other" + # engine fallback engine_phases = {ph.name: ph for ph in getattr(game, "get_phase_history", lambda: [])()} eng2code = { diff --git a/experiment_runner.py b/experiment_runner.py index 4786aae..62aa880 100644 --- a/experiment_runner.py +++ b/experiment_runner.py @@ -116,6 +116,15 @@ def _add_experiment_flags(p: argparse.ArgumentParser) -> None: default=0.05, help="α for hypothesis tests in comparison mode (default 0.05).", ) + p.add_argument( + "--showall", + action="store_true", + help=( + "When used together with --compare_to, prints every metric in the " + "console output, not just significant results (confidence intervals still use --sig_level)." + ), + ) + def _add_lm_game_flags(p: argparse.ArgumentParser) -> None: @@ -469,7 +478,13 @@ def main() -> None: if exp_args.compare_to is not None: from experiment_runner.analysis import compare_stats # local import - compare_stats.run(exp_dir, exp_args.compare_to, alpha=exp_args.sig_level) + compare_stats.run( + exp_dir, + exp_args.compare_to, + alpha=exp_args.sig_level, + show_all=exp_args.showall, + ) + log.info("comparison complete; artefacts in %s/analysis/comparison", exp_dir) return diff --git a/experiment_runner/analysis/compare_stats.py b/experiment_runner/analysis/compare_stats.py index b636c39..f8851f7 100644 --- a/experiment_runner/analysis/compare_stats.py +++ b/experiment_runner/analysis/compare_stats.py @@ -336,7 +336,7 @@ def _plot_phase_overlay(exp_a: Path, exp_b: Path, out_dir: Path) -> None: # ───────────────────────── public API ───────────────────────── -def run(exp_a: Path, exp_b: Path, alpha: float = 0.05) -> None: +def run(exp_a: Path, exp_b: Path, alpha: float = 0.05, show_all: bool = False) -> None: df_a = _load_games(exp_a) df_b = _load_games(exp_b) @@ -373,16 +373,18 @@ def run(exp_a: Path, exp_b: Path, alpha: float = 0.05) -> None: print(f"Comparing {tag_a} to {tag_b}: All comparisons are [{tag_b}] – [{tag_a}].") sig_agg = _significant(agg_df, alpha) - if sig_agg.empty: + rows_to_print = agg_df if show_all else _significant(agg_df, alpha) + if rows_to_print.empty: _print_hdr(f"Aggregated Across Powers – no metric significant at {ci_pct}% CI") else: n_a, n_b = int(sig_agg.iloc[0]["n_A"]), int(sig_agg.iloc[0]["n_B"]) _print_hdr( - f"Aggregated Across Powers – significant at {ci_pct}% " + f"Aggregated Across Powers – " + f"{'all metrics' if show_all else f'significant at {ci_pct}%'} " f"(n({tag_a})={n_a}, n({tag_b})={n_b})" ) width = max(len(m) for m in sig_agg["Metric"]) + 2 - for _, r in sig_agg.iterrows(): + for _, r in rows_to_print.iterrows(): print(_fmt_row(r["Metric"], r, width, ci_label)) # ── section 2: per-power breakdown ──────────────────────── @@ -410,15 +412,20 @@ def run(exp_a: Path, exp_b: Path, alpha: float = 0.05) -> None: pow_df.to_csv(pow_csv, index=False) sig_pow = _significant(pow_df, alpha) - if sig_pow.empty: + + rows_to_print_pow = pow_df if show_all else _significant(pow_df, alpha) + if rows_to_print_pow.empty: _print_hdr(f"Per-Power Breakdown – no metric significant at {ci_pct}% CI") else: _print_hdr( - f"Per-Power Breakdown – metrics significant at {ci_pct}% CI (α={alpha})" + f"Per-Power Breakdown – " + f"{'all metrics' if show_all else f'metrics significant at {ci_pct}% CI'} " + f"(α={alpha})" + ) width = max(len(m) for m in sig_pow["Metric"]) + 2 for power in powers: - sub = sig_pow[sig_pow["Power"] == power] + sub = rows_to_print_pow[rows_to_print_pow["Power"] == power] if sub.empty: continue n_a, n_b = int(sub.iloc[0]["n_A"]), int(sub.iloc[0]["n_B"])