mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-19 12:58:09 +00:00
add relationship plots
This commit is contained in:
parent
af3fb8ce48
commit
70a876bcee
3 changed files with 216 additions and 7 deletions
|
|
@ -28,9 +28,10 @@ import logging
|
|||
import re
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from typing import List, Any
|
||||
import math
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.colors import to_rgba
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
|
||||
|
|
@ -42,6 +43,16 @@ log = logging.getLogger(__name__)
|
|||
# ───────────────────────── helpers ──────────────────────────
|
||||
_SEASON_ORDER = {"S": 0, "F": 1, "W": 2, "A": 3}
|
||||
|
||||
_POWER_COLOUR = {
|
||||
"AUSTRIA": "tab:red",
|
||||
"ENGLAND": "tab:blue",
|
||||
"FRANCE": "tab:green",
|
||||
"GERMANY": "tab:purple",
|
||||
"ITALY": "tab:olive",
|
||||
"RUSSIA": "tab:brown",
|
||||
"TURKEY": "tab:orange",
|
||||
}
|
||||
|
||||
|
||||
def _sanitize(name: str) -> str:
|
||||
return re.sub(r"[^\w\-\.]", "_", name)
|
||||
|
|
@ -64,6 +75,23 @@ def _numeric_columns(df: pd.DataFrame, extra_exclude: set[str] | None = None) ->
|
|||
return [c for c in df.select_dtypes("number").columns if c not in exclude]
|
||||
|
||||
|
||||
def _parse_relationships(rel_string: str) -> dict[str, int]:
|
||||
"""
|
||||
"AUSTRIA:-1|FRANCE:2" → {"AUSTRIA": -1, "FRANCE": 2}
|
||||
Returns empty dict on blank / nan / bad input.
|
||||
"""
|
||||
if not isinstance(rel_string, str) or not rel_string:
|
||||
return {}
|
||||
out: dict[str, int] = {}
|
||||
for part in rel_string.split("|"):
|
||||
try:
|
||||
pwr, val = part.split(":")
|
||||
out[pwr.strip().upper()] = int(val)
|
||||
except ValueError:
|
||||
continue
|
||||
return out
|
||||
|
||||
|
||||
def _phase_sort_key(ph: str) -> tuple[int, int]:
|
||||
"""
|
||||
Sort key that keeps normal phases chronological and forces the literal
|
||||
|
|
@ -159,6 +187,165 @@ def _plot_game_level(all_games: pd.DataFrame, plot_dir: Path) -> None:
|
|||
plt.close(fig)
|
||||
|
||||
|
||||
def _plot_relationships_per_game(
|
||||
all_phase: pd.DataFrame,
|
||||
root_dir: Path,
|
||||
gameid_to_rundir: dict[str, str],
|
||||
) -> None:
|
||||
"""
|
||||
For each game, create one PNG per *focal* power that shows
|
||||
• self-perceived relationship to every other power (solid, full colour)
|
||||
• how that other power perceives the focal power (solid, lighter, thinner)
|
||||
|
||||
The x-axis is dense and specific to each game (0…n-1) with tick labels set
|
||||
to the actual phase strings, so there are no gaps no matter which phases
|
||||
appear in different runs.
|
||||
|
||||
To keep coincident traces legible, points that would sit on top of one
|
||||
another are given a minimal vertical jitter. “Self” points are nudged
|
||||
down (negative), “other” points up (positive). Powers keep their canonical
|
||||
AUSTRIA → TURKEY ordering within each direction so that the visual code is
|
||||
stable across games.
|
||||
"""
|
||||
if all_phase.empty or "game_id" not in all_phase.columns:
|
||||
return
|
||||
|
||||
# ── ensure rel_dict column exists ────────────────────────────────────
|
||||
if "rel_dict" not in all_phase.columns:
|
||||
all_phase = all_phase.copy()
|
||||
all_phase["rel_dict"] = all_phase["relationships"].apply(_parse_relationships)
|
||||
|
||||
powers = list(_POWER_COLOUR.keys()) # AUSTRIA … TURKEY
|
||||
power_order = {p: i for i, p in enumerate(powers)}
|
||||
jitter_step = 0.04 # vertical gap between stacked points
|
||||
|
||||
for game_id, game_df in all_phase.groupby("game_id", sort=False):
|
||||
# dense phase ordering (0 … n-1)
|
||||
phase_labels = sorted(game_df["game_phase"].unique(), key=_phase_sort_key)
|
||||
phase_to_x = {ph: idx for idx, ph in enumerate(phase_labels)}
|
||||
fig_w = max(8, len(phase_labels) * 0.1 + 4)
|
||||
|
||||
# quick lookup: (phase, power) → rel_dict
|
||||
rel_lookup = {
|
||||
(row.game_phase, row.power_name): row.rel_dict
|
||||
for row in game_df.itertuples()
|
||||
}
|
||||
|
||||
run_label = gameid_to_rundir.get(str(game_id), f"game_{_sanitize(str(game_id))}")
|
||||
plot_dir = root_dir / run_label
|
||||
plot_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for focal in powers:
|
||||
# ── pre-gather every trace so we can resolve collisions first ─
|
||||
traces: dict[tuple[str, str], dict[str, Any]] = {}
|
||||
for other in powers:
|
||||
if other == focal:
|
||||
continue
|
||||
|
||||
x_vals: list[int] = []
|
||||
self_vals: list[float] = []
|
||||
other_vals: list[float] = []
|
||||
|
||||
for ph in phase_labels:
|
||||
x_vals.append(phase_to_x[ph])
|
||||
|
||||
# focal’s view of "other"
|
||||
self_rels = rel_lookup.get((ph, focal), {})
|
||||
self_vals.append(self_rels.get(other, float("nan")))
|
||||
|
||||
# other’s view of focal
|
||||
other_rels = rel_lookup.get((ph, other), {})
|
||||
other_vals.append(other_rels.get(focal, float("nan")))
|
||||
|
||||
traces[(other, "self")] = dict(x=x_vals, y=self_vals)
|
||||
traces[(other, "other")] = dict(x=x_vals, y=other_vals)
|
||||
|
||||
# ── collision-aware jitter: build offset matrix ───────────────
|
||||
offsets: dict[tuple[str, str], list[float]] = {
|
||||
key: [0.0] * len(phase_labels) for key in traces
|
||||
}
|
||||
|
||||
for idx in range(len(phase_labels)):
|
||||
# group all non-NaN points by their integer y level
|
||||
level_buckets: dict[float, list[tuple[str, str]]] = {}
|
||||
for key, data in traces.items():
|
||||
y = data["y"][idx]
|
||||
if not math.isnan(y):
|
||||
level_buckets.setdefault(y, []).append(key)
|
||||
|
||||
# de-stack each bucket independently
|
||||
for y_val, bucket in level_buckets.items():
|
||||
if len(bucket) == 1:
|
||||
continue # nothing overlaps here
|
||||
|
||||
# split into self vs other, then power order
|
||||
self_keys = sorted(
|
||||
[k for k in bucket if k[1] == "self"],
|
||||
key=lambda k: power_order[k[0]],
|
||||
)
|
||||
other_keys = sorted(
|
||||
[k for k in bucket if k[1] == "other"],
|
||||
key=lambda k: power_order[k[0]],
|
||||
)
|
||||
|
||||
# negative jitter for self
|
||||
for j, key in enumerate(self_keys, start=1):
|
||||
offsets[key][idx] = -j * jitter_step
|
||||
# positive jitter for other
|
||||
for j, key in enumerate(other_keys):
|
||||
offsets[key][idx] = j * jitter_step # first "other" gets 0
|
||||
|
||||
# ── finally plot using the jittered values ───────────────────
|
||||
plt.figure(figsize=(fig_w, 5.5))
|
||||
y_min, y_max = -2, 2 # track extremes for ylim
|
||||
|
||||
for other in powers:
|
||||
if other == focal:
|
||||
continue
|
||||
|
||||
for kind in ("self", "other"):
|
||||
key = (other, kind)
|
||||
data = traces[key]
|
||||
y_off = [
|
||||
v + off if not math.isnan(v) else v
|
||||
for v, off in zip(data["y"], offsets[key])
|
||||
]
|
||||
|
||||
# track axis range
|
||||
for v in y_off:
|
||||
if not math.isnan(v):
|
||||
y_min = min(y_min, v)
|
||||
y_max = max(y_max, v)
|
||||
|
||||
base_colour = _POWER_COLOUR[other]
|
||||
colour = (
|
||||
base_colour
|
||||
if kind == "self"
|
||||
else to_rgba(base_colour, alpha=0.35)
|
||||
)
|
||||
|
||||
plt.plot(
|
||||
data["x"],
|
||||
y_off,
|
||||
label=f"{other} ({kind})",
|
||||
color=colour,
|
||||
linewidth=2,
|
||||
)
|
||||
|
||||
plt.xticks(list(phase_to_x.values()), phase_labels, rotation=90, fontsize=8)
|
||||
margin = 0.1
|
||||
plt.ylim(y_min - margin, y_max + margin)
|
||||
plt.ylabel("Relationship value (−2 … +2)")
|
||||
plt.xlabel("Game phase")
|
||||
plt.title(f"{focal} Relationships – {run_label}")
|
||||
plt.legend(ncol=3, fontsize=8)
|
||||
plt.tight_layout()
|
||||
|
||||
plt.savefig(plot_dir / f"{focal}_relationships.png", dpi=140)
|
||||
plt.close()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def _plot_phase_level(
|
||||
|
|
@ -274,14 +461,22 @@ def run(experiment_dir: Path, ctx: dict) -> None: # pylint: disable=unused-argu
|
|||
|
||||
# 3. plots
|
||||
sns.set_theme(style="whitegrid")
|
||||
log.info("Generating aggregated plots")
|
||||
_plot_game_level(all_game_df, plots_root / "game")
|
||||
_plot_phase_level(all_phase_df, plots_root / "phase")
|
||||
_plot_phase_level(all_phase_df, plots_root / "phase")
|
||||
game_map = _map_game_id_to_run_dir(experiment_dir)
|
||||
log.info("Generating per-game plots")
|
||||
_plot_phase_level_per_game(
|
||||
all_phase_df,
|
||||
plots_root / "phase_by_game",
|
||||
game_map,
|
||||
)
|
||||
log.info("Generating relationship plots")
|
||||
_plot_relationships_per_game(
|
||||
all_phase_df,
|
||||
plots_root / "relationships",
|
||||
game_map,
|
||||
)
|
||||
|
||||
|
||||
log.info("statistical_game_analysis: plots written → %s", plots_root)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue