mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-24 17:05:04 +00:00
savegame fix + chart updates
This commit is contained in:
parent
7edc7c465f
commit
e351aa3841
4 changed files with 211 additions and 41 deletions
|
|
@ -55,28 +55,22 @@ def _numeric_columns(df: pd.DataFrame) -> List[str]:
|
|||
def _load_games(exp: Path) -> pd.DataFrame:
|
||||
"""
|
||||
Return a DataFrame with one row per (game_id, power_name) containing
|
||||
all numeric columns from *_game_analysis.csv plus these derived
|
||||
columns:
|
||||
all numeric columns from *_game_analysis.csv plus these derived columns:
|
||||
|
||||
max_supply_centers_owned
|
||||
max_territories_controlled
|
||||
max_military_units (all per-power maxima across phases)
|
||||
|
||||
max_game_score (max across powers within the game)
|
||||
|
||||
The phase files live under .../analysis/** and are searched
|
||||
recursively so the script works with both “individual” and
|
||||
“combined” layouts.
|
||||
max_supply_centers_owned – per-power max across phases
|
||||
max_territories_controlled – per-power max across phases
|
||||
max_military_units – per-power max across phases
|
||||
max_game_score – game-level max across powers
|
||||
"""
|
||||
root = exp / "analysis"
|
||||
|
||||
# ---------- game-level CSVs ---------------------------------
|
||||
# ----------- game-level CSVs -----------------------------------------
|
||||
game_csvs = list(root.rglob("*_game_analysis.csv"))
|
||||
if not game_csvs:
|
||||
raise FileNotFoundError(f"no *_game_analysis.csv found under {root}")
|
||||
df_game = pd.concat((pd.read_csv(p) for p in game_csvs), ignore_index=True)
|
||||
|
||||
# ---------- derive max_game_score ---------------------------
|
||||
# ----------- derive max_game_score -----------------------------------
|
||||
if "game_score" in df_game.columns:
|
||||
df_game["max_game_score"] = (
|
||||
df_game.groupby("game_id")["game_score"].transform("max")
|
||||
|
|
@ -84,11 +78,10 @@ def _load_games(exp: Path) -> pd.DataFrame:
|
|||
else:
|
||||
df_game["max_game_score"] = np.nan
|
||||
|
||||
# ---------- phase-level maxima for the other three ----------
|
||||
# ----------- per-power maxima from phase files -----------------------
|
||||
phase_csvs = list(root.rglob("*_phase_analysis.csv"))
|
||||
if phase_csvs:
|
||||
df_phase = pd.concat((pd.read_csv(p) for p in phase_csvs), ignore_index=True)
|
||||
|
||||
mapping = {
|
||||
"supply_centers_owned_count": "max_supply_centers_owned",
|
||||
"territories_controlled_count": "max_territories_controlled",
|
||||
|
|
@ -104,14 +97,18 @@ def _load_games(exp: Path) -> pd.DataFrame:
|
|||
)
|
||||
df_game = df_game.merge(max_df, on=["game_id", "power_name"], how="left")
|
||||
|
||||
# ensure all four columns exist
|
||||
# ----------- guarantee all max-columns exist -------------------------
|
||||
for col in _MAX_METRICS:
|
||||
if col not in df_game.columns:
|
||||
df_game[col] = np.nan
|
||||
|
||||
# ----------- critical de-duplication (fixes doubled n) ---------------
|
||||
df_game = df_game.drop_duplicates(subset=["game_id", "power_name"], keep="first")
|
||||
|
||||
return df_game
|
||||
|
||||
|
||||
|
||||
# ───────────────────── Welch statistics ──────────────────────
|
||||
def _welch(a: np.ndarray, b: np.ndarray, alpha: float) -> Dict:
|
||||
_t, p_val = stats.ttest_ind(a, b, equal_var=False)
|
||||
|
|
@ -162,6 +159,92 @@ def _significant(df: pd.DataFrame, alpha: float) -> pd.DataFrame:
|
|||
return df[keep].sort_values("p_value").reset_index(drop=True)
|
||||
|
||||
|
||||
# ---------- phase-level helpers -----------------------------------------
|
||||
def _load_phase(exp: Path) -> pd.DataFrame:
|
||||
root = exp / "analysis"
|
||||
phase_csvs = list(root.rglob("*_phase_analysis.csv"))
|
||||
if not phase_csvs:
|
||||
raise FileNotFoundError(f"no *_phase_analysis.csv found under {root}")
|
||||
return pd.concat((pd.read_csv(p) for p in phase_csvs), ignore_index=True)
|
||||
|
||||
|
||||
def _phase_index(ph_series: pd.Series) -> pd.Series:
|
||||
_SEASON_ORDER = {"S": 0, "F": 1, "W": 2, "A": 3}
|
||||
def _key(ph: str) -> tuple[int, int]:
|
||||
year = int(ph[1:5]) if len(ph) >= 5 and ph[1:5].isdigit() else 0
|
||||
season = _SEASON_ORDER.get(ph[0], 9)
|
||||
return year, season
|
||||
uniq = sorted(ph_series.unique(), key=_key)
|
||||
return ph_series.map({ph: i for i, ph in enumerate(uniq)})
|
||||
|
||||
|
||||
def _plot_phase_overlay(exp_a: Path, exp_b: Path, out_dir: Path) -> None:
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
df_a = _load_phase(exp_a)
|
||||
df_b = _load_phase(exp_b)
|
||||
tag_a, tag_b = exp_a.name or str(exp_a), exp_b.name or str(exp_b)
|
||||
|
||||
df_a["experiment"] = tag_a
|
||||
df_b["experiment"] = tag_b
|
||||
df = pd.concat([df_a, df_b], ignore_index=True)
|
||||
|
||||
if "phase_index" not in df.columns:
|
||||
df["phase_index"] = _phase_index(df["game_phase"])
|
||||
|
||||
num_cols = [c for c in df.select_dtypes("number").columns
|
||||
if c not in _EXCLUDE and c != "phase_index"]
|
||||
|
||||
# aggregate across games: mean per phase × power × experiment
|
||||
agg = (
|
||||
df.groupby(["experiment", "phase_index", "game_phase", "power_name"],
|
||||
as_index=False)[num_cols]
|
||||
.mean()
|
||||
)
|
||||
|
||||
palette = sns.color_palette("tab10", n_colors=len(agg["power_name"].unique()))
|
||||
power_colors = dict(zip(sorted(agg["power_name"].unique()), palette))
|
||||
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
n_phases = agg["phase_index"].nunique()
|
||||
fig_w = max(8, n_phases * 0.1 + 4)
|
||||
|
||||
for col in num_cols:
|
||||
plt.figure(figsize=(fig_w, 6))
|
||||
for power in sorted(agg["power_name"].unique()):
|
||||
for exp_tag, style in [(tag_a, "--"), (tag_b, "-")]:
|
||||
sub = agg[(agg["power_name"] == power) &
|
||||
(agg["experiment"] == exp_tag)]
|
||||
if sub.empty:
|
||||
continue
|
||||
plt.plot(
|
||||
sub["phase_index"],
|
||||
sub[col],
|
||||
linestyle=style,
|
||||
color=power_colors[power],
|
||||
marker="o",
|
||||
label=f"{power} – {exp_tag}",
|
||||
)
|
||||
|
||||
phases_sorted = (
|
||||
agg.drop_duplicates("phase_index")
|
||||
.sort_values("phase_index")[["phase_index", "game_phase"]]
|
||||
)
|
||||
plt.xticks(
|
||||
phases_sorted["phase_index"],
|
||||
phases_sorted["game_phase"],
|
||||
rotation=90,
|
||||
fontsize=8,
|
||||
)
|
||||
plt.title(col.replace("_", " ").title())
|
||||
plt.xlabel("Game Phase")
|
||||
plt.legend(ncol=2, fontsize=8)
|
||||
plt.tight_layout()
|
||||
plt.savefig(out_dir / f"{col}.png", dpi=140)
|
||||
plt.close()
|
||||
|
||||
|
||||
# ───────────────────────── public API ─────────────────────────
|
||||
def run(exp_a: Path, exp_b: Path, alpha: float = 0.05) -> None:
|
||||
df_a = _load_games(exp_a)
|
||||
|
|
@ -177,7 +260,7 @@ def run(exp_a: Path, exp_b: Path, alpha: float = 0.05) -> None:
|
|||
tag_a = exp_a.name or str(exp_a)
|
||||
tag_b = exp_b.name or str(exp_b)
|
||||
|
||||
out_dir = exp_a / "analysis" / "comparison"
|
||||
out_dir = exp_b / "analysis" / "comparison"
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ── section 1: aggregated across powers ───────────────────
|
||||
|
|
@ -252,3 +335,15 @@ def run(exp_a: Path, exp_b: Path, alpha: float = 0.05) -> None:
|
|||
print("\nCSV outputs:")
|
||||
print(f" • {agg_csv}")
|
||||
print(f" • {pow_csv}")
|
||||
|
||||
|
||||
print('\n\nGenerating plots...')
|
||||
# overlay phase-level plots
|
||||
try:
|
||||
_plot_phase_overlay(exp_a, exp_b, out_dir / "phase_overlay")
|
||||
print(f"\nPhase overlay plots → {out_dir / 'phase_overlay'}")
|
||||
except Exception as exc:
|
||||
print(f"\n[warning] phase overlay plot generation failed: {exc}")
|
||||
|
||||
print('Complete')
|
||||
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from __future__ import annotations
|
|||
|
||||
import logging
|
||||
import re
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
|
|
@ -64,7 +65,16 @@ def _numeric_columns(df: pd.DataFrame, extra_exclude: set[str] | None = None) ->
|
|||
|
||||
|
||||
def _phase_sort_key(ph: str) -> tuple[int, int]:
|
||||
"""Convert 'S1901M' → (1901, 0)."""
|
||||
"""
|
||||
Sort key that keeps normal phases chronological and forces the literal
|
||||
string 'COMPLETED' to the very end.
|
||||
|
||||
• 'S1901M' → (1901, 0)
|
||||
• 'COMPLETED' → (9999, 9)
|
||||
"""
|
||||
if ph.upper() == "COMPLETED":
|
||||
return (9999, 9) # always last
|
||||
|
||||
year = int(ph[1:5]) if len(ph) >= 5 and ph[1:5].isdigit() else 0
|
||||
season = _SEASON_ORDER.get(ph[0], 9)
|
||||
return year, season
|
||||
|
|
@ -75,6 +85,27 @@ def _phase_index(series: pd.Series) -> pd.Series:
|
|||
mapping = {ph: i for i, ph in enumerate(uniq)}
|
||||
return series.map(mapping)
|
||||
|
||||
def _map_game_id_to_run_dir(exp_dir: Path) -> dict[str, str]:
|
||||
"""
|
||||
Reads each runs/run_xxxxx/lmvsgame.json file and returns
|
||||
{game_id_string: 'run_xxxxx'}.
|
||||
"""
|
||||
mapping: dict[str, str] = {}
|
||||
runs_root = exp_dir / "runs"
|
||||
for run_dir in runs_root.glob("run_*"):
|
||||
json_path = run_dir / "lmvsgame.json"
|
||||
if not json_path.exists():
|
||||
continue
|
||||
try:
|
||||
with json_path.open(encoding="utf-8") as fh:
|
||||
data = json.load(fh)
|
||||
gid = str(data.get("id", "")) # use top-level "id"
|
||||
if gid:
|
||||
mapping[gid] = run_dir.name
|
||||
except Exception: # corrupt / unreadable → skip
|
||||
continue
|
||||
return mapping
|
||||
|
||||
|
||||
# ───────────────────────── plots ────────────────────────────
|
||||
def _plot_game_level(all_games: pd.DataFrame, plot_dir: Path) -> None:
|
||||
|
|
@ -130,7 +161,15 @@ def _plot_game_level(all_games: pd.DataFrame, plot_dir: Path) -> None:
|
|||
|
||||
|
||||
|
||||
def _plot_phase_level(all_phase: pd.DataFrame, plot_dir: Path) -> None:
|
||||
def _plot_phase_level(
|
||||
all_phase: pd.DataFrame,
|
||||
plot_dir: Path,
|
||||
title_suffix: str = "",
|
||||
) -> None:
|
||||
"""
|
||||
Plots aggregated phase metrics. If *title_suffix* is supplied it is
|
||||
appended to each chart title — handy for per-run plots.
|
||||
"""
|
||||
if all_phase.empty:
|
||||
return
|
||||
plot_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -147,10 +186,10 @@ def _plot_phase_level(all_phase: pd.DataFrame, plot_dir: Path) -> None:
|
|||
)
|
||||
|
||||
n_phases = agg["phase_index"].nunique()
|
||||
fig_base_width = max(8, n_phases * 0.1 + 4) # 0.45 in per label + padding
|
||||
fig_w = max(8, n_phases * 0.1 + 4)
|
||||
|
||||
for col in num_cols:
|
||||
plt.figure(figsize=(fig_base_width, 6))
|
||||
plt.figure(figsize=(fig_w, 6))
|
||||
sns.lineplot(
|
||||
data=agg,
|
||||
x="phase_index",
|
||||
|
|
@ -169,13 +208,40 @@ def _plot_phase_level(all_phase: pd.DataFrame, plot_dir: Path) -> None:
|
|||
rotation=90,
|
||||
fontsize=8,
|
||||
)
|
||||
title = col.replace("_", " ").title()
|
||||
if title_suffix:
|
||||
title = f"{title} – {title_suffix}"
|
||||
plt.title(title)
|
||||
plt.xlabel("Game Phase")
|
||||
plt.title(col.replace("_", " ").title())
|
||||
plt.tight_layout()
|
||||
plt.savefig(plot_dir / f"{_sanitize(col)}.png", dpi=140)
|
||||
plt.close()
|
||||
|
||||
|
||||
def _plot_phase_level_per_game(
|
||||
all_phase: pd.DataFrame,
|
||||
root_dir: Path,
|
||||
gameid_to_rundir: dict[str, str],
|
||||
) -> None:
|
||||
"""
|
||||
Writes one folder of phase-plots per iteration.
|
||||
Folder name and chart titles use the run directory (e.g. run_00003).
|
||||
"""
|
||||
if all_phase.empty or "game_id" not in all_phase.columns:
|
||||
return
|
||||
|
||||
for game_id, sub in all_phase.groupby("game_id"):
|
||||
run_label = gameid_to_rundir.get(str(game_id), f"game_{_sanitize(str(game_id))}")
|
||||
target = root_dir / run_label
|
||||
|
||||
# ── critical change: drop global phase_index so we rebuild a dense one ──
|
||||
sub = sub.copy().drop(columns=["phase_index"], errors="ignore")
|
||||
|
||||
_plot_phase_level(sub, target, title_suffix=run_label)
|
||||
|
||||
|
||||
|
||||
|
||||
# ───────────────────────── entry-point ─────────────────────────
|
||||
def run(experiment_dir: Path, ctx: dict) -> None: # pylint: disable=unused-argument
|
||||
root = experiment_dir / "analysis" / "statistical_game_analysis"
|
||||
|
|
@ -210,5 +276,12 @@ def run(experiment_dir: Path, ctx: dict) -> None: # pylint: disable=unused-argu
|
|||
sns.set_theme(style="whitegrid")
|
||||
_plot_game_level(all_game_df, plots_root / "game")
|
||||
_plot_phase_level(all_phase_df, plots_root / "phase")
|
||||
game_map = _map_game_id_to_run_dir(experiment_dir)
|
||||
_plot_phase_level_per_game(
|
||||
all_phase_df,
|
||||
plots_root / "phase_by_game",
|
||||
game_map,
|
||||
)
|
||||
|
||||
|
||||
log.info("statistical_game_analysis: plots written → %s", plots_root)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue