mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-19 12:58:09 +00:00
fix order classification bug
This commit is contained in:
parent
86563587a5
commit
7735831513
3 changed files with 40 additions and 12 deletions
|
|
@ -291,10 +291,15 @@ class GameHistory:
|
||||||
|
|
||||||
def _order_type(order: str) -> str:
|
def _order_type(order: str) -> str:
|
||||||
o = order.upper()
|
o = order.upper()
|
||||||
if o == "WAIVE":
|
|
||||||
|
if o.strip() == "WAIVE":
|
||||||
return "waive"
|
return "waive"
|
||||||
if " H" in o or " HOLD" in o:
|
|
||||||
|
# hold: ends with “ H” or “ HOLD”
|
||||||
|
import re
|
||||||
|
if re.search(r"\sH\s*$", o) or re.search(r"\sHOLD\s*$", o):
|
||||||
return "hold"
|
return "hold"
|
||||||
|
|
||||||
if " S " in o:
|
if " S " in o:
|
||||||
return "support"
|
return "support"
|
||||||
if " C " in o:
|
if " C " in o:
|
||||||
|
|
@ -303,12 +308,13 @@ class GameHistory:
|
||||||
return "retreat"
|
return "retreat"
|
||||||
if " - " in o:
|
if " - " in o:
|
||||||
return "move"
|
return "move"
|
||||||
if " BUILD" in o or o.endswith(" B") or " B " in o:
|
if re.search(r"\sBUILD\s*$", o) or o.endswith(" B") or " B " in o:
|
||||||
return "build"
|
return "build"
|
||||||
if " DISBAND" in o or o.endswith(" D") or " D " in o:
|
if re.search(r"\sDISBAND\s*$", o) or o.endswith(" D") or " D " in o:
|
||||||
return "disband"
|
return "disband"
|
||||||
return "other"
|
return "other"
|
||||||
|
|
||||||
|
|
||||||
# engine fallback
|
# engine fallback
|
||||||
engine_phases = {ph.name: ph for ph in getattr(game, "get_phase_history", lambda: [])()}
|
engine_phases = {ph.name: ph for ph in getattr(game, "get_phase_history", lambda: [])()}
|
||||||
eng2code = {
|
eng2code = {
|
||||||
|
|
|
||||||
|
|
@ -116,6 +116,15 @@ def _add_experiment_flags(p: argparse.ArgumentParser) -> None:
|
||||||
default=0.05,
|
default=0.05,
|
||||||
help="α for hypothesis tests in comparison mode (default 0.05).",
|
help="α for hypothesis tests in comparison mode (default 0.05).",
|
||||||
)
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--showall",
|
||||||
|
action="store_true",
|
||||||
|
help=(
|
||||||
|
"When used together with --compare_to, prints every metric in the "
|
||||||
|
"console output, not just significant results (confidence intervals still use --sig_level)."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _add_lm_game_flags(p: argparse.ArgumentParser) -> None:
|
def _add_lm_game_flags(p: argparse.ArgumentParser) -> None:
|
||||||
|
|
@ -469,7 +478,13 @@ def main() -> None:
|
||||||
if exp_args.compare_to is not None:
|
if exp_args.compare_to is not None:
|
||||||
from experiment_runner.analysis import compare_stats # local import
|
from experiment_runner.analysis import compare_stats # local import
|
||||||
|
|
||||||
compare_stats.run(exp_dir, exp_args.compare_to, alpha=exp_args.sig_level)
|
compare_stats.run(
|
||||||
|
exp_dir,
|
||||||
|
exp_args.compare_to,
|
||||||
|
alpha=exp_args.sig_level,
|
||||||
|
show_all=exp_args.showall,
|
||||||
|
)
|
||||||
|
|
||||||
log.info("comparison complete; artefacts in %s/analysis/comparison", exp_dir)
|
log.info("comparison complete; artefacts in %s/analysis/comparison", exp_dir)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -336,7 +336,7 @@ def _plot_phase_overlay(exp_a: Path, exp_b: Path, out_dir: Path) -> None:
|
||||||
|
|
||||||
|
|
||||||
# ───────────────────────── public API ─────────────────────────
|
# ───────────────────────── public API ─────────────────────────
|
||||||
def run(exp_a: Path, exp_b: Path, alpha: float = 0.05) -> None:
|
def run(exp_a: Path, exp_b: Path, alpha: float = 0.05, show_all: bool = False) -> None:
|
||||||
df_a = _load_games(exp_a)
|
df_a = _load_games(exp_a)
|
||||||
df_b = _load_games(exp_b)
|
df_b = _load_games(exp_b)
|
||||||
|
|
||||||
|
|
@ -373,16 +373,18 @@ def run(exp_a: Path, exp_b: Path, alpha: float = 0.05) -> None:
|
||||||
print(f"Comparing {tag_a} to {tag_b}: All comparisons are [{tag_b}] – [{tag_a}].")
|
print(f"Comparing {tag_a} to {tag_b}: All comparisons are [{tag_b}] – [{tag_a}].")
|
||||||
|
|
||||||
sig_agg = _significant(agg_df, alpha)
|
sig_agg = _significant(agg_df, alpha)
|
||||||
if sig_agg.empty:
|
rows_to_print = agg_df if show_all else _significant(agg_df, alpha)
|
||||||
|
if rows_to_print.empty:
|
||||||
_print_hdr(f"Aggregated Across Powers – no metric significant at {ci_pct}% CI")
|
_print_hdr(f"Aggregated Across Powers – no metric significant at {ci_pct}% CI")
|
||||||
else:
|
else:
|
||||||
n_a, n_b = int(sig_agg.iloc[0]["n_A"]), int(sig_agg.iloc[0]["n_B"])
|
n_a, n_b = int(sig_agg.iloc[0]["n_A"]), int(sig_agg.iloc[0]["n_B"])
|
||||||
_print_hdr(
|
_print_hdr(
|
||||||
f"Aggregated Across Powers – significant at {ci_pct}% "
|
f"Aggregated Across Powers – "
|
||||||
|
f"{'all metrics' if show_all else f'significant at {ci_pct}%'} "
|
||||||
f"(n({tag_a})={n_a}, n({tag_b})={n_b})"
|
f"(n({tag_a})={n_a}, n({tag_b})={n_b})"
|
||||||
)
|
)
|
||||||
width = max(len(m) for m in sig_agg["Metric"]) + 2
|
width = max(len(m) for m in sig_agg["Metric"]) + 2
|
||||||
for _, r in sig_agg.iterrows():
|
for _, r in rows_to_print.iterrows():
|
||||||
print(_fmt_row(r["Metric"], r, width, ci_label))
|
print(_fmt_row(r["Metric"], r, width, ci_label))
|
||||||
|
|
||||||
# ── section 2: per-power breakdown ────────────────────────
|
# ── section 2: per-power breakdown ────────────────────────
|
||||||
|
|
@ -410,15 +412,20 @@ def run(exp_a: Path, exp_b: Path, alpha: float = 0.05) -> None:
|
||||||
pow_df.to_csv(pow_csv, index=False)
|
pow_df.to_csv(pow_csv, index=False)
|
||||||
|
|
||||||
sig_pow = _significant(pow_df, alpha)
|
sig_pow = _significant(pow_df, alpha)
|
||||||
if sig_pow.empty:
|
|
||||||
|
rows_to_print_pow = pow_df if show_all else _significant(pow_df, alpha)
|
||||||
|
if rows_to_print_pow.empty:
|
||||||
_print_hdr(f"Per-Power Breakdown – no metric significant at {ci_pct}% CI")
|
_print_hdr(f"Per-Power Breakdown – no metric significant at {ci_pct}% CI")
|
||||||
else:
|
else:
|
||||||
_print_hdr(
|
_print_hdr(
|
||||||
f"Per-Power Breakdown – metrics significant at {ci_pct}% CI (α={alpha})"
|
f"Per-Power Breakdown – "
|
||||||
|
f"{'all metrics' if show_all else f'metrics significant at {ci_pct}% CI'} "
|
||||||
|
f"(α={alpha})"
|
||||||
|
|
||||||
)
|
)
|
||||||
width = max(len(m) for m in sig_pow["Metric"]) + 2
|
width = max(len(m) for m in sig_pow["Metric"]) + 2
|
||||||
for power in powers:
|
for power in powers:
|
||||||
sub = sig_pow[sig_pow["Power"] == power]
|
sub = rows_to_print_pow[rows_to_print_pow["Power"] == power]
|
||||||
if sub.empty:
|
if sub.empty:
|
||||||
continue
|
continue
|
||||||
n_a, n_b = int(sub.iloc[0]["n_A"]), int(sub.iloc[0]["n_B"])
|
n_a, n_b = int(sub.iloc[0]["n_A"]), int(sub.iloc[0]["n_B"])
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue