bug fixes

This commit is contained in:
sam-paech 2025-07-10 07:06:34 +10:00
parent 754cf8e23b
commit 2338e3c6e6
6 changed files with 170 additions and 32 deletions

View file

@ -9,6 +9,7 @@ Control 18 supply centers. Nothing else will do.
- Format: 'F [PROVINCE]/[COAST]' where [COAST] = NC (North), SC (South), EC (East), or WC (West)
- Example: 'F SPA/SC - MAO'
- Only fleets need coast specification.
5. Aim to issue an order for all of your units. Holds tend to be wasted orders.
Your Task:
1. Reason

View file

@ -23,6 +23,7 @@ Version: 2.0 (Hard Mode with complete validation)
import os
import json
import csv
import sys
import argparse
from pathlib import Path
from collections import defaultdict, Counter
@ -42,6 +43,8 @@ except ImportError:
"models.PowerEnum not found via absolute or relative import. "
) from exc
csv.field_size_limit(sys.maxsize)
class StatisticalGameAnalyzer:
"""Production-ready analyzer for AI Diplomacy game statistics.
@ -363,6 +366,8 @@ class StatisticalGameAnalyzer:
# Categorize by relationship
recipient = msg.get('recipient_power')
if recipient is None:
continue
try:
normalized_recipient = PowerEnum(recipient).value
except ValueError:
@ -696,6 +701,8 @@ class StatisticalGameAnalyzer:
# Categorize by relationship
recipient = msg.get('recipient_power')
if recipient is None:
continue
# This will coerce some known aliases to match the 7 acceptable names (see models.py)
normalized_recipient = PowerEnum(recipient)

View file

@ -1,20 +1,3 @@
"""
experiment_runner.analysis.compare_stats
----------------------------------------
Compares two completed Diplomacy experiments, printing every metric
whose confidence interval (1 α) excludes 0.
Derived maximumever metrics
max_supply_centers_owned per-power max across phases
max_territories_controlled per-power max across phases
max_military_units per-power max across phases
max_game_score *game-level* max across powers
(only used in the aggregated-across-powers comparison)
All CLI semantics, CSV outputs, significance tests, etc., remain intact.
"""
from __future__ import annotations
from pathlib import Path
@ -23,6 +6,33 @@ from typing import Dict, List
import numpy as np
import pandas as pd
from scipy import stats
from statsmodels.stats import multitest as smm
from statsmodels.stats import power as smp
import warnings
try: # present from SciPy ≥ 1.10
from scipy._lib._util import DegenerateDataWarning
except Exception: # fallback for older SciPy
class DegenerateDataWarning(UserWarning):
pass
warnings.filterwarnings(
"ignore",
category=RuntimeWarning,
message="invalid value encountered in scalar divide",
)
warnings.filterwarnings(
"ignore",
category=RuntimeWarning,
message="Precision loss occurred in moment calculation",
)
warnings.filterwarnings(
"ignore",
category=RuntimeWarning,
message="The BCa confidence interval cannot be calculated.",
)
warnings.filterwarnings("ignore", category=DegenerateDataWarning)
# ───────────────────────── helpers ──────────────────────────
_EXCLUDE: set[str] = {
@ -108,39 +118,119 @@ def _load_games(exp: Path) -> pd.DataFrame:
return df_game
# ───────────────── Advanced Statistics Helpers ──────────────────
def _bayesian_t_test(a: np.ndarray, b: np.ndarray, alpha: float, n_samples: int = 10000):
"""Perform a simple Bayesian t-test assuming uninformative priors."""
def posterior_samples(data):
n, mean, var = len(data), np.mean(data), np.var(data, ddof=1)
if n == 0 or var == 0: return np.full(n_samples, mean) # Handle edge cases
# Posterior parameters for Normal-Inverse-Gamma
mu_n, nu_n, alpha_n, beta_n = mean, n, n / 2, (n / 2) * var
# Sample from posterior
post_var = stats.invgamma.rvs(a=alpha_n, scale=beta_n, size=n_samples, random_state=0)
post_mean = stats.norm.rvs(loc=mu_n, scale=np.sqrt(post_var / nu_n), size=n_samples, random_state=1)
return post_mean
try:
post_a, post_b = posterior_samples(a), posterior_samples(b)
diff_samples = post_b - post_a
post_mean_diff = np.mean(diff_samples)
ci_low, ci_high = np.percentile(diff_samples, [100 * alpha / 2, 100 * (1 - alpha / 2)])
prob_b_gt_a = np.mean(diff_samples > 0)
return {
"Bayes_Post_Mean_Diff": post_mean_diff,
"Bayes_CI_low": ci_low,
"Bayes_CI_high": ci_high,
"Bayes_Prob_B_gt_A": prob_b_gt_a,
}
except Exception:
return {k: np.nan for k in ["Bayes_Post_Mean_Diff", "Bayes_CI_low", "Bayes_CI_high", "Bayes_Prob_B_gt_A"]}
# ───────────────────── Welch statistics ──────────────────────
def _welch(a: np.ndarray, b: np.ndarray, alpha: float) -> Dict:
# --- Frequentist Welch's t-test ---
_t, p_val = stats.ttest_ind(a, b, equal_var=False)
mean_a, mean_b = a.mean(), b.mean()
diff = mean_b - mean_a
s1, s2 = a.var(ddof=1), b.var(ddof=1)
n1, n2 = len(a), len(b)
se = np.sqrt(s1/n1 + s2/n2)
df = (s1/n1 + s2/n2)**2 / ((s1/n1)**2/(n1-1) + (s2/n2)**2/(n2-1))
ci = stats.t.ppf(1 - alpha/2, df) * se
# --- Standard Deviations and Cohen's d ---
sd_a, sd_b = a.std(ddof=1), b.std(ddof=1)
pooled_sd = np.sqrt((a.var(ddof=1) + b.var(ddof=1)) / 2)
cohen_d = diff / pooled_sd if pooled_sd else np.nan
se = np.sqrt(a.var(ddof=1) / len(a) + b.var(ddof=1) / len(b))
df = len(a) + len(b) - 2
ci = stats.t.ppf(1 - alpha / 2, df=df) * se
# --- Normality/Symmetry Diagnostics ---
skew_a, kurt_a = stats.skew(a), stats.kurtosis(a)
skew_b, kurt_b = stats.skew(b), stats.kurtosis(b)
# --- Non-parametric p-value (Permutation Test) ---
try:
perm_res = stats.permutation_test((a, b), lambda x, y: np.mean(y) - np.mean(x), n_resamples=9999, random_state=0)
p_perm = perm_res.pvalue
except Exception:
p_perm = np.nan
# --- Power for a minimally interesting effect (d=0.5) ---
try:
power = smp.TTestIndPower().solve_power(effect_size=0.5, nobs1=len(a), alpha=alpha, ratio=len(b)/len(a))
except Exception:
power = np.nan
# --- Robust location estimate (Median difference with bootstrap CI) ---
try:
median_diff = np.median(b) - np.median(a)
res = stats.bootstrap((a, b), lambda x, y: np.median(y) - np.median(x),
confidence_level=1-alpha, method='BCa', n_resamples=2499, random_state=0)
median_ci_low, median_ci_high = res.confidence_interval
except Exception:
median_diff, median_ci_low, median_ci_high = np.nan, np.nan, np.nan
# --- Leave-one-out influence summary ---
try:
loo_diffs_a = [np.mean(b) - np.mean(np.delete(a, i)) for i in range(len(a))]
loo_diffs_b = [np.mean(np.delete(b, i)) - np.mean(a) for i in range(len(b))]
all_loo_diffs = loo_diffs_a + loo_diffs_b
loo_diff_min, loo_diff_max = np.min(all_loo_diffs), np.max(all_loo_diffs)
except Exception:
loo_diff_min, loo_diff_max = np.nan, np.nan
# --- Bayesian analysis ---
bayes_results = _bayesian_t_test(a, b, alpha)
return {
"Mean_A": mean_a,
"Mean_B": mean_b,
"Diff": diff,
"CI_low": diff - ci,
"CI_high": diff + ci,
"p_value": p_val,
"Mean_A": mean_a, "Mean_B": mean_b, "Diff": diff,
"SD_A": sd_a, "SD_B": sd_b,
"SE_diff": se,
"CI_low": diff - ci, "CI_high": diff + ci,
"p_value": p_val, "p_perm": p_perm,
"Cohen_d": cohen_d,
"n_A": len(a),
"n_B": len(b),
"n_A": len(a), "n_B": len(b),
"Skew_A": skew_a, "Kurtosis_A": kurt_a,
"Skew_B": skew_b, "Kurtosis_B": kurt_b,
"Power_d_0.5": power,
"Median_Diff": median_diff, "Median_Diff_CI_low": median_ci_low, "Median_Diff_CI_high": median_ci_high,
"LOO_Diff_min": loo_diff_min, "LOO_Diff_max": loo_diff_max,
**bayes_results,
}
# ───────────────── console helpers ───────────────────────────
def _fmt_row(label: str, r: Dict, width: int, ci_label: str) -> str:
ci = f"[{r['CI_low']:+.2f}, {r['CI_high']:+.2f}]"
p_perm_val = r.get('p_perm', np.nan)
return (
f" {label:<{width}} "
f"{r['Diff']:+6.2f} "
f"({r['Mean_A']:.2f}{r['Mean_B']:.2f}) "
f"{ci_label} {ci:<17} "
f"p={r['p_value']:.4g} "
f"(p_perm={p_perm_val:.3f}) "
f"d={r['Cohen_d']:+.2f}"
)
@ -274,6 +364,8 @@ def run(exp_a: Path, exp_b: Path, alpha: float = 0.05) -> None:
rows_agg.append({"Metric": col, **_welch(a_vals, b_vals, alpha)})
agg_df = pd.DataFrame(rows_agg)
if not agg_df.empty:
p_vals = agg_df["p_value"].dropna()
agg_csv = out_dir / f"comparison_aggregated_vs_{tag_b}.csv"
agg_df.to_csv(agg_csv, index=False)
@ -311,6 +403,9 @@ def run(exp_a: Path, exp_b: Path, alpha: float = 0.05) -> None:
)
pow_df = pd.DataFrame(rows_pow)
if not pow_df.empty:
p_vals = pow_df["p_value"].dropna()
pow_csv = out_dir / f"comparison_by_power_vs_{tag_b}.csv"
pow_df.to_csv(pow_csv, index=False)
@ -327,7 +422,7 @@ def run(exp_a: Path, exp_b: Path, alpha: float = 0.05) -> None:
if sub.empty:
continue
n_a, n_b = int(sub.iloc[0]["n_A"]), int(sub.iloc[0]["n_B"])
print(f"{power} (n({tag_a})={n_a}, n({tag_b})={n_b})")
print(f"\n{power} (n({tag_a})={n_a}, n({tag_b})={n_b})")
for _, r in sub.iterrows():
print(_fmt_row(r["Metric"], r, width, ci_label))
@ -346,4 +441,3 @@ def run(exp_a: Path, exp_b: Path, alpha: float = 0.05) -> None:
print(f"\n[warning] phase overlay plot generation failed: {exc}")
print('Complete')

View file

@ -26,6 +26,7 @@ dependencies = [
"sphinx>=8.2.3",
"sphinx-copybutton>=0.5.2",
"sphinx-rtd-theme>=3.0.2",
"statsmodels==0.14.4",
"together>=1.5.17",
"tornado>=5.0",
"tqdm>=4.67.1",

35
uv.lock generated
View file

@ -28,6 +28,7 @@ dependencies = [
{ name = "sphinx" },
{ name = "sphinx-copybutton" },
{ name = "sphinx-rtd-theme" },
{ name = "statsmodels" },
{ name = "together" },
{ name = "tornado" },
{ name = "tqdm" },
@ -57,6 +58,7 @@ requires-dist = [
{ name = "sphinx", specifier = ">=8.2.3" },
{ name = "sphinx-copybutton", specifier = ">=0.5.2" },
{ name = "sphinx-rtd-theme", specifier = ">=3.0.2" },
{ name = "statsmodels", specifier = "==0.14.4" },
{ name = "together", specifier = ">=1.5.17" },
{ name = "tornado", specifier = ">=5.0" },
{ name = "tqdm", specifier = ">=4.67.1" },
@ -1061,6 +1063,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/39/c2/646d2e93e0af70f4e5359d870a63584dacbc324b54d73e6b3267920ff117/pandas-2.3.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:bb3be958022198531eb7ec2008cfc78c5b1eed51af8600c6c5d9160d89d8d249", size = 13231847, upload-time = "2025-06-05T03:27:51.465Z" },
]
[[package]]
name = "patsy"
version = "1.0.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "numpy" },
]
sdist = { url = "https://files.pythonhosted.org/packages/d1/81/74f6a65b848ffd16c18f920620ce999fe45fe27f01ab3911260ce4ed85e4/patsy-1.0.1.tar.gz", hash = "sha256:e786a9391eec818c054e359b737bbce692f051aee4c661f4141cc88fb459c0c4", size = 396010, upload-time = "2024-11-12T14:10:54.642Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/87/2b/b50d3d08ea0fc419c183a84210571eba005328efa62b6b98bc28e9ead32a/patsy-1.0.1-py2.py3-none-any.whl", hash = "sha256:751fb38f9e97e62312e921a1954b81e1bb2bcda4f5eeabaf94db251ee791509c", size = 232923, upload-time = "2024-11-12T14:10:52.85Z" },
]
[[package]]
name = "pillow"
version = "11.2.1"
@ -1606,6 +1620,27 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/52/a7/d2782e4e3f77c8450f727ba74a8f12756d5ba823d81b941f1b04da9d033a/sphinxcontrib_serializinghtml-2.0.0-py3-none-any.whl", hash = "sha256:6e2cb0eef194e10c27ec0023bfeb25badbbb5868244cf5bc5bdc04e4464bf331", size = 92072, upload-time = "2024-07-29T01:10:08.203Z" },
]
[[package]]
name = "statsmodels"
version = "0.14.4"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "numpy" },
{ name = "packaging" },
{ name = "pandas" },
{ name = "patsy" },
{ name = "scipy" },
]
sdist = { url = "https://files.pythonhosted.org/packages/1f/3b/963a015dd8ea17e10c7b0e2f14d7c4daec903baf60a017e756b57953a4bf/statsmodels-0.14.4.tar.gz", hash = "sha256:5d69e0f39060dc72c067f9bb6e8033b6dccdb0bae101d76a7ef0bcc94e898b67", size = 20354802, upload-time = "2024-10-03T16:15:36.273Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/31/f8/2662e6a101315ad336f75168fa9bac71f913ebcb92a6be84031d84a0f21f/statsmodels-0.14.4-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b5a24f5d2c22852d807d2b42daf3a61740820b28d8381daaf59dcb7055bf1a79", size = 10186886, upload-time = "2024-10-03T17:10:44.074Z" },
{ url = "https://files.pythonhosted.org/packages/fa/c0/ee6e8ed35fc1ca9c7538c592f4974547bf72274bc98db1ae4a6e87481a83/statsmodels-0.14.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:df4f7864606fa843d7e7c0e6af288f034a2160dba14e6ccc09020a3cf67cb092", size = 9880066, upload-time = "2024-10-03T17:10:56.972Z" },
{ url = "https://files.pythonhosted.org/packages/d1/97/3380ca6d8fd66cfb3d12941e472642f26e781a311c355a4e97aab2ed0216/statsmodels-0.14.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91341cbde9e8bea5fb419a76e09114e221567d03f34ca26e6d67ae2c27d8fe3c", size = 10283521, upload-time = "2024-10-03T17:14:06.216Z" },
{ url = "https://files.pythonhosted.org/packages/fe/2a/55c5b5c5e5124a202ea3fe0bcdbdeceaf91b4ec6164b8434acb9dd97409c/statsmodels-0.14.4-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1322286a7bfdde2790bf72d29698a1b76c20b8423a55bdcd0d457969d0041f72", size = 10723228, upload-time = "2024-10-03T17:14:19.587Z" },
{ url = "https://files.pythonhosted.org/packages/4f/76/67747e49dc758daae06f33aad8247b718cd7d224f091d2cd552681215bb2/statsmodels-0.14.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e31b95ac603415887c9f0d344cb523889cf779bc52d68e27e2d23c358958fec7", size = 10859503, upload-time = "2024-10-03T17:14:32.798Z" },
{ url = "https://files.pythonhosted.org/packages/1d/eb/cb8b01f5edf8f135eb3d0553d159db113a35b2948d0e51eeb735e7ae09ea/statsmodels-0.14.4-cp313-cp313-win_amd64.whl", hash = "sha256:81030108d27aecc7995cac05aa280cf8c6025f6a6119894eef648997936c2dd0", size = 9817574, upload-time = "2024-10-03T16:14:37.461Z" },
]
[[package]]
name = "tabulate"
version = "0.9.0"