add relationship plots

This commit is contained in:
sam-paech 2025-07-10 10:35:17 +10:00
parent af3fb8ce48
commit 70a876bcee
3 changed files with 216 additions and 7 deletions

View file

@ -28,9 +28,10 @@ import logging
import re
import json
from pathlib import Path
from typing import List
from typing import List, Any
import math
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgba
import pandas as pd
import seaborn as sns
@ -42,6 +43,16 @@ log = logging.getLogger(__name__)
# ───────────────────────── helpers ──────────────────────────
_SEASON_ORDER = {"S": 0, "F": 1, "W": 2, "A": 3}
_POWER_COLOUR = {
"AUSTRIA": "tab:red",
"ENGLAND": "tab:blue",
"FRANCE": "tab:green",
"GERMANY": "tab:purple",
"ITALY": "tab:olive",
"RUSSIA": "tab:brown",
"TURKEY": "tab:orange",
}
def _sanitize(name: str) -> str:
return re.sub(r"[^\w\-\.]", "_", name)
@ -64,6 +75,23 @@ def _numeric_columns(df: pd.DataFrame, extra_exclude: set[str] | None = None) ->
return [c for c in df.select_dtypes("number").columns if c not in exclude]
def _parse_relationships(rel_string: str) -> dict[str, int]:
"""
"AUSTRIA:-1|FRANCE:2" {"AUSTRIA": -1, "FRANCE": 2}
Returns empty dict on blank / nan / bad input.
"""
if not isinstance(rel_string, str) or not rel_string:
return {}
out: dict[str, int] = {}
for part in rel_string.split("|"):
try:
pwr, val = part.split(":")
out[pwr.strip().upper()] = int(val)
except ValueError:
continue
return out
def _phase_sort_key(ph: str) -> tuple[int, int]:
"""
Sort key that keeps normal phases chronological and forces the literal
@ -159,6 +187,165 @@ def _plot_game_level(all_games: pd.DataFrame, plot_dir: Path) -> None:
plt.close(fig)
def _plot_relationships_per_game(
all_phase: pd.DataFrame,
root_dir: Path,
gameid_to_rundir: dict[str, str],
) -> None:
"""
For each game, create one PNG per *focal* power that shows
self-perceived relationship to every other power (solid, full colour)
how that other power perceives the focal power (solid, lighter, thinner)
The x-axis is dense and specific to each game (0n-1) with tick labels set
to the actual phase strings, so there are no gaps no matter which phases
appear in different runs.
To keep coincident traces legible, points that would sit on top of one
another are given a minimal vertical jitter. Self points are nudged
down (negative), other points up (positive). Powers keep their canonical
AUSTRIA TURKEY ordering within each direction so that the visual code is
stable across games.
"""
if all_phase.empty or "game_id" not in all_phase.columns:
return
# ── ensure rel_dict column exists ────────────────────────────────────
if "rel_dict" not in all_phase.columns:
all_phase = all_phase.copy()
all_phase["rel_dict"] = all_phase["relationships"].apply(_parse_relationships)
powers = list(_POWER_COLOUR.keys()) # AUSTRIA … TURKEY
power_order = {p: i for i, p in enumerate(powers)}
jitter_step = 0.04 # vertical gap between stacked points
for game_id, game_df in all_phase.groupby("game_id", sort=False):
# dense phase ordering (0 … n-1)
phase_labels = sorted(game_df["game_phase"].unique(), key=_phase_sort_key)
phase_to_x = {ph: idx for idx, ph in enumerate(phase_labels)}
fig_w = max(8, len(phase_labels) * 0.1 + 4)
# quick lookup: (phase, power) → rel_dict
rel_lookup = {
(row.game_phase, row.power_name): row.rel_dict
for row in game_df.itertuples()
}
run_label = gameid_to_rundir.get(str(game_id), f"game_{_sanitize(str(game_id))}")
plot_dir = root_dir / run_label
plot_dir.mkdir(parents=True, exist_ok=True)
for focal in powers:
# ── pre-gather every trace so we can resolve collisions first ─
traces: dict[tuple[str, str], dict[str, Any]] = {}
for other in powers:
if other == focal:
continue
x_vals: list[int] = []
self_vals: list[float] = []
other_vals: list[float] = []
for ph in phase_labels:
x_vals.append(phase_to_x[ph])
# focals view of "other"
self_rels = rel_lookup.get((ph, focal), {})
self_vals.append(self_rels.get(other, float("nan")))
# others view of focal
other_rels = rel_lookup.get((ph, other), {})
other_vals.append(other_rels.get(focal, float("nan")))
traces[(other, "self")] = dict(x=x_vals, y=self_vals)
traces[(other, "other")] = dict(x=x_vals, y=other_vals)
# ── collision-aware jitter: build offset matrix ───────────────
offsets: dict[tuple[str, str], list[float]] = {
key: [0.0] * len(phase_labels) for key in traces
}
for idx in range(len(phase_labels)):
# group all non-NaN points by their integer y level
level_buckets: dict[float, list[tuple[str, str]]] = {}
for key, data in traces.items():
y = data["y"][idx]
if not math.isnan(y):
level_buckets.setdefault(y, []).append(key)
# de-stack each bucket independently
for y_val, bucket in level_buckets.items():
if len(bucket) == 1:
continue # nothing overlaps here
# split into self vs other, then power order
self_keys = sorted(
[k for k in bucket if k[1] == "self"],
key=lambda k: power_order[k[0]],
)
other_keys = sorted(
[k for k in bucket if k[1] == "other"],
key=lambda k: power_order[k[0]],
)
# negative jitter for self
for j, key in enumerate(self_keys, start=1):
offsets[key][idx] = -j * jitter_step
# positive jitter for other
for j, key in enumerate(other_keys):
offsets[key][idx] = j * jitter_step # first "other" gets 0
# ── finally plot using the jittered values ───────────────────
plt.figure(figsize=(fig_w, 5.5))
y_min, y_max = -2, 2 # track extremes for ylim
for other in powers:
if other == focal:
continue
for kind in ("self", "other"):
key = (other, kind)
data = traces[key]
y_off = [
v + off if not math.isnan(v) else v
for v, off in zip(data["y"], offsets[key])
]
# track axis range
for v in y_off:
if not math.isnan(v):
y_min = min(y_min, v)
y_max = max(y_max, v)
base_colour = _POWER_COLOUR[other]
colour = (
base_colour
if kind == "self"
else to_rgba(base_colour, alpha=0.35)
)
plt.plot(
data["x"],
y_off,
label=f"{other} ({kind})",
color=colour,
linewidth=2,
)
plt.xticks(list(phase_to_x.values()), phase_labels, rotation=90, fontsize=8)
margin = 0.1
plt.ylim(y_min - margin, y_max + margin)
plt.ylabel("Relationship value (2 … +2)")
plt.xlabel("Game phase")
plt.title(f"{focal} Relationships {run_label}")
plt.legend(ncol=3, fontsize=8)
plt.tight_layout()
plt.savefig(plot_dir / f"{focal}_relationships.png", dpi=140)
plt.close()
def _plot_phase_level(
@ -274,14 +461,22 @@ def run(experiment_dir: Path, ctx: dict) -> None: # pylint: disable=unused-argu
# 3. plots
sns.set_theme(style="whitegrid")
log.info("Generating aggregated plots")
_plot_game_level(all_game_df, plots_root / "game")
_plot_phase_level(all_phase_df, plots_root / "phase")
_plot_phase_level(all_phase_df, plots_root / "phase")
game_map = _map_game_id_to_run_dir(experiment_dir)
log.info("Generating per-game plots")
_plot_phase_level_per_game(
all_phase_df,
plots_root / "phase_by_game",
game_map,
)
log.info("Generating relationship plots")
_plot_relationships_per_game(
all_phase_df,
plots_root / "relationships",
game_map,
)
log.info("statistical_game_analysis: plots written → %s", plots_root)