diff --git a/ai_diplomacy/clients.py b/ai_diplomacy/clients.py index bff7849..74554ba 100644 --- a/ai_diplomacy/clients.py +++ b/ai_diplomacy/clients.py @@ -484,8 +484,10 @@ class BaseModelClient: unanswered_messages += f"\nFrom {msg['sender']} in {msg['phase']}: {msg['content']}\n" else: unanswered_messages += "\nNo urgent messages requiring direct responses.\n" - - return context + unanswered_messages + "\n\n" + instructions + + final_prompt = context + unanswered_messages + "\n\n" + instructions + final_prompt = final_prompt.replace('AUSTRIA', 'Austria').replace('ENGLAND', "England").replace('FRANCE', 'France').replace('GERMANY', 'Germany').replace('ITALY', "Italy").replace('RUSSIA', 'Russia').replace('TURKEY', 'Turkey') + return final_prompt async def get_planning_reply( # Renamed from get_plan to avoid conflict with get_plan in agent.py self, diff --git a/ai_diplomacy/game_history.py b/ai_diplomacy/game_history.py index 72cdcc7..27abd23 100644 --- a/ai_diplomacy/game_history.py +++ b/ai_diplomacy/game_history.py @@ -3,6 +3,7 @@ import logging from collections import defaultdict from dataclasses import dataclass, field from typing import Dict, List, Optional +from pprint import pformat logger = logging.getLogger("utils") logger.setLevel(logging.INFO) @@ -22,6 +23,7 @@ class Phase: name: str # e.g. "SPRING 1901" plans: Dict[str, str] = field(default_factory=dict) messages: List[Message] = field(default_factory=list) + submitted_orders_by_power: Dict[str, List[str]] = field(default_factory=dict) orders_by_power: Dict[str, List[str]] = field( default_factory=lambda: defaultdict(list) ) @@ -77,7 +79,7 @@ class Phase: # Join multiple results with commas result_str = f" ({', '.join(results[i])})" else: - result_str = " (successful)" + result_str = " (success)" result += f" {order}{result_str}\n" result += "\n" return result @@ -151,7 +153,124 @@ class GameHistory: return {} return self.phases[-1].plans - # NEW METHOD + + + + + def get_order_history_for_prompt( + self, + game: "Game", + power_name: str, + current_phase_name: str, + num_movement_phases_to_show: int = 1, + ) -> str: + # ── guard clauses ────────────────────────────────────────── + if not self.phases or num_movement_phases_to_show <= 0: + return "\n(No order history to show)\n" + + prev = [p for p in self.phases if p.name != current_phase_name] + if not prev: + return "\n(No previous phases in history)\n" + + start, seen = 0, 0 + for i in range(len(prev) - 1, -1, -1): + if prev[i].name.endswith("M"): + seen += 1 + if seen >= num_movement_phases_to_show: + start = i + break + phases_to_report = prev[start:] + if not phases_to_report: + return "\n(No relevant order history in look-back window)\n" + + # ── helpers ─────────────────────────────────────────────── + def _scalar(res): + """Flatten lists/dicts to a single outcome token (string).""" + tag = res + while isinstance(tag, list): + tag = tag[0] if tag else "" + if isinstance(tag, dict): + tag = tag.get("outcome") or tag.get("result") or "" + return str(tag).strip().lower() + + engine_phases = {ph.name: ph for ph in getattr(game, "get_phase_history", lambda: [])()} + eng2code = { + "AUSTRIA": "AUT", "ENGLAND": "ENG", "FRANCE": "FRA", + "GERMANY": "GER", "ITALY": "ITA", "RUSSIA": "RUS", "TURKEY": "TUR" + } + norm = game.map.norm + + out_lines = ["**ORDER HISTORY (Recent Rounds)**"] + + for ph in phases_to_report: + if not (ph.orders_by_power or ph.submitted_orders_by_power): + continue + out_lines.append(f"\n--- Orders from Phase: {ph.name} ---") + + for pwr in sorted(set(ph.orders_by_power) | set(ph.submitted_orders_by_power)): + submitted = ph.submitted_orders_by_power.get(pwr, []) + accepted = ph.orders_by_power.get(pwr, []) + + if isinstance(submitted, str): + submitted = [submitted] + if isinstance(accepted, str): + accepted = [accepted] + + def _norm_keep(o): # keep WAIVE readable + return o if o.upper() == "WAIVE" else norm(o) + + sub_norm = {_norm_keep(o): o for o in submitted} + acc_norm = {_norm_keep(o): o for o in accepted} + if not submitted and not accepted: + continue + + out_lines.append(f" {pwr}:") + + # outcome source + raw_res = ph.results_by_power.get(pwr) or ph.results_by_power or {} + if not raw_res: + eng = engine_phases.get(ph.name) + if eng and hasattr(eng, "order_results"): + key = next((k for k, v in eng2code.items() if v == pwr), None) + raw_res = (eng.order_results or {}).get(key, {}) + + seen_ok = set() + + # 1️⃣ accepted orders + for idx, order in enumerate(accepted): + if isinstance(raw_res, dict): + res_raw = raw_res.get(order) or raw_res.get(" ".join(order.split()[:2])) + elif isinstance(raw_res, list) and idx < len(raw_res): + res_raw = raw_res[idx] + else: + res_raw = "" + + tag = _scalar(res_raw) + if not tag or tag == "ok": + tag = "success" + elif "bounce" in tag: + tag = "bounce" + elif "void" == tag: + tag = "void: no effect" + + out_lines.append(f" {order} ({tag})") + seen_ok.add(_norm_keep(order)) + + # 2️⃣ invalid submissions + for k in sorted(set(sub_norm) - seen_ok): + out_lines.append(f" {sub_norm[k]} (Rejected by engine: invalid)") + + if len(out_lines) == 1: + return "\n(No orders were issued in recent history)\n" + return "\n".join(out_lines) + + + + + + + + def get_messages_this_round(self, power_name: str, current_phase_name: str) -> str: current_phase: Optional[Phase] = None for phase_obj in self.phases: @@ -288,80 +407,4 @@ class GameHistory: 'content': msg_content }) - return ignored_by_power - - # MODIFIED METHOD (renamed from get_game_history) - def get_previous_phases_history( - self, power_name: str, current_phase_name: str, include_plans: bool = True, num_prev_phases: int = 5 - ) -> str: - if not self.phases: - return "\n(No game history available)\n" - - relevant_phases = [p for p in self.phases if p.name != current_phase_name] - - if not relevant_phases: - return "\n(No previous game history before this round)\n" - - phases_to_report = relevant_phases[-num_prev_phases:] - - if not phases_to_report: - return "\n(No previous game history available within the lookback window)\n" - - game_history_str = "" - - for phase_idx, phase in enumerate(phases_to_report): - phase_content_str = f"\nPHASE: {phase.name}\n" - current_phase_has_content = False - - global_msgs = phase.get_global_messages() - if global_msgs: - phase_content_str += "\n GLOBAL MESSAGES:\n" - phase_content_str += "".join([f" {line}\n" for line in global_msgs.strip().split('\n')]) - current_phase_has_content = True - - private_msgs = phase.get_private_messages(power_name) - if private_msgs: - phase_content_str += "\n PRIVATE MESSAGES:\n" - for other_power, messages in private_msgs.items(): - phase_content_str += f" Conversation with {other_power}:\n" - phase_content_str += "".join([f" {line}\n" for line in messages.strip().split('\n')]) - current_phase_has_content = True - - if phase.orders_by_power: - phase_content_str += "\n ORDERS:\n" - for power, orders in phase.orders_by_power.items(): - indicator = " (your power)" if power == power_name else "" - phase_content_str += f" {power}{indicator}:\n" - results = phase.results_by_power.get(power, []) - for i, order in enumerate(orders): - result_str = " (successful)" - if i < len(results) and results[i] and not all(r == "" for r in results[i]): - result_str = f" ({', '.join(results[i])})" - phase_content_str += f" {order}{result_str}\n" - phase_content_str += "\n" - current_phase_has_content = True - - if current_phase_has_content: - if not game_history_str: - game_history_str = "**PREVIOUS GAME HISTORY (Messages, Orders, & Plans from older rounds & phases)**\n" - game_history_str += phase_content_str - if phase_idx < len(phases_to_report) -1 : - game_history_str += " " + "-" * 48 + "\n" - - if include_plans and phases_to_report: - last_reported_previous_phase = phases_to_report[-1] - if last_reported_previous_phase.plans: - if not game_history_str: - game_history_str = "**PREVIOUS GAME HISTORY (Messages, Orders, & Plans from older rounds & phases)**\n" - game_history_str += f"\n PLANS SUBMITTED FOR PHASE {last_reported_previous_phase.name}:\n" - if power_name in last_reported_previous_phase.plans: - game_history_str += f" Your Plan: {last_reported_previous_phase.plans[power_name]}\n" - for p_other, plan_other in last_reported_previous_phase.plans.items(): - if p_other != power_name: - game_history_str += f" {p_other}'s Plan: {plan_other}\n" - game_history_str += "\n" - - if not game_history_str.replace("**PREVIOUS GAME HISTORY (Messages, Orders, & Plans from older rounds & phases)**\n", "").strip(): - return "\n(No relevant previous game history to display)\n" - - return game_history_str.strip() + return ignored_by_power \ No newline at end of file diff --git a/ai_diplomacy/prompt_constructor.py b/ai_diplomacy/prompt_constructor.py index b21947f..a89847f 100644 --- a/ai_diplomacy/prompt_constructor.py +++ b/ai_diplomacy/prompt_constructor.py @@ -124,9 +124,20 @@ def build_context_prompt( # Build {home_centers} home_centers_str = ", ".join(HOME_CENTERS.get(power_name.upper(), [])) + order_history_str = game_history.get_order_history_for_prompt( + game=game, # Pass the game object for normalization + power_name=power_name, + current_phase_name=year_phase, + num_movement_phases_to_show=1 + ) + # Replace token only if it exists (template may not include it) if "{home_centers}" in context_template: context_template = context_template.replace("{home_centers}", home_centers_str) + + # Following the pattern for home_centers, use replace for safety + if "{order_history}" in context_template: + context_template = context_template.replace("{order_history}", order_history_str) context = context_template.format( power_name=power_name, @@ -202,5 +213,9 @@ def construct_order_generation_prompt( ) final_prompt = system_prompt + "\n\n" + context + "\n\n" + instructions + + # Make the power names more LLM friendly + final_prompt = final_prompt.replace('AUSTRIA', 'Austria').replace('ENGLAND', "England").replace('FRANCE', 'France').replace('GERMANY', 'Germany').replace('ITALY', "Italy").replace('RUSSIA', 'Russia').replace('TURKEY', 'Turkey') + print(final_prompt) - return final_prompt + return final_prompt \ No newline at end of file diff --git a/ai_diplomacy/prompts/context_prompt.txt b/ai_diplomacy/prompts/context_prompt.txt index 6ce74f3..3a34f44 100644 --- a/ai_diplomacy/prompts/context_prompt.txt +++ b/ai_diplomacy/prompts/context_prompt.txt @@ -13,6 +13,9 @@ Relationships: {agent_relationships} RECENT PRIVATE DIARY ENTRIES (Your inner thoughts and plans): {agent_private_diary} +ORDER HISTORY +{order_history} + GAME MAP Unit Locations: {all_unit_locations} diff --git a/ai_diplomacy/prompts_simple/context_prompt.txt b/ai_diplomacy/prompts_simple/context_prompt.txt index 592d67c..aa7612a 100644 --- a/ai_diplomacy/prompts_simple/context_prompt.txt +++ b/ai_diplomacy/prompts_simple/context_prompt.txt @@ -14,6 +14,9 @@ Relationships: {agent_relationships} # Recent Private Diary Entries (Your inner thoughts and plans): {agent_private_diary} +# Order History +{order_history} + # Game Map Unit Locations: {all_unit_locations} diff --git a/ai_diplomacy/prompts_simple/negotiation_diary_prompt.txt b/ai_diplomacy/prompts_simple/negotiation_diary_prompt.txt index 4d7a3fd..ba40b6c 100644 --- a/ai_diplomacy/prompts_simple/negotiation_diary_prompt.txt +++ b/ai_diplomacy/prompts_simple/negotiation_diary_prompt.txt @@ -19,9 +19,9 @@ Game State: TASK Analyze the negotiations, goals, relationships, and game state to: 1. Summarize key outcomes and agreements -2. State your strategic intent for {current_phase} +2. State your specific intents for {current_phase}, including moves you have agreed to in negotiations and whether you intend to fulfil them. 3. Update relationships as needed (Enemy, Unfriendly, Neutral, Friendly, Ally) -4. Note which powers are not responding to your messages and consider adjusting your approach +4. Important: You will not see the full negotiation log in the order decision phase, so you must transmit key information about the negotiations to your future self via this summary. When powers ignore your messages, consider: - They may be intentionally avoiding commitment diff --git a/ai_diplomacy/utils.py b/ai_diplomacy/utils.py index 016f427..2eb55c4 100644 --- a/ai_diplomacy/utils.py +++ b/ai_diplomacy/utils.py @@ -102,110 +102,100 @@ def gather_possible_orders(game: Game, power_name: str) -> Dict[str, List[str]]: async def get_valid_orders( game: Game, - client, # This is the BaseModelClient instance + client, # BaseModelClient instance board_state, power_name: str, possible_orders: Dict[str, List[str]], - game_history, # This is GameHistory instance - model_error_stats: Dict[str, Dict[str, int]], - agent_goals: Optional[List[str]] = None, - agent_relationships: Optional[Dict[str, str]] = None, - agent_private_diary_str: Optional[str] = None, # Added new parameter + game_history, + model_error_stats, + agent_goals=None, + agent_relationships=None, + agent_private_diary_str=None, log_file_path: str = None, phase: str = None, -) -> List[str]: +) -> Dict[str, List[str]]: """ - Tries up to 'max_retries' to generate and validate orders. - If invalid, we append the error feedback to the conversation - context for the next retry. If still invalid, return fallback. + Generates orders with the LLM, validates them by round-tripping through the + engine, and returns **both** the accepted and rejected orders so the caller + can record invalid attempts. + + Returns + ------- + dict : { "valid": [...], "invalid": [...] } """ - # Ask the LLM for orders - orders = await client.get_orders( + # ── 1. Ask the model ─────────────────────────────────────── + raw_orders = await client.get_orders( game=game, board_state=board_state, power_name=power_name, possible_orders=possible_orders, - conversation_text=game_history, # Pass GameHistory instance + conversation_text=game_history, model_error_stats=model_error_stats, agent_goals=agent_goals, agent_relationships=agent_relationships, - agent_private_diary_str=agent_private_diary_str, # Pass the diary string + agent_private_diary_str=agent_private_diary_str, log_file_path=log_file_path, phase=phase, ) - - # Initialize list to track invalid order information - invalid_info = [] - - # Validate each order - all_valid = True - valid_orders = [] - - if not isinstance(orders, list): # Ensure orders is a list before iterating - logger.warning(f"[{power_name}] Orders received from LLM is not a list: {orders}. Using fallback.") - model_error_stats[client.model_name]["order_decoding_errors"] += 1 # Use client.model_name - return client.fallback_orders(possible_orders) - for move in orders: - # Skip empty orders - if not move or move.strip() == "": - continue - - # Handle special case for WAIVE - if move.upper() == "WAIVE": - valid_orders.append(move) - continue - - # Example move: "A PAR H" -> unit="A PAR", order_part="H" - tokens = move.split(" ", 2) - if len(tokens) < 3: - invalid_info.append(f"Order '{move}' is malformed; expected 'A PAR H' style.") - all_valid = False - continue - - unit = " ".join(tokens[:2]) # e.g. "A PAR" - order_part = tokens[2] # e.g. "H" or "S A MAR" + invalid_info: list[str] = [] + valid: list[str] = [] + invalid: list[str] = [] - # Use the internal game validation method - if order_part == "B": # Build orders - validity = 1 # hack because game._valid_order doesn't support 'B' - elif order_part == "D": # Disband orders - # Check if the unit is actually one of the power's units - if unit in game.powers[power_name].units: - validity = 1 # Simple check, engine handles full validation - else: - validity = 0 - else: # Movement, Support, Hold, Convoy, Retreat - try: - validity = game._valid_order( - game.powers[power_name], unit, order_part, report=1 - ) - except Exception as e: - logger.warning(f"Error validating order '{move}': {e}") - invalid_info.append(f"Order '{move}' caused an error: {e}") - validity = 0 - all_valid = False - - if validity == 1: - valid_orders.append(move) - else: - invalid_info.append(f"Order '{move}' is invalid for {power_name}") - all_valid = False - - # Log validation results - if invalid_info: - logger.debug(f"[{power_name}] Invalid orders: {', '.join(invalid_info)}") - - if all_valid and valid_orders: - logger.debug(f"[{power_name}] All orders valid: {valid_orders}") - return valid_orders - else: - logger.debug(f"[{power_name}] Some orders invalid, using fallback.") - # Use client.model_name for stats key, as power_name might not be unique if multiple agents use same model + # ── 2. Type check ────────────────────────────────────────── + if not isinstance(raw_orders, list): + logger.warning("[%s] Orders received from LLM are not a list: %s. Using fallback.", + power_name, raw_orders) model_error_stats[client.model_name]["order_decoding_errors"] += 1 + return {"valid": client.fallback_orders(possible_orders), "invalid": []} + + # ── 3. Round-trip validation with engine ─────────────────── + CODE_TO_ENGINE = { + "AUT": "AUSTRIA", "ENG": "ENGLAND", "FRA": "FRANCE", + "GER": "GERMANY", "ITA": "ITALY", "RUS": "RUSSIA", "TUR": "TURKEY", + } + engine_power = power_name if power_name in game.powers else CODE_TO_ENGINE[power_name] + + for move in raw_orders: + if not move or not move.strip(): + continue + + upper = move.upper() + + # WAIVE is always valid + if upper == "WAIVE": + valid.append("WAIVE") + continue + + game.clear_orders(engine_power) + game.set_orders(engine_power, [upper]) + normed = game.get_orders(engine_power) + + if normed: # accepted + valid.append(normed[0]) + else: # rejected + invalid.append(upper) + invalid_info.append(f"Order '{move}' is invalid for {power_name}") + + game.clear_orders(engine_power) # clean slate for main engine flow + + # ── 4. Legacy logging & stats updates ────────────────────── + if invalid_info: # at least one bad move + logger.debug("[%s] Invalid orders: %s", power_name, ", ".join(invalid_info)) + model_error_stats[client.model_name]["order_decoding_errors"] += 1 + logger.debug("[%s] Some orders invalid, using fallback.", power_name) + else: + logger.debug("[%s] All orders valid: %s", power_name, valid) + + # ── 5. Fallback when nothing survives ───────────────────── + if not valid: fallback = client.fallback_orders(possible_orders) - return fallback + return {"valid": fallback, "invalid": invalid} + + return {"valid": valid, "invalid": invalid} + + def normalize_and_compare_orders( diff --git a/experiment_runner.py b/experiment_runner.py index 0095a05..92ed29f 100644 --- a/experiment_runner.py +++ b/experiment_runner.py @@ -99,6 +99,22 @@ def _add_experiment_flags(p: argparse.ArgumentParser) -> None: "flag to lm_game for deterministic behaviour)." ), ) + p.add_argument( + "--compare_to", + type=Path, + default=None, + help=( + "Path to another completed experiment directory. " + "If supplied alongside --experiment_dir, the runner skips game " + "execution and produces a statistical comparison between the two." + ), + ) + p.add_argument( + "--sig_level", + type=float, + default=0.05, + help="α for hypothesis tests in comparison mode (default 0.05).", + ) def _add_lm_game_flags(p: argparse.ArgumentParser) -> None: @@ -169,6 +185,17 @@ def _add_lm_game_flags(p: argparse.ArgumentParser) -> None: "Defaults to the packaged prompts directory." ), ) + p.add_argument( + "--simple_prompts", + type=_str2bool, + nargs="?", + const=True, + default=False, + help=( + "When true (1 / true / yes) the engine switches to simpler prompts " + "which low-midrange models handle better." + ), + ) # ──────────────────────────────────────────────────────────────────────────── @@ -255,6 +282,20 @@ _RunInfo = collections.namedtuple( "_RunInfo", "index run_dir seed cmd_line returncode elapsed_s" ) +def _str2bool(v: str | bool) -> bool: + """ + Accepts typical textual truthy / falsy values and returns a bool. + Mirrors the helper used inside lm_game. + """ + if isinstance(v, bool): + return v + val = v.lower() + if val in ("yes", "y", "true", "t", "1"): + return True + if val in ("no", "n", "false", "f", "0"): + return False + raise argparse.ArgumentTypeError("boolean value expected") + def _mk_run_dir(exp_dir: Path, idx: int) -> Path: run_dir = exp_dir / "runs" / f"run_{idx:05d}" @@ -372,6 +413,13 @@ def main() -> None: log.info("Appending to existing experiment: %s", exp_dir) exp_dir.mkdir(parents=True, exist_ok=True) + if exp_args.compare_to is not None: + from experiment_runner.analysis import compare_stats # local import + + compare_stats.run(exp_dir, exp_args.compare_to, alpha=exp_args.sig_level) + log.info("comparison complete; artefacts in %s/analysis/comparison", exp_dir) + return + # Persist experiment-level config cfg_path = exp_dir / "config.json" if not cfg_path.exists(): # ← new guard diff --git a/experiment_runner/analysis/compare_stats.py b/experiment_runner/analysis/compare_stats.py new file mode 100644 index 0000000..c89ea49 --- /dev/null +++ b/experiment_runner/analysis/compare_stats.py @@ -0,0 +1,168 @@ +""" +experiment_runner.analysis.compare_stats +---------------------------------------- + +Compares two completed Diplomacy experiments. Console output now +shows *all* metrics whose 95 % CI excludes 0 (α = 0.05 by default). + +CSV files remain: + + /analysis/comparison/ + comparison_aggregated_vs_.csv + comparison_by_power_vs_.csv +""" +from __future__ import annotations + +from pathlib import Path +from typing import Dict, List + +import numpy as np +import pandas as pd +from scipy import stats + +# ───────────────────────── helpers ────────────────────────── +_EXCLUDE: set[str] = { + "game_id", + "llm_model", + "power_name", + "game_phase", + "analyzed_response_type", +} + + +def _numeric_columns(df: pd.DataFrame) -> List[str]: + return [c for c in df.select_dtypes("number").columns if c not in _EXCLUDE] + + +def _load_games(exp: Path) -> pd.DataFrame: + indiv = exp / "analysis" / "statistical_game_analysis" / "individual" + csvs = list(indiv.glob("*_game_analysis.csv")) + if not csvs: + raise FileNotFoundError(f"no *_game_analysis.csv under {indiv}") + return pd.concat((pd.read_csv(p) for p in csvs), ignore_index=True) + + +def _welch(a: np.ndarray, b: np.ndarray, alpha: float) -> Dict: + _t, p_val = stats.ttest_ind(a, b, equal_var=False) + mean_a, mean_b = a.mean(), b.mean() + diff = mean_b - mean_a + pooled_sd = np.sqrt((a.var(ddof=1) + b.var(ddof=1)) / 2) + cohen_d = diff / pooled_sd if pooled_sd else np.nan + se = np.sqrt(a.var(ddof=1) / len(a) + b.var(ddof=1) / len(b)) + df = len(a) + len(b) - 2 + ci = stats.t.ppf(1 - alpha / 2, df=df) * se + return { + "Mean_A": mean_a, + "Mean_B": mean_b, + "Diff": diff, + "CI_low": diff - ci, + "CI_high": diff + ci, + "p_value": p_val, + "Cohen_d": cohen_d, + "n_A": len(a), + "n_B": len(b), + } + + +# ───────────────────────── console formatting ───────────────────────── +def _fmt_row(label: str, r: Dict, width: int) -> str: + ci = f"[{r['CI_low']:+.2f}, {r['CI_high']:+.2f}]" + return ( + f" {label:<{width}} " + f"{r['Diff']:+6.2f} " + f"({r['Mean_A']:.2f} → {r['Mean_B']:.2f}) " + f"95%CI {ci:<17} " + f"p={r['p_value']:.4g} " + f"d={r['Cohen_d']:+.2f}" + ) + + +def _print_hdr(title: str) -> None: + print(f"\n{title}") + print("─" * len(title)) + + +def _significant(df: pd.DataFrame, alpha: float) -> pd.DataFrame: + """Return rows whose CI excludes 0 (equivalently p < alpha).""" + sig = df[ + ((df["CI_low"] > 0) & (df["CI_high"] > 0)) + | ((df["CI_low"] < 0) & (df["CI_high"] < 0)) + | (df["p_value"] < alpha) # fallback, same criterion + ].copy() + return sig.sort_values("p_value").reset_index(drop=True) + + +# ───────────────────────── public API ───────────────────────── +def run(exp_a: Path, exp_b: Path, alpha: float = 0.05) -> None: + df_a = _load_games(exp_a) + df_b = _load_games(exp_b) + + metrics = sorted(set(_numeric_columns(df_a)) & set(_numeric_columns(df_b))) + if not metrics: + print("No overlapping numeric metrics to compare.") + return + + out_dir = exp_a / "analysis" / "comparison" + out_dir.mkdir(parents=True, exist_ok=True) + + # ── section 1: aggregated across powers ─────────────────────────── + rows_agg: List[Dict] = [] + for col in metrics: + a_vals = df_a.groupby("game_id")[col].mean().dropna().to_numpy() + b_vals = df_b.groupby("game_id")[col].mean().dropna().to_numpy() + if len(a_vals) < 2 or len(b_vals) < 2: + continue + rows_agg.append({"Metric": col, **_welch(a_vals, b_vals, alpha)}) + + agg_df = pd.DataFrame(rows_agg) + agg_csv = out_dir / f"comparison_aggregated_vs_{exp_b.name}.csv" + agg_df.to_csv(agg_csv, index=False) + + sig_agg = _significant(agg_df, alpha) + if not sig_agg.empty: + n_a, n_b = int(sig_agg.iloc[0]["n_A"]), int(sig_agg.iloc[0]["n_B"]) + _print_hdr(f"Aggregated Across Powers – significant at 95 % CI (nA={n_a}, nB={n_b})") + label_w = max(len(m) for m in sig_agg["Metric"]) + 2 + for _, r in sig_agg.iterrows(): + print(_fmt_row(r["Metric"], r, label_w)) + else: + _print_hdr("Aggregated Across Powers – no metric significant at 95 % CI") + + # ── section 2: per-power breakdown ─────────────────────────────── + rows_pow: List[Dict] = [] + powers = sorted(set(df_a["power_name"]) & set(df_b["power_name"])) + for power in powers: + sub_a = df_a[df_a["power_name"] == power] + sub_b = df_b[df_b["power_name"] == power] + for col in metrics: + a_vals = sub_a[col].dropna().to_numpy() + b_vals = sub_b[col].dropna().to_numpy() + if len(a_vals) < 2 or len(b_vals) < 2: + continue + rows_pow.append( + {"Power": power, "Metric": col, **_welch(a_vals, b_vals, alpha)} + ) + + pow_df = pd.DataFrame(rows_pow) + pow_csv = out_dir / f"comparison_by_power_vs_{exp_b.name}.csv" + pow_df.to_csv(pow_csv, index=False) + + sig_pow = _significant(pow_df, alpha) + if not sig_pow.empty: + _print_hdr(f"Per-Power Breakdown – metrics significant at 95 % CI (α={alpha})") + label_w = max(len(m) for m in sig_pow["Metric"]) + 2 + for power in powers: + sub = sig_pow[sig_pow["Power"] == power] + if sub.empty: + continue + n_a, n_b = int(sub.iloc[0]["n_A"]), int(sub.iloc[0]["n_B"]) + print(f"{power} (nA={n_a}, nB={n_b})") + for _, r in sub.iterrows(): + print(_fmt_row(r["Metric"], r, label_w)) + else: + _print_hdr("Per-Power Breakdown – no metric significant at 95 % CI") + + # ── summary of file outputs ─────────────────────────────────────── + print("\nCSV outputs:") + print(f" • {agg_csv}") + print(f" • {pow_csv}") diff --git a/lm_game.py b/lm_game.py index c2fb61d..bd41a24 100644 --- a/lm_game.py +++ b/lm_game.py @@ -313,18 +313,30 @@ async def main(): active_powers = [p for p, a in agents.items() if not game.powers[p].is_eliminated()] order_power_names = [p for p in active_powers if gather_possible_orders(game, p)] + submitted_orders_this_phase = defaultdict(list) for i, result in enumerate(order_results): p_name = order_power_names[i] - if isinstance(result, Exception): - logger.error(f"Error getting orders for {p_name}: {result}", exc_info=result) - game.set_orders(p_name, []) - else: - orders = result - game.set_orders(p_name, orders) - if orders: - await agents[p_name].generate_order_diary_entry(game, orders, llm_log_file_path) + if isinstance(result, Exception): + logger.error("Error getting orders for %s: %s", p_name, result, exc_info=result) + valid, invalid = [], [] + else: + valid = result.get("valid", []) + invalid = result.get("invalid", []) + + # what the engine will actually execute + game.set_orders(p_name, valid) + + # what we record for prompt/history purposes + submitted_orders_this_phase[p_name] = valid + invalid + + # optional: diary entry only for the orders we tried to submit + if valid or invalid: + await agents[p_name].generate_order_diary_entry( + game, valid + invalid, llm_log_file_path + ) + # --- 4d. Process Phase --- completed_phase = current_phase game.process() @@ -333,6 +345,25 @@ async def main(): logger.info(f"{power_name}: {power.centers}") # --- 4e. Post-Processing and State Updates --- + phase_history_from_game = game.get_phase_history() + if phase_history_from_game: + last_phase_from_game = phase_history_from_game[-1] + if last_phase_from_game.name == completed_phase: + phase_obj_in_my_history = game_history._get_phase(completed_phase) + if phase_obj_in_my_history: + # Store the orders the agents generated + phase_obj_in_my_history.submitted_orders_by_power = submitted_orders_this_phase + # Store the orders the engine actually accepted + phase_obj_in_my_history.orders_by_power = last_phase_from_game.orders + + # Store the results for the accepted orders + converted_results = defaultdict(list) + if last_phase_from_game.results: + for pwr, res_list in last_phase_from_game.results.items(): + converted_results[pwr] = [[res] for res in res_list] + phase_obj_in_my_history.results_by_power = converted_results + logger.debug(f"Populated submitted/accepted order and result history for phase {completed_phase}.") + phase_summary = game.phase_summaries.get(current_phase, "(Summary not generated)") all_orders_this_phase = game.order_history.get(current_short_phase, {})