diff --git a/ai_diplomacy/game_logic.py b/ai_diplomacy/game_logic.py index 5ec7277..9f893b2 100644 --- a/ai_diplomacy/game_logic.py +++ b/ai_diplomacy/game_logic.py @@ -6,6 +6,7 @@ import asyncio from typing import Dict, Tuple, Optional, Any from argparse import Namespace from pathlib import Path +import re from diplomacy import Game from diplomacy.utils.export import to_saved_game_format, from_saved_game_format @@ -65,28 +66,17 @@ def deserialize_agent(agent_data: dict, prompts_dir: Optional[str] = None, *, ov # --- State Management --- -# game_logic.py -_PHASE_ORDER = ["M", "R", "A"] # Movement → Retreats → Adjustments +_PHASE_RE = re.compile(r"^[SW](\d{4})[MRA]$") - -def _next_phase_name(short: str) -> str: +def _phase_year(phase_name: str) -> Optional[int]: """ - Return the Diplomacy phase string that chronologically follows *short*. - (E.g. S1901M → S1901R, S1901R → W1901A, W1901A → S1902M) + Return the four-digit year encoded in standard phase strings + like 'S1901M'. For anything non-standard (e.g. 'COMPLETE') + return None so callers can decide how to handle it. """ - season = short[0] # 'S' | 'W' - year = int(short[1:5]) - typ = short[-1] # 'M' | 'R' | 'A' + m = _PHASE_RE.match(phase_name) + return int(m.group(1)) if m else None - idx = _PHASE_ORDER.index(typ) - if idx < 2: # still in the same season - return f"{season}{year}{_PHASE_ORDER[idx + 1]}" - - # typ was 'A' → roll season - if season == "S": # summer → winter, same year - return f"W{year}M" - else: # winter→ spring, next year - return f"S{year + 1}M" def save_game_state( @@ -139,7 +129,8 @@ def save_game_state( current_state_agents = {p_name: serialize_agent(p_agent) for p_name, p_agent in agents.items() if not game.powers[p_name].is_eliminated()} for phase_block in saved_game.get("phases", []): - if int(phase_block["name"][1:5]) > run_config.max_year: + year_val = _phase_year(phase_block["name"]) + if year_val is not None and year_val > run_config.max_year: break phase_name = phase_block["name"] diff --git a/experiment_runner/analysis/compare_stats.py b/experiment_runner/analysis/compare_stats.py index 50ad706..d72bb5f 100644 --- a/experiment_runner/analysis/compare_stats.py +++ b/experiment_runner/analysis/compare_stats.py @@ -55,28 +55,22 @@ def _numeric_columns(df: pd.DataFrame) -> List[str]: def _load_games(exp: Path) -> pd.DataFrame: """ Return a DataFrame with one row per (game_id, power_name) containing - all numeric columns from *_game_analysis.csv plus these derived - columns: + all numeric columns from *_game_analysis.csv plus these derived columns: - max_supply_centers_owned - max_territories_controlled - max_military_units (all per-power maxima across phases) - - max_game_score (max across powers within the game) - - The phase files live under .../analysis/** and are searched - recursively so the script works with both “individual” and - “combined” layouts. + max_supply_centers_owned – per-power max across phases + max_territories_controlled – per-power max across phases + max_military_units – per-power max across phases + max_game_score – game-level max across powers """ root = exp / "analysis" - # ---------- game-level CSVs --------------------------------- + # ----------- game-level CSVs ----------------------------------------- game_csvs = list(root.rglob("*_game_analysis.csv")) if not game_csvs: raise FileNotFoundError(f"no *_game_analysis.csv found under {root}") df_game = pd.concat((pd.read_csv(p) for p in game_csvs), ignore_index=True) - # ---------- derive max_game_score --------------------------- + # ----------- derive max_game_score ----------------------------------- if "game_score" in df_game.columns: df_game["max_game_score"] = ( df_game.groupby("game_id")["game_score"].transform("max") @@ -84,11 +78,10 @@ def _load_games(exp: Path) -> pd.DataFrame: else: df_game["max_game_score"] = np.nan - # ---------- phase-level maxima for the other three ---------- + # ----------- per-power maxima from phase files ----------------------- phase_csvs = list(root.rglob("*_phase_analysis.csv")) if phase_csvs: df_phase = pd.concat((pd.read_csv(p) for p in phase_csvs), ignore_index=True) - mapping = { "supply_centers_owned_count": "max_supply_centers_owned", "territories_controlled_count": "max_territories_controlled", @@ -104,14 +97,18 @@ def _load_games(exp: Path) -> pd.DataFrame: ) df_game = df_game.merge(max_df, on=["game_id", "power_name"], how="left") - # ensure all four columns exist + # ----------- guarantee all max-columns exist ------------------------- for col in _MAX_METRICS: if col not in df_game.columns: df_game[col] = np.nan + # ----------- critical de-duplication (fixes doubled n) --------------- + df_game = df_game.drop_duplicates(subset=["game_id", "power_name"], keep="first") + return df_game + # ───────────────────── Welch statistics ────────────────────── def _welch(a: np.ndarray, b: np.ndarray, alpha: float) -> Dict: _t, p_val = stats.ttest_ind(a, b, equal_var=False) @@ -162,6 +159,92 @@ def _significant(df: pd.DataFrame, alpha: float) -> pd.DataFrame: return df[keep].sort_values("p_value").reset_index(drop=True) +# ---------- phase-level helpers ----------------------------------------- +def _load_phase(exp: Path) -> pd.DataFrame: + root = exp / "analysis" + phase_csvs = list(root.rglob("*_phase_analysis.csv")) + if not phase_csvs: + raise FileNotFoundError(f"no *_phase_analysis.csv found under {root}") + return pd.concat((pd.read_csv(p) for p in phase_csvs), ignore_index=True) + + +def _phase_index(ph_series: pd.Series) -> pd.Series: + _SEASON_ORDER = {"S": 0, "F": 1, "W": 2, "A": 3} + def _key(ph: str) -> tuple[int, int]: + year = int(ph[1:5]) if len(ph) >= 5 and ph[1:5].isdigit() else 0 + season = _SEASON_ORDER.get(ph[0], 9) + return year, season + uniq = sorted(ph_series.unique(), key=_key) + return ph_series.map({ph: i for i, ph in enumerate(uniq)}) + + +def _plot_phase_overlay(exp_a: Path, exp_b: Path, out_dir: Path) -> None: + import seaborn as sns + import matplotlib.pyplot as plt + + df_a = _load_phase(exp_a) + df_b = _load_phase(exp_b) + tag_a, tag_b = exp_a.name or str(exp_a), exp_b.name or str(exp_b) + + df_a["experiment"] = tag_a + df_b["experiment"] = tag_b + df = pd.concat([df_a, df_b], ignore_index=True) + + if "phase_index" not in df.columns: + df["phase_index"] = _phase_index(df["game_phase"]) + + num_cols = [c for c in df.select_dtypes("number").columns + if c not in _EXCLUDE and c != "phase_index"] + + # aggregate across games: mean per phase × power × experiment + agg = ( + df.groupby(["experiment", "phase_index", "game_phase", "power_name"], + as_index=False)[num_cols] + .mean() + ) + + palette = sns.color_palette("tab10", n_colors=len(agg["power_name"].unique())) + power_colors = dict(zip(sorted(agg["power_name"].unique()), palette)) + + out_dir.mkdir(parents=True, exist_ok=True) + n_phases = agg["phase_index"].nunique() + fig_w = max(8, n_phases * 0.1 + 4) + + for col in num_cols: + plt.figure(figsize=(fig_w, 6)) + for power in sorted(agg["power_name"].unique()): + for exp_tag, style in [(tag_a, "--"), (tag_b, "-")]: + sub = agg[(agg["power_name"] == power) & + (agg["experiment"] == exp_tag)] + if sub.empty: + continue + plt.plot( + sub["phase_index"], + sub[col], + linestyle=style, + color=power_colors[power], + marker="o", + label=f"{power} – {exp_tag}", + ) + + phases_sorted = ( + agg.drop_duplicates("phase_index") + .sort_values("phase_index")[["phase_index", "game_phase"]] + ) + plt.xticks( + phases_sorted["phase_index"], + phases_sorted["game_phase"], + rotation=90, + fontsize=8, + ) + plt.title(col.replace("_", " ").title()) + plt.xlabel("Game Phase") + plt.legend(ncol=2, fontsize=8) + plt.tight_layout() + plt.savefig(out_dir / f"{col}.png", dpi=140) + plt.close() + + # ───────────────────────── public API ───────────────────────── def run(exp_a: Path, exp_b: Path, alpha: float = 0.05) -> None: df_a = _load_games(exp_a) @@ -177,7 +260,7 @@ def run(exp_a: Path, exp_b: Path, alpha: float = 0.05) -> None: tag_a = exp_a.name or str(exp_a) tag_b = exp_b.name or str(exp_b) - out_dir = exp_a / "analysis" / "comparison" + out_dir = exp_b / "analysis" / "comparison" out_dir.mkdir(parents=True, exist_ok=True) # ── section 1: aggregated across powers ─────────────────── @@ -252,3 +335,15 @@ def run(exp_a: Path, exp_b: Path, alpha: float = 0.05) -> None: print("\nCSV outputs:") print(f" • {agg_csv}") print(f" • {pow_csv}") + + + print('\n\nGenerating plots...') + # overlay phase-level plots + try: + _plot_phase_overlay(exp_a, exp_b, out_dir / "phase_overlay") + print(f"\nPhase overlay plots → {out_dir / 'phase_overlay'}") + except Exception as exc: + print(f"\n[warning] phase overlay plot generation failed: {exc}") + + print('Complete') + diff --git a/experiment_runner/analysis/statistical_game_analysis.py b/experiment_runner/analysis/statistical_game_analysis.py index 35e4709..67f783f 100644 --- a/experiment_runner/analysis/statistical_game_analysis.py +++ b/experiment_runner/analysis/statistical_game_analysis.py @@ -26,6 +26,7 @@ from __future__ import annotations import logging import re +import json from pathlib import Path from typing import List @@ -64,7 +65,16 @@ def _numeric_columns(df: pd.DataFrame, extra_exclude: set[str] | None = None) -> def _phase_sort_key(ph: str) -> tuple[int, int]: - """Convert 'S1901M' → (1901, 0).""" + """ + Sort key that keeps normal phases chronological and forces the literal + string 'COMPLETED' to the very end. + + • 'S1901M' → (1901, 0) + • 'COMPLETED' → (9999, 9) + """ + if ph.upper() == "COMPLETED": + return (9999, 9) # always last + year = int(ph[1:5]) if len(ph) >= 5 and ph[1:5].isdigit() else 0 season = _SEASON_ORDER.get(ph[0], 9) return year, season @@ -75,6 +85,27 @@ def _phase_index(series: pd.Series) -> pd.Series: mapping = {ph: i for i, ph in enumerate(uniq)} return series.map(mapping) +def _map_game_id_to_run_dir(exp_dir: Path) -> dict[str, str]: + """ + Reads each runs/run_xxxxx/lmvsgame.json file and returns + {game_id_string: 'run_xxxxx'}. + """ + mapping: dict[str, str] = {} + runs_root = exp_dir / "runs" + for run_dir in runs_root.glob("run_*"): + json_path = run_dir / "lmvsgame.json" + if not json_path.exists(): + continue + try: + with json_path.open(encoding="utf-8") as fh: + data = json.load(fh) + gid = str(data.get("id", "")) # use top-level "id" + if gid: + mapping[gid] = run_dir.name + except Exception: # corrupt / unreadable → skip + continue + return mapping + # ───────────────────────── plots ──────────────────────────── def _plot_game_level(all_games: pd.DataFrame, plot_dir: Path) -> None: @@ -130,7 +161,15 @@ def _plot_game_level(all_games: pd.DataFrame, plot_dir: Path) -> None: -def _plot_phase_level(all_phase: pd.DataFrame, plot_dir: Path) -> None: +def _plot_phase_level( + all_phase: pd.DataFrame, + plot_dir: Path, + title_suffix: str = "", +) -> None: + """ + Plots aggregated phase metrics. If *title_suffix* is supplied it is + appended to each chart title — handy for per-run plots. + """ if all_phase.empty: return plot_dir.mkdir(parents=True, exist_ok=True) @@ -147,10 +186,10 @@ def _plot_phase_level(all_phase: pd.DataFrame, plot_dir: Path) -> None: ) n_phases = agg["phase_index"].nunique() - fig_base_width = max(8, n_phases * 0.1 + 4) # 0.45 in per label + padding + fig_w = max(8, n_phases * 0.1 + 4) for col in num_cols: - plt.figure(figsize=(fig_base_width, 6)) + plt.figure(figsize=(fig_w, 6)) sns.lineplot( data=agg, x="phase_index", @@ -169,13 +208,40 @@ def _plot_phase_level(all_phase: pd.DataFrame, plot_dir: Path) -> None: rotation=90, fontsize=8, ) + title = col.replace("_", " ").title() + if title_suffix: + title = f"{title} – {title_suffix}" + plt.title(title) plt.xlabel("Game Phase") - plt.title(col.replace("_", " ").title()) plt.tight_layout() plt.savefig(plot_dir / f"{_sanitize(col)}.png", dpi=140) plt.close() +def _plot_phase_level_per_game( + all_phase: pd.DataFrame, + root_dir: Path, + gameid_to_rundir: dict[str, str], +) -> None: + """ + Writes one folder of phase-plots per iteration. + Folder name and chart titles use the run directory (e.g. run_00003). + """ + if all_phase.empty or "game_id" not in all_phase.columns: + return + + for game_id, sub in all_phase.groupby("game_id"): + run_label = gameid_to_rundir.get(str(game_id), f"game_{_sanitize(str(game_id))}") + target = root_dir / run_label + + # ── critical change: drop global phase_index so we rebuild a dense one ── + sub = sub.copy().drop(columns=["phase_index"], errors="ignore") + + _plot_phase_level(sub, target, title_suffix=run_label) + + + + # ───────────────────────── entry-point ───────────────────────── def run(experiment_dir: Path, ctx: dict) -> None: # pylint: disable=unused-argument root = experiment_dir / "analysis" / "statistical_game_analysis" @@ -210,5 +276,12 @@ def run(experiment_dir: Path, ctx: dict) -> None: # pylint: disable=unused-argu sns.set_theme(style="whitegrid") _plot_game_level(all_game_df, plots_root / "game") _plot_phase_level(all_phase_df, plots_root / "phase") + game_map = _map_game_id_to_run_dir(experiment_dir) + _plot_phase_level_per_game( + all_phase_df, + plots_root / "phase_by_game", + game_map, + ) + log.info("statistical_game_analysis: plots written → %s", plots_root) diff --git a/lm_game.py b/lm_game.py index 12a3012..e7232e5 100644 --- a/lm_game.py +++ b/lm_game.py @@ -54,6 +54,10 @@ def _str2bool(v: str) -> bool: return False raise argparse.ArgumentTypeError(f"Boolean value expected, got '{v}'") +def _detect_victory(game: Game, threshold: int = 18) -> bool: + """True iff any power already owns ≥ `threshold` supply centres.""" + return any(len(p.centers) >= threshold for p in game.powers.values()) + def parse_arguments(): parser = argparse.ArgumentParser( description="Run a Diplomacy game simulation with configurable parameters." @@ -280,6 +284,13 @@ async def main(): game.phase_summaries = {} agents = await initialize_new_game(run_config, game, game_history, llm_log_file_path) + if _detect_victory(game): + game.is_game_done = True # short-circuit the main loop + logger.info( + "Game already complete on load – a power has ≥18 centres " + f"(current phase {game.get_current_phase()})." + ) + # --- 4. Main Game Loop --- while not game.is_game_done: phase_start = time.time()