mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-19 12:58:09 +00:00
add order history
This commit is contained in:
parent
ebf26cf8a6
commit
b5a84867a1
10 changed files with 477 additions and 174 deletions
|
|
@ -484,8 +484,10 @@ class BaseModelClient:
|
|||
unanswered_messages += f"\nFrom {msg['sender']} in {msg['phase']}: {msg['content']}\n"
|
||||
else:
|
||||
unanswered_messages += "\nNo urgent messages requiring direct responses.\n"
|
||||
|
||||
return context + unanswered_messages + "\n\n" + instructions
|
||||
|
||||
final_prompt = context + unanswered_messages + "\n\n" + instructions
|
||||
final_prompt = final_prompt.replace('AUSTRIA', 'Austria').replace('ENGLAND', "England").replace('FRANCE', 'France').replace('GERMANY', 'Germany').replace('ITALY', "Italy").replace('RUSSIA', 'Russia').replace('TURKEY', 'Turkey')
|
||||
return final_prompt
|
||||
|
||||
async def get_planning_reply( # Renamed from get_plan to avoid conflict with get_plan in agent.py
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import logging
|
|||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional
|
||||
from pprint import pformat
|
||||
|
||||
logger = logging.getLogger("utils")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
|
@ -22,6 +23,7 @@ class Phase:
|
|||
name: str # e.g. "SPRING 1901"
|
||||
plans: Dict[str, str] = field(default_factory=dict)
|
||||
messages: List[Message] = field(default_factory=list)
|
||||
submitted_orders_by_power: Dict[str, List[str]] = field(default_factory=dict)
|
||||
orders_by_power: Dict[str, List[str]] = field(
|
||||
default_factory=lambda: defaultdict(list)
|
||||
)
|
||||
|
|
@ -77,7 +79,7 @@ class Phase:
|
|||
# Join multiple results with commas
|
||||
result_str = f" ({', '.join(results[i])})"
|
||||
else:
|
||||
result_str = " (successful)"
|
||||
result_str = " (success)"
|
||||
result += f" {order}{result_str}\n"
|
||||
result += "\n"
|
||||
return result
|
||||
|
|
@ -151,7 +153,124 @@ class GameHistory:
|
|||
return {}
|
||||
return self.phases[-1].plans
|
||||
|
||||
# NEW METHOD
|
||||
|
||||
|
||||
|
||||
|
||||
def get_order_history_for_prompt(
|
||||
self,
|
||||
game: "Game",
|
||||
power_name: str,
|
||||
current_phase_name: str,
|
||||
num_movement_phases_to_show: int = 1,
|
||||
) -> str:
|
||||
# ── guard clauses ──────────────────────────────────────────
|
||||
if not self.phases or num_movement_phases_to_show <= 0:
|
||||
return "\n(No order history to show)\n"
|
||||
|
||||
prev = [p for p in self.phases if p.name != current_phase_name]
|
||||
if not prev:
|
||||
return "\n(No previous phases in history)\n"
|
||||
|
||||
start, seen = 0, 0
|
||||
for i in range(len(prev) - 1, -1, -1):
|
||||
if prev[i].name.endswith("M"):
|
||||
seen += 1
|
||||
if seen >= num_movement_phases_to_show:
|
||||
start = i
|
||||
break
|
||||
phases_to_report = prev[start:]
|
||||
if not phases_to_report:
|
||||
return "\n(No relevant order history in look-back window)\n"
|
||||
|
||||
# ── helpers ───────────────────────────────────────────────
|
||||
def _scalar(res):
|
||||
"""Flatten lists/dicts to a single outcome token (string)."""
|
||||
tag = res
|
||||
while isinstance(tag, list):
|
||||
tag = tag[0] if tag else ""
|
||||
if isinstance(tag, dict):
|
||||
tag = tag.get("outcome") or tag.get("result") or ""
|
||||
return str(tag).strip().lower()
|
||||
|
||||
engine_phases = {ph.name: ph for ph in getattr(game, "get_phase_history", lambda: [])()}
|
||||
eng2code = {
|
||||
"AUSTRIA": "AUT", "ENGLAND": "ENG", "FRANCE": "FRA",
|
||||
"GERMANY": "GER", "ITALY": "ITA", "RUSSIA": "RUS", "TURKEY": "TUR"
|
||||
}
|
||||
norm = game.map.norm
|
||||
|
||||
out_lines = ["**ORDER HISTORY (Recent Rounds)**"]
|
||||
|
||||
for ph in phases_to_report:
|
||||
if not (ph.orders_by_power or ph.submitted_orders_by_power):
|
||||
continue
|
||||
out_lines.append(f"\n--- Orders from Phase: {ph.name} ---")
|
||||
|
||||
for pwr in sorted(set(ph.orders_by_power) | set(ph.submitted_orders_by_power)):
|
||||
submitted = ph.submitted_orders_by_power.get(pwr, [])
|
||||
accepted = ph.orders_by_power.get(pwr, [])
|
||||
|
||||
if isinstance(submitted, str):
|
||||
submitted = [submitted]
|
||||
if isinstance(accepted, str):
|
||||
accepted = [accepted]
|
||||
|
||||
def _norm_keep(o): # keep WAIVE readable
|
||||
return o if o.upper() == "WAIVE" else norm(o)
|
||||
|
||||
sub_norm = {_norm_keep(o): o for o in submitted}
|
||||
acc_norm = {_norm_keep(o): o for o in accepted}
|
||||
if not submitted and not accepted:
|
||||
continue
|
||||
|
||||
out_lines.append(f" {pwr}:")
|
||||
|
||||
# outcome source
|
||||
raw_res = ph.results_by_power.get(pwr) or ph.results_by_power or {}
|
||||
if not raw_res:
|
||||
eng = engine_phases.get(ph.name)
|
||||
if eng and hasattr(eng, "order_results"):
|
||||
key = next((k for k, v in eng2code.items() if v == pwr), None)
|
||||
raw_res = (eng.order_results or {}).get(key, {})
|
||||
|
||||
seen_ok = set()
|
||||
|
||||
# 1️⃣ accepted orders
|
||||
for idx, order in enumerate(accepted):
|
||||
if isinstance(raw_res, dict):
|
||||
res_raw = raw_res.get(order) or raw_res.get(" ".join(order.split()[:2]))
|
||||
elif isinstance(raw_res, list) and idx < len(raw_res):
|
||||
res_raw = raw_res[idx]
|
||||
else:
|
||||
res_raw = ""
|
||||
|
||||
tag = _scalar(res_raw)
|
||||
if not tag or tag == "ok":
|
||||
tag = "success"
|
||||
elif "bounce" in tag:
|
||||
tag = "bounce"
|
||||
elif "void" == tag:
|
||||
tag = "void: no effect"
|
||||
|
||||
out_lines.append(f" {order} ({tag})")
|
||||
seen_ok.add(_norm_keep(order))
|
||||
|
||||
# 2️⃣ invalid submissions
|
||||
for k in sorted(set(sub_norm) - seen_ok):
|
||||
out_lines.append(f" {sub_norm[k]} (Rejected by engine: invalid)")
|
||||
|
||||
if len(out_lines) == 1:
|
||||
return "\n(No orders were issued in recent history)\n"
|
||||
return "\n".join(out_lines)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def get_messages_this_round(self, power_name: str, current_phase_name: str) -> str:
|
||||
current_phase: Optional[Phase] = None
|
||||
for phase_obj in self.phases:
|
||||
|
|
@ -288,80 +407,4 @@ class GameHistory:
|
|||
'content': msg_content
|
||||
})
|
||||
|
||||
return ignored_by_power
|
||||
|
||||
# MODIFIED METHOD (renamed from get_game_history)
|
||||
def get_previous_phases_history(
|
||||
self, power_name: str, current_phase_name: str, include_plans: bool = True, num_prev_phases: int = 5
|
||||
) -> str:
|
||||
if not self.phases:
|
||||
return "\n(No game history available)\n"
|
||||
|
||||
relevant_phases = [p for p in self.phases if p.name != current_phase_name]
|
||||
|
||||
if not relevant_phases:
|
||||
return "\n(No previous game history before this round)\n"
|
||||
|
||||
phases_to_report = relevant_phases[-num_prev_phases:]
|
||||
|
||||
if not phases_to_report:
|
||||
return "\n(No previous game history available within the lookback window)\n"
|
||||
|
||||
game_history_str = ""
|
||||
|
||||
for phase_idx, phase in enumerate(phases_to_report):
|
||||
phase_content_str = f"\nPHASE: {phase.name}\n"
|
||||
current_phase_has_content = False
|
||||
|
||||
global_msgs = phase.get_global_messages()
|
||||
if global_msgs:
|
||||
phase_content_str += "\n GLOBAL MESSAGES:\n"
|
||||
phase_content_str += "".join([f" {line}\n" for line in global_msgs.strip().split('\n')])
|
||||
current_phase_has_content = True
|
||||
|
||||
private_msgs = phase.get_private_messages(power_name)
|
||||
if private_msgs:
|
||||
phase_content_str += "\n PRIVATE MESSAGES:\n"
|
||||
for other_power, messages in private_msgs.items():
|
||||
phase_content_str += f" Conversation with {other_power}:\n"
|
||||
phase_content_str += "".join([f" {line}\n" for line in messages.strip().split('\n')])
|
||||
current_phase_has_content = True
|
||||
|
||||
if phase.orders_by_power:
|
||||
phase_content_str += "\n ORDERS:\n"
|
||||
for power, orders in phase.orders_by_power.items():
|
||||
indicator = " (your power)" if power == power_name else ""
|
||||
phase_content_str += f" {power}{indicator}:\n"
|
||||
results = phase.results_by_power.get(power, [])
|
||||
for i, order in enumerate(orders):
|
||||
result_str = " (successful)"
|
||||
if i < len(results) and results[i] and not all(r == "" for r in results[i]):
|
||||
result_str = f" ({', '.join(results[i])})"
|
||||
phase_content_str += f" {order}{result_str}\n"
|
||||
phase_content_str += "\n"
|
||||
current_phase_has_content = True
|
||||
|
||||
if current_phase_has_content:
|
||||
if not game_history_str:
|
||||
game_history_str = "**PREVIOUS GAME HISTORY (Messages, Orders, & Plans from older rounds & phases)**\n"
|
||||
game_history_str += phase_content_str
|
||||
if phase_idx < len(phases_to_report) -1 :
|
||||
game_history_str += " " + "-" * 48 + "\n"
|
||||
|
||||
if include_plans and phases_to_report:
|
||||
last_reported_previous_phase = phases_to_report[-1]
|
||||
if last_reported_previous_phase.plans:
|
||||
if not game_history_str:
|
||||
game_history_str = "**PREVIOUS GAME HISTORY (Messages, Orders, & Plans from older rounds & phases)**\n"
|
||||
game_history_str += f"\n PLANS SUBMITTED FOR PHASE {last_reported_previous_phase.name}:\n"
|
||||
if power_name in last_reported_previous_phase.plans:
|
||||
game_history_str += f" Your Plan: {last_reported_previous_phase.plans[power_name]}\n"
|
||||
for p_other, plan_other in last_reported_previous_phase.plans.items():
|
||||
if p_other != power_name:
|
||||
game_history_str += f" {p_other}'s Plan: {plan_other}\n"
|
||||
game_history_str += "\n"
|
||||
|
||||
if not game_history_str.replace("**PREVIOUS GAME HISTORY (Messages, Orders, & Plans from older rounds & phases)**\n", "").strip():
|
||||
return "\n(No relevant previous game history to display)\n"
|
||||
|
||||
return game_history_str.strip()
|
||||
return ignored_by_power
|
||||
|
|
@ -124,9 +124,20 @@ def build_context_prompt(
|
|||
# Build {home_centers}
|
||||
home_centers_str = ", ".join(HOME_CENTERS.get(power_name.upper(), []))
|
||||
|
||||
order_history_str = game_history.get_order_history_for_prompt(
|
||||
game=game, # Pass the game object for normalization
|
||||
power_name=power_name,
|
||||
current_phase_name=year_phase,
|
||||
num_movement_phases_to_show=1
|
||||
)
|
||||
|
||||
# Replace token only if it exists (template may not include it)
|
||||
if "{home_centers}" in context_template:
|
||||
context_template = context_template.replace("{home_centers}", home_centers_str)
|
||||
|
||||
# Following the pattern for home_centers, use replace for safety
|
||||
if "{order_history}" in context_template:
|
||||
context_template = context_template.replace("{order_history}", order_history_str)
|
||||
|
||||
context = context_template.format(
|
||||
power_name=power_name,
|
||||
|
|
@ -202,5 +213,9 @@ def construct_order_generation_prompt(
|
|||
)
|
||||
|
||||
final_prompt = system_prompt + "\n\n" + context + "\n\n" + instructions
|
||||
|
||||
# Make the power names more LLM friendly
|
||||
final_prompt = final_prompt.replace('AUSTRIA', 'Austria').replace('ENGLAND', "England").replace('FRANCE', 'France').replace('GERMANY', 'Germany').replace('ITALY', "Italy").replace('RUSSIA', 'Russia').replace('TURKEY', 'Turkey')
|
||||
print(final_prompt)
|
||||
|
||||
return final_prompt
|
||||
return final_prompt
|
||||
|
|
@ -13,6 +13,9 @@ Relationships: {agent_relationships}
|
|||
RECENT PRIVATE DIARY ENTRIES (Your inner thoughts and plans):
|
||||
{agent_private_diary}
|
||||
|
||||
ORDER HISTORY
|
||||
{order_history}
|
||||
|
||||
GAME MAP
|
||||
Unit Locations:
|
||||
{all_unit_locations}
|
||||
|
|
|
|||
|
|
@ -14,6 +14,9 @@ Relationships: {agent_relationships}
|
|||
# Recent Private Diary Entries (Your inner thoughts and plans):
|
||||
{agent_private_diary}
|
||||
|
||||
# Order History
|
||||
{order_history}
|
||||
|
||||
# Game Map
|
||||
Unit Locations:
|
||||
{all_unit_locations}
|
||||
|
|
|
|||
|
|
@ -19,9 +19,9 @@ Game State:
|
|||
TASK
|
||||
Analyze the negotiations, goals, relationships, and game state to:
|
||||
1. Summarize key outcomes and agreements
|
||||
2. State your strategic intent for {current_phase}
|
||||
2. State your specific intents for {current_phase}, including moves you have agreed to in negotiations and whether you intend to fulfil them.
|
||||
3. Update relationships as needed (Enemy, Unfriendly, Neutral, Friendly, Ally)
|
||||
4. Note which powers are not responding to your messages and consider adjusting your approach
|
||||
4. Important: You will not see the full negotiation log in the order decision phase, so you must transmit key information about the negotiations to your future self via this summary.
|
||||
|
||||
When powers ignore your messages, consider:
|
||||
- They may be intentionally avoiding commitment
|
||||
|
|
|
|||
|
|
@ -102,110 +102,100 @@ def gather_possible_orders(game: Game, power_name: str) -> Dict[str, List[str]]:
|
|||
|
||||
async def get_valid_orders(
|
||||
game: Game,
|
||||
client, # This is the BaseModelClient instance
|
||||
client, # BaseModelClient instance
|
||||
board_state,
|
||||
power_name: str,
|
||||
possible_orders: Dict[str, List[str]],
|
||||
game_history, # This is GameHistory instance
|
||||
model_error_stats: Dict[str, Dict[str, int]],
|
||||
agent_goals: Optional[List[str]] = None,
|
||||
agent_relationships: Optional[Dict[str, str]] = None,
|
||||
agent_private_diary_str: Optional[str] = None, # Added new parameter
|
||||
game_history,
|
||||
model_error_stats,
|
||||
agent_goals=None,
|
||||
agent_relationships=None,
|
||||
agent_private_diary_str=None,
|
||||
log_file_path: str = None,
|
||||
phase: str = None,
|
||||
) -> List[str]:
|
||||
) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Tries up to 'max_retries' to generate and validate orders.
|
||||
If invalid, we append the error feedback to the conversation
|
||||
context for the next retry. If still invalid, return fallback.
|
||||
Generates orders with the LLM, validates them by round-tripping through the
|
||||
engine, and returns **both** the accepted and rejected orders so the caller
|
||||
can record invalid attempts.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict : { "valid": [...], "invalid": [...] }
|
||||
"""
|
||||
|
||||
# Ask the LLM for orders
|
||||
orders = await client.get_orders(
|
||||
# ── 1. Ask the model ───────────────────────────────────────
|
||||
raw_orders = await client.get_orders(
|
||||
game=game,
|
||||
board_state=board_state,
|
||||
power_name=power_name,
|
||||
possible_orders=possible_orders,
|
||||
conversation_text=game_history, # Pass GameHistory instance
|
||||
conversation_text=game_history,
|
||||
model_error_stats=model_error_stats,
|
||||
agent_goals=agent_goals,
|
||||
agent_relationships=agent_relationships,
|
||||
agent_private_diary_str=agent_private_diary_str, # Pass the diary string
|
||||
agent_private_diary_str=agent_private_diary_str,
|
||||
log_file_path=log_file_path,
|
||||
phase=phase,
|
||||
)
|
||||
|
||||
# Initialize list to track invalid order information
|
||||
invalid_info = []
|
||||
|
||||
# Validate each order
|
||||
all_valid = True
|
||||
valid_orders = []
|
||||
|
||||
if not isinstance(orders, list): # Ensure orders is a list before iterating
|
||||
logger.warning(f"[{power_name}] Orders received from LLM is not a list: {orders}. Using fallback.")
|
||||
model_error_stats[client.model_name]["order_decoding_errors"] += 1 # Use client.model_name
|
||||
return client.fallback_orders(possible_orders)
|
||||
|
||||
for move in orders:
|
||||
# Skip empty orders
|
||||
if not move or move.strip() == "":
|
||||
continue
|
||||
|
||||
# Handle special case for WAIVE
|
||||
if move.upper() == "WAIVE":
|
||||
valid_orders.append(move)
|
||||
continue
|
||||
|
||||
# Example move: "A PAR H" -> unit="A PAR", order_part="H"
|
||||
tokens = move.split(" ", 2)
|
||||
if len(tokens) < 3:
|
||||
invalid_info.append(f"Order '{move}' is malformed; expected 'A PAR H' style.")
|
||||
all_valid = False
|
||||
continue
|
||||
|
||||
unit = " ".join(tokens[:2]) # e.g. "A PAR"
|
||||
order_part = tokens[2] # e.g. "H" or "S A MAR"
|
||||
invalid_info: list[str] = []
|
||||
valid: list[str] = []
|
||||
invalid: list[str] = []
|
||||
|
||||
# Use the internal game validation method
|
||||
if order_part == "B": # Build orders
|
||||
validity = 1 # hack because game._valid_order doesn't support 'B'
|
||||
elif order_part == "D": # Disband orders
|
||||
# Check if the unit is actually one of the power's units
|
||||
if unit in game.powers[power_name].units:
|
||||
validity = 1 # Simple check, engine handles full validation
|
||||
else:
|
||||
validity = 0
|
||||
else: # Movement, Support, Hold, Convoy, Retreat
|
||||
try:
|
||||
validity = game._valid_order(
|
||||
game.powers[power_name], unit, order_part, report=1
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error validating order '{move}': {e}")
|
||||
invalid_info.append(f"Order '{move}' caused an error: {e}")
|
||||
validity = 0
|
||||
all_valid = False
|
||||
|
||||
if validity == 1:
|
||||
valid_orders.append(move)
|
||||
else:
|
||||
invalid_info.append(f"Order '{move}' is invalid for {power_name}")
|
||||
all_valid = False
|
||||
|
||||
# Log validation results
|
||||
if invalid_info:
|
||||
logger.debug(f"[{power_name}] Invalid orders: {', '.join(invalid_info)}")
|
||||
|
||||
if all_valid and valid_orders:
|
||||
logger.debug(f"[{power_name}] All orders valid: {valid_orders}")
|
||||
return valid_orders
|
||||
else:
|
||||
logger.debug(f"[{power_name}] Some orders invalid, using fallback.")
|
||||
# Use client.model_name for stats key, as power_name might not be unique if multiple agents use same model
|
||||
# ── 2. Type check ──────────────────────────────────────────
|
||||
if not isinstance(raw_orders, list):
|
||||
logger.warning("[%s] Orders received from LLM are not a list: %s. Using fallback.",
|
||||
power_name, raw_orders)
|
||||
model_error_stats[client.model_name]["order_decoding_errors"] += 1
|
||||
return {"valid": client.fallback_orders(possible_orders), "invalid": []}
|
||||
|
||||
# ── 3. Round-trip validation with engine ───────────────────
|
||||
CODE_TO_ENGINE = {
|
||||
"AUT": "AUSTRIA", "ENG": "ENGLAND", "FRA": "FRANCE",
|
||||
"GER": "GERMANY", "ITA": "ITALY", "RUS": "RUSSIA", "TUR": "TURKEY",
|
||||
}
|
||||
engine_power = power_name if power_name in game.powers else CODE_TO_ENGINE[power_name]
|
||||
|
||||
for move in raw_orders:
|
||||
if not move or not move.strip():
|
||||
continue
|
||||
|
||||
upper = move.upper()
|
||||
|
||||
# WAIVE is always valid
|
||||
if upper == "WAIVE":
|
||||
valid.append("WAIVE")
|
||||
continue
|
||||
|
||||
game.clear_orders(engine_power)
|
||||
game.set_orders(engine_power, [upper])
|
||||
normed = game.get_orders(engine_power)
|
||||
|
||||
if normed: # accepted
|
||||
valid.append(normed[0])
|
||||
else: # rejected
|
||||
invalid.append(upper)
|
||||
invalid_info.append(f"Order '{move}' is invalid for {power_name}")
|
||||
|
||||
game.clear_orders(engine_power) # clean slate for main engine flow
|
||||
|
||||
# ── 4. Legacy logging & stats updates ──────────────────────
|
||||
if invalid_info: # at least one bad move
|
||||
logger.debug("[%s] Invalid orders: %s", power_name, ", ".join(invalid_info))
|
||||
model_error_stats[client.model_name]["order_decoding_errors"] += 1
|
||||
logger.debug("[%s] Some orders invalid, using fallback.", power_name)
|
||||
else:
|
||||
logger.debug("[%s] All orders valid: %s", power_name, valid)
|
||||
|
||||
# ── 5. Fallback when nothing survives ─────────────────────
|
||||
if not valid:
|
||||
fallback = client.fallback_orders(possible_orders)
|
||||
return fallback
|
||||
return {"valid": fallback, "invalid": invalid}
|
||||
|
||||
return {"valid": valid, "invalid": invalid}
|
||||
|
||||
|
||||
|
||||
|
||||
def normalize_and_compare_orders(
|
||||
|
|
|
|||
|
|
@ -99,6 +99,22 @@ def _add_experiment_flags(p: argparse.ArgumentParser) -> None:
|
|||
"flag to lm_game for deterministic behaviour)."
|
||||
),
|
||||
)
|
||||
p.add_argument(
|
||||
"--compare_to",
|
||||
type=Path,
|
||||
default=None,
|
||||
help=(
|
||||
"Path to another completed experiment directory. "
|
||||
"If supplied alongside --experiment_dir, the runner skips game "
|
||||
"execution and produces a statistical comparison between the two."
|
||||
),
|
||||
)
|
||||
p.add_argument(
|
||||
"--sig_level",
|
||||
type=float,
|
||||
default=0.05,
|
||||
help="α for hypothesis tests in comparison mode (default 0.05).",
|
||||
)
|
||||
|
||||
|
||||
def _add_lm_game_flags(p: argparse.ArgumentParser) -> None:
|
||||
|
|
@ -169,6 +185,17 @@ def _add_lm_game_flags(p: argparse.ArgumentParser) -> None:
|
|||
"Defaults to the packaged prompts directory."
|
||||
),
|
||||
)
|
||||
p.add_argument(
|
||||
"--simple_prompts",
|
||||
type=_str2bool,
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
help=(
|
||||
"When true (1 / true / yes) the engine switches to simpler prompts "
|
||||
"which low-midrange models handle better."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ────────────────────────────────────────────────────────────────────────────
|
||||
|
|
@ -255,6 +282,20 @@ _RunInfo = collections.namedtuple(
|
|||
"_RunInfo", "index run_dir seed cmd_line returncode elapsed_s"
|
||||
)
|
||||
|
||||
def _str2bool(v: str | bool) -> bool:
|
||||
"""
|
||||
Accepts typical textual truthy / falsy values and returns a bool.
|
||||
Mirrors the helper used inside lm_game.
|
||||
"""
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
val = v.lower()
|
||||
if val in ("yes", "y", "true", "t", "1"):
|
||||
return True
|
||||
if val in ("no", "n", "false", "f", "0"):
|
||||
return False
|
||||
raise argparse.ArgumentTypeError("boolean value expected")
|
||||
|
||||
|
||||
def _mk_run_dir(exp_dir: Path, idx: int) -> Path:
|
||||
run_dir = exp_dir / "runs" / f"run_{idx:05d}"
|
||||
|
|
@ -372,6 +413,13 @@ def main() -> None:
|
|||
log.info("Appending to existing experiment: %s", exp_dir)
|
||||
exp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if exp_args.compare_to is not None:
|
||||
from experiment_runner.analysis import compare_stats # local import
|
||||
|
||||
compare_stats.run(exp_dir, exp_args.compare_to, alpha=exp_args.sig_level)
|
||||
log.info("comparison complete; artefacts in %s/analysis/comparison", exp_dir)
|
||||
return
|
||||
|
||||
# Persist experiment-level config
|
||||
cfg_path = exp_dir / "config.json"
|
||||
if not cfg_path.exists(): # ← new guard
|
||||
|
|
|
|||
168
experiment_runner/analysis/compare_stats.py
Normal file
168
experiment_runner/analysis/compare_stats.py
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
"""
|
||||
experiment_runner.analysis.compare_stats
|
||||
----------------------------------------
|
||||
|
||||
Compares two completed Diplomacy experiments. Console output now
|
||||
shows *all* metrics whose 95 % CI excludes 0 (α = 0.05 by default).
|
||||
|
||||
CSV files remain:
|
||||
|
||||
<expA>/analysis/comparison/
|
||||
comparison_aggregated_vs_<expB>.csv
|
||||
comparison_by_power_vs_<expB>.csv
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from scipy import stats
|
||||
|
||||
# ───────────────────────── helpers ──────────────────────────
|
||||
_EXCLUDE: set[str] = {
|
||||
"game_id",
|
||||
"llm_model",
|
||||
"power_name",
|
||||
"game_phase",
|
||||
"analyzed_response_type",
|
||||
}
|
||||
|
||||
|
||||
def _numeric_columns(df: pd.DataFrame) -> List[str]:
|
||||
return [c for c in df.select_dtypes("number").columns if c not in _EXCLUDE]
|
||||
|
||||
|
||||
def _load_games(exp: Path) -> pd.DataFrame:
|
||||
indiv = exp / "analysis" / "statistical_game_analysis" / "individual"
|
||||
csvs = list(indiv.glob("*_game_analysis.csv"))
|
||||
if not csvs:
|
||||
raise FileNotFoundError(f"no *_game_analysis.csv under {indiv}")
|
||||
return pd.concat((pd.read_csv(p) for p in csvs), ignore_index=True)
|
||||
|
||||
|
||||
def _welch(a: np.ndarray, b: np.ndarray, alpha: float) -> Dict:
|
||||
_t, p_val = stats.ttest_ind(a, b, equal_var=False)
|
||||
mean_a, mean_b = a.mean(), b.mean()
|
||||
diff = mean_b - mean_a
|
||||
pooled_sd = np.sqrt((a.var(ddof=1) + b.var(ddof=1)) / 2)
|
||||
cohen_d = diff / pooled_sd if pooled_sd else np.nan
|
||||
se = np.sqrt(a.var(ddof=1) / len(a) + b.var(ddof=1) / len(b))
|
||||
df = len(a) + len(b) - 2
|
||||
ci = stats.t.ppf(1 - alpha / 2, df=df) * se
|
||||
return {
|
||||
"Mean_A": mean_a,
|
||||
"Mean_B": mean_b,
|
||||
"Diff": diff,
|
||||
"CI_low": diff - ci,
|
||||
"CI_high": diff + ci,
|
||||
"p_value": p_val,
|
||||
"Cohen_d": cohen_d,
|
||||
"n_A": len(a),
|
||||
"n_B": len(b),
|
||||
}
|
||||
|
||||
|
||||
# ───────────────────────── console formatting ─────────────────────────
|
||||
def _fmt_row(label: str, r: Dict, width: int) -> str:
|
||||
ci = f"[{r['CI_low']:+.2f}, {r['CI_high']:+.2f}]"
|
||||
return (
|
||||
f" {label:<{width}} "
|
||||
f"{r['Diff']:+6.2f} "
|
||||
f"({r['Mean_A']:.2f} → {r['Mean_B']:.2f}) "
|
||||
f"95%CI {ci:<17} "
|
||||
f"p={r['p_value']:.4g} "
|
||||
f"d={r['Cohen_d']:+.2f}"
|
||||
)
|
||||
|
||||
|
||||
def _print_hdr(title: str) -> None:
|
||||
print(f"\n{title}")
|
||||
print("─" * len(title))
|
||||
|
||||
|
||||
def _significant(df: pd.DataFrame, alpha: float) -> pd.DataFrame:
|
||||
"""Return rows whose CI excludes 0 (equivalently p < alpha)."""
|
||||
sig = df[
|
||||
((df["CI_low"] > 0) & (df["CI_high"] > 0))
|
||||
| ((df["CI_low"] < 0) & (df["CI_high"] < 0))
|
||||
| (df["p_value"] < alpha) # fallback, same criterion
|
||||
].copy()
|
||||
return sig.sort_values("p_value").reset_index(drop=True)
|
||||
|
||||
|
||||
# ───────────────────────── public API ─────────────────────────
|
||||
def run(exp_a: Path, exp_b: Path, alpha: float = 0.05) -> None:
|
||||
df_a = _load_games(exp_a)
|
||||
df_b = _load_games(exp_b)
|
||||
|
||||
metrics = sorted(set(_numeric_columns(df_a)) & set(_numeric_columns(df_b)))
|
||||
if not metrics:
|
||||
print("No overlapping numeric metrics to compare.")
|
||||
return
|
||||
|
||||
out_dir = exp_a / "analysis" / "comparison"
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ── section 1: aggregated across powers ───────────────────────────
|
||||
rows_agg: List[Dict] = []
|
||||
for col in metrics:
|
||||
a_vals = df_a.groupby("game_id")[col].mean().dropna().to_numpy()
|
||||
b_vals = df_b.groupby("game_id")[col].mean().dropna().to_numpy()
|
||||
if len(a_vals) < 2 or len(b_vals) < 2:
|
||||
continue
|
||||
rows_agg.append({"Metric": col, **_welch(a_vals, b_vals, alpha)})
|
||||
|
||||
agg_df = pd.DataFrame(rows_agg)
|
||||
agg_csv = out_dir / f"comparison_aggregated_vs_{exp_b.name}.csv"
|
||||
agg_df.to_csv(agg_csv, index=False)
|
||||
|
||||
sig_agg = _significant(agg_df, alpha)
|
||||
if not sig_agg.empty:
|
||||
n_a, n_b = int(sig_agg.iloc[0]["n_A"]), int(sig_agg.iloc[0]["n_B"])
|
||||
_print_hdr(f"Aggregated Across Powers – significant at 95 % CI (nA={n_a}, nB={n_b})")
|
||||
label_w = max(len(m) for m in sig_agg["Metric"]) + 2
|
||||
for _, r in sig_agg.iterrows():
|
||||
print(_fmt_row(r["Metric"], r, label_w))
|
||||
else:
|
||||
_print_hdr("Aggregated Across Powers – no metric significant at 95 % CI")
|
||||
|
||||
# ── section 2: per-power breakdown ───────────────────────────────
|
||||
rows_pow: List[Dict] = []
|
||||
powers = sorted(set(df_a["power_name"]) & set(df_b["power_name"]))
|
||||
for power in powers:
|
||||
sub_a = df_a[df_a["power_name"] == power]
|
||||
sub_b = df_b[df_b["power_name"] == power]
|
||||
for col in metrics:
|
||||
a_vals = sub_a[col].dropna().to_numpy()
|
||||
b_vals = sub_b[col].dropna().to_numpy()
|
||||
if len(a_vals) < 2 or len(b_vals) < 2:
|
||||
continue
|
||||
rows_pow.append(
|
||||
{"Power": power, "Metric": col, **_welch(a_vals, b_vals, alpha)}
|
||||
)
|
||||
|
||||
pow_df = pd.DataFrame(rows_pow)
|
||||
pow_csv = out_dir / f"comparison_by_power_vs_{exp_b.name}.csv"
|
||||
pow_df.to_csv(pow_csv, index=False)
|
||||
|
||||
sig_pow = _significant(pow_df, alpha)
|
||||
if not sig_pow.empty:
|
||||
_print_hdr(f"Per-Power Breakdown – metrics significant at 95 % CI (α={alpha})")
|
||||
label_w = max(len(m) for m in sig_pow["Metric"]) + 2
|
||||
for power in powers:
|
||||
sub = sig_pow[sig_pow["Power"] == power]
|
||||
if sub.empty:
|
||||
continue
|
||||
n_a, n_b = int(sub.iloc[0]["n_A"]), int(sub.iloc[0]["n_B"])
|
||||
print(f"{power} (nA={n_a}, nB={n_b})")
|
||||
for _, r in sub.iterrows():
|
||||
print(_fmt_row(r["Metric"], r, label_w))
|
||||
else:
|
||||
_print_hdr("Per-Power Breakdown – no metric significant at 95 % CI")
|
||||
|
||||
# ── summary of file outputs ───────────────────────────────────────
|
||||
print("\nCSV outputs:")
|
||||
print(f" • {agg_csv}")
|
||||
print(f" • {pow_csv}")
|
||||
47
lm_game.py
47
lm_game.py
|
|
@ -313,18 +313,30 @@ async def main():
|
|||
|
||||
active_powers = [p for p, a in agents.items() if not game.powers[p].is_eliminated()]
|
||||
order_power_names = [p for p in active_powers if gather_possible_orders(game, p)]
|
||||
submitted_orders_this_phase = defaultdict(list)
|
||||
|
||||
for i, result in enumerate(order_results):
|
||||
p_name = order_power_names[i]
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Error getting orders for {p_name}: {result}", exc_info=result)
|
||||
game.set_orders(p_name, [])
|
||||
else:
|
||||
orders = result
|
||||
game.set_orders(p_name, orders)
|
||||
if orders:
|
||||
await agents[p_name].generate_order_diary_entry(game, orders, llm_log_file_path)
|
||||
|
||||
if isinstance(result, Exception):
|
||||
logger.error("Error getting orders for %s: %s", p_name, result, exc_info=result)
|
||||
valid, invalid = [], []
|
||||
else:
|
||||
valid = result.get("valid", [])
|
||||
invalid = result.get("invalid", [])
|
||||
|
||||
# what the engine will actually execute
|
||||
game.set_orders(p_name, valid)
|
||||
|
||||
# what we record for prompt/history purposes
|
||||
submitted_orders_this_phase[p_name] = valid + invalid
|
||||
|
||||
# optional: diary entry only for the orders we tried to submit
|
||||
if valid or invalid:
|
||||
await agents[p_name].generate_order_diary_entry(
|
||||
game, valid + invalid, llm_log_file_path
|
||||
)
|
||||
|
||||
# --- 4d. Process Phase ---
|
||||
completed_phase = current_phase
|
||||
game.process()
|
||||
|
|
@ -333,6 +345,25 @@ async def main():
|
|||
logger.info(f"{power_name}: {power.centers}")
|
||||
|
||||
# --- 4e. Post-Processing and State Updates ---
|
||||
phase_history_from_game = game.get_phase_history()
|
||||
if phase_history_from_game:
|
||||
last_phase_from_game = phase_history_from_game[-1]
|
||||
if last_phase_from_game.name == completed_phase:
|
||||
phase_obj_in_my_history = game_history._get_phase(completed_phase)
|
||||
if phase_obj_in_my_history:
|
||||
# Store the orders the agents generated
|
||||
phase_obj_in_my_history.submitted_orders_by_power = submitted_orders_this_phase
|
||||
# Store the orders the engine actually accepted
|
||||
phase_obj_in_my_history.orders_by_power = last_phase_from_game.orders
|
||||
|
||||
# Store the results for the accepted orders
|
||||
converted_results = defaultdict(list)
|
||||
if last_phase_from_game.results:
|
||||
for pwr, res_list in last_phase_from_game.results.items():
|
||||
converted_results[pwr] = [[res] for res in res_list]
|
||||
phase_obj_in_my_history.results_by_power = converted_results
|
||||
logger.debug(f"Populated submitted/accepted order and result history for phase {completed_phase}.")
|
||||
|
||||
phase_summary = game.phase_summaries.get(current_phase, "(Summary not generated)")
|
||||
all_orders_this_phase = game.order_history.get(current_short_phase, {})
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue