add simplified prompts

This commit is contained in:
sam-paech 2025-06-27 14:42:05 +10:00
parent 0bd909b30b
commit ebf26cf8a6
33 changed files with 1762 additions and 143 deletions

View file

@ -0,0 +1,214 @@
"""
experiment_runner.analysis.statistical_game_analysis
----------------------------------------------------
Runs the Statistical Game Analyzer to create per-run / combined CSVs,
then produces a suite of PNG plots:
analysis/
statistical_game_analysis/
individual/
run_00000_game_analysis.csv
plots/
game/
final_supply_centers_owned.png
game_summary_heatmap.png
phase/
supply_centers_owned_count.png
Complies with experiment-runners plug-in contract:
run(experiment_dir: pathlib.Path, ctx: dict) -> None
"""
from __future__ import annotations
import logging
import re
from pathlib import Path
from typing import List
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
# third-party analyser that creates the CSVs
from analysis.statistical_game_analysis import StatisticalGameAnalyzer # type: ignore
log = logging.getLogger(__name__)
# ───────────────────────── helpers ──────────────────────────
_SEASON_ORDER = {"S": 0, "F": 1, "W": 2, "A": 3}
def _sanitize(name: str) -> str:
return re.sub(r"[^\w\-\.]", "_", name)
def _discover_csvs(individual_dir: Path, pattern: str) -> List[Path]:
return sorted(individual_dir.glob(pattern))
def _numeric_columns(df: pd.DataFrame, extra_exclude: set[str] | None = None) -> List[str]:
exclude = {
"game_id",
"llm_model",
"power_name",
"game_phase",
"analyzed_response_type",
}
if extra_exclude:
exclude |= extra_exclude
return [c for c in df.select_dtypes("number").columns if c not in exclude]
def _phase_sort_key(ph: str) -> tuple[int, int]:
"""Convert 'S1901M' → (1901, 0)."""
year = int(ph[1:5]) if len(ph) >= 5 and ph[1:5].isdigit() else 0
season = _SEASON_ORDER.get(ph[0], 9)
return year, season
def _phase_index(series: pd.Series) -> pd.Series:
uniq = sorted(series.unique(), key=_phase_sort_key)
mapping = {ph: i for i, ph in enumerate(uniq)}
return series.map(mapping)
# ───────────────────────── plots ────────────────────────────
def _plot_game_level(all_games: pd.DataFrame, plot_dir: Path) -> None:
"""
Box-plots per metric (hue = power, legend removed).
Z-score heat-map: powers × metrics, colour-coded by relative standing.
"""
plot_dir.mkdir(parents=True, exist_ok=True)
num_cols = _numeric_columns(all_games)
# ── per-metric box-plots ──────────────────────────────────────────
for col in num_cols:
fig, ax = plt.subplots(figsize=(8, 6))
sns.boxplot(
data=all_games,
x="power_name",
y=col,
hue="power_name",
palette="pastel",
dodge=False,
ax=ax,
)
leg = ax.get_legend()
if leg is not None:
leg.remove()
ax.set_title(col.replace("_", " ").title())
fig.tight_layout()
fig.savefig(plot_dir / f"{_sanitize(col)}.png", dpi=140)
plt.close(fig)
# ── summary heat-map (column-wise z-scores) ───────────────────────
# 1) mean across runs 2) z-score each column
summary = all_games.groupby("power_name")[num_cols].mean().sort_index()
zscores = summary.apply(lambda col: (col - col.mean()) / col.std(ddof=0), axis=0)
fig_w = max(6, len(num_cols) * 0.45 + 2)
fig, ax = plt.subplots(figsize=(fig_w, 6))
sns.heatmap(
zscores,
cmap="coolwarm",
center=0,
linewidths=0.4,
annot=True,
fmt=".2f",
ax=ax,
)
ax.set_title("Relative Standing (column-wise z-score)")
ax.set_ylabel("Power")
fig.tight_layout()
fig.savefig(plot_dir.parent / "game_summary_zscore_heatmap.png", dpi=160)
plt.close(fig)
def _plot_phase_level(all_phase: pd.DataFrame, plot_dir: Path) -> None:
if all_phase.empty:
return
plot_dir.mkdir(parents=True, exist_ok=True)
if "phase_index" not in all_phase.columns:
all_phase["phase_index"] = _phase_index(all_phase["game_phase"])
num_cols = _numeric_columns(all_phase)
agg = (
all_phase
.groupby(["phase_index", "game_phase", "power_name"], as_index=False)[num_cols]
.mean()
)
n_phases = agg["phase_index"].nunique()
fig_base_width = max(8, n_phases * 0.1 + 4) # 0.45 in per label + padding
for col in num_cols:
plt.figure(figsize=(fig_base_width, 6))
sns.lineplot(
data=agg,
x="phase_index",
y=col,
hue="power_name",
marker="o",
)
phases_sorted = (
agg.drop_duplicates("phase_index")
.sort_values("phase_index")[["phase_index", "game_phase"]]
)
plt.xticks(
phases_sorted["phase_index"],
phases_sorted["game_phase"],
rotation=90,
fontsize=8,
)
plt.xlabel("Game Phase")
plt.title(col.replace("_", " ").title())
plt.tight_layout()
plt.savefig(plot_dir / f"{_sanitize(col)}.png", dpi=140)
plt.close()
# ───────────────────────── entry-point ─────────────────────────
def run(experiment_dir: Path, ctx: dict) -> None: # pylint: disable=unused-argument
root = experiment_dir / "analysis" / "statistical_game_analysis"
indiv_dir = root / "individual"
plots_root = root / "plots"
# 1. (re)generate CSVs
try:
StatisticalGameAnalyzer().analyze_multiple_folders(
str(experiment_dir / "runs"), str(root)
)
log.info("statistical_game_analysis: CSV generation complete")
except Exception as exc: # noqa: broad-except
log.exception("statistical_game_analysis: CSV generation failed %s", exc)
return
# 2. load CSVs
game_csvs = _discover_csvs(indiv_dir, "*_game_analysis.csv")
phase_csvs = _discover_csvs(indiv_dir, "*_phase_analysis.csv")
if not game_csvs:
log.warning("statistical_game_analysis: no *_game_analysis.csv found")
return
all_game_df = pd.concat((pd.read_csv(p) for p in game_csvs), ignore_index=True)
all_phase_df = (
pd.concat((pd.read_csv(p) for p in phase_csvs), ignore_index=True)
if phase_csvs else pd.DataFrame()
)
# 3. plots
sns.set_theme(style="whitegrid")
_plot_game_level(all_game_df, plots_root / "game")
_plot_phase_level(all_phase_df, plots_root / "phase")
log.info("statistical_game_analysis: plots written → %s", plots_root)