AI_Diplomacy/ai_diplomacy/utils.py
AlxAI 9fc25f2fec Add comprehensive Diplomacy analysis with visualizations
- Added diplomacy_unified_analysis_final.py: Complete analysis script with CSV-only approach
- Added DIPLOMACY_ANALYSIS_DOCUMENTATION.md: Comprehensive project documentation
- Added visualization_experiments_log.md: Detailed development history
- Added visualization_results/: AAAI-quality visualizations showing model evolution
- Fixed old format success calculation bug (results keyed by unit location)
- Demonstrated AI evolution from passive to active play across 61 models
- Updated .gitignore to exclude results_alpha

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-27 13:29:29 -04:00

642 lines
No EOL
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from dotenv import load_dotenv
import logging
import os
from typing import Dict, List, Tuple, Set, Optional
from diplomacy import Game
import csv
from typing import TYPE_CHECKING
import random
import string
import json
import asyncio
from openai import RateLimitError, APIConnectionError, APITimeoutError
import aiohttp
import requests
from pathlib import Path
from config import config
from models import POWERS_ORDER
# Avoid circular import for type hinting
if TYPE_CHECKING:
from .clients import BaseModelClient
# If DiplomacyAgent is used for type hinting for an 'agent' parameter:
# from .agent import DiplomacyAgent
logger = logging.getLogger("utils")
logger.setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)
load_dotenv()
def atomic_write_json(data: dict, filepath: str):
"""Writes a dictionary to a JSON file atomically."""
try:
# Ensure the directory exists
dir_name = os.path.dirname(filepath)
if dir_name:
os.makedirs(dir_name, exist_ok=True)
# Write to a temporary file in the same directory
temp_filepath = f"{filepath}.tmp.{os.getpid()}"
with open(temp_filepath, "w", encoding="utf-8") as f:
json.dump(data, f, indent=4)
# Atomically rename the temporary file to the final destination
os.rename(temp_filepath, filepath)
except Exception as e:
logger.error(f"Failed to perform atomic write to {filepath}: {e}", exc_info=True)
# Clean up temp file if it exists
if os.path.exists(temp_filepath):
try:
os.remove(temp_filepath)
except Exception as e_clean:
logger.error(f"Failed to clean up temp file {temp_filepath}: {e_clean}")
def assign_models_to_powers() -> Dict[str, str]:
"""
Example usage: define which model each power uses.
Return a dict: { power_name: model_id, ... }
POWERS = ['AUSTRIA', 'ENGLAND', 'FRANCE', 'GERMANY', 'ITALY', 'RUSSIA', 'TURKEY']
Models supported: o3-mini, o4-mini, o3, gpt-4o, gpt-4o-mini,
claude-opus-4-20250514, claude-sonnet-4-20250514, claude-3-5-haiku-20241022, claude-3-5-sonnet-20241022, claude-3-7-sonnet-20250219
gemini-2.0-flash, gemini-2.5-flash-preview-04-17, gemini-2.5-pro-preview-03-25,
deepseek-chat, deepseek-reasoner
openrouter-meta-llama/llama-3.3-70b-instruct, openrouter-qwen/qwen3-235b-a22b, openrouter-microsoft/phi-4-reasoning-plus:free,
openrouter-deepseek/deepseek-prover-v2:free, openrouter-meta-llama/llama-4-maverick:free, openrouter-nvidia/llama-3.3-nemotron-super-49b-v1:free,
openrouter-google/gemma-3-12b-it:free, openrouter-google/gemini-2.5-flash-preview-05-20
"""
# POWER MODELS
return {
"AUSTRIA": "o4-mini",
"ENGLAND": "o3",
"FRANCE": "gpt-5-reasoning-alpha-2025-07-19",
"GERMANY": "gpt-4.1",
"ITALY": "o4-mini",
"RUSSIA": "gpt-5-reasoning-alpha-2025-07-19",
"TURKEY": "o4-mini",
}
# TEST MODELS
"""
return {
"AUSTRIA": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
"ENGLAND": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
"FRANCE": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
"GERMANY": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
"ITALY": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
"RUSSIA": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
"TURKEY": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
}
"""
def get_special_models() -> Dict[str, str]:
"""
Define models for special purposes like phase summaries and formatting.
These can be overridden via environment variables:
- AI_DIPLOMACY_NARRATIVE_MODEL: Model for phase summaries (default: "o3")
- AI_DIPLOMACY_FORMATTER_MODEL: Model for JSON formatting (default: "google/gemini-2.5-flash-lite-preview-06-17")
Returns:
dict: {
"phase_summary": model for generating narrative phase summaries,
"formatter": model for formatting natural language to JSON
}
Examples:
# Use Claude for phase summaries
export AI_DIPLOMACY_NARRATIVE_MODEL="claude-3-5-sonnet-20241022"
# Use a different Gemini model for formatting
export AI_DIPLOMACY_FORMATTER_MODEL="gemini-2.0-flash"
"""
return {"phase_summary": config.AI_DIPLOMACY_NARRATIVE_MODEL, "formatter": config.AI_DIPLOMACY_FORMATTER_MODEL}
def gather_possible_orders(game: Game, power_name: str) -> Dict[str, List[str]]:
"""
Returns a dictionary mapping each orderable location to the list of valid orders.
"""
orderable_locs = game.get_orderable_locations(power_name)
all_possible = game.get_all_possible_orders()
result = {}
for loc in orderable_locs:
result[loc] = all_possible.get(loc, [])
return result
async def get_valid_orders(
game: Game,
client, # BaseModelClient instance
board_state,
power_name: str,
possible_orders: Dict[str, List[str]],
game_history,
model_error_stats,
agent_goals=None,
agent_relationships=None,
agent_private_diary_str=None,
log_file_path: str = None,
phase: str = None,
) -> Dict[str, List[str]]:
"""
Generates orders with the LLM, validates them by round-tripping through the
engine, and returns **both** the accepted and rejected orders so the caller
can record invalid attempts.
Returns
-------
dict : { "valid": [...], "invalid": [...] }
"""
# ── 1. Ask the model ───────────────────────────────────────
raw_orders = await client.get_orders(
game=game,
board_state=board_state,
power_name=power_name,
possible_orders=possible_orders,
conversation_text=game_history,
model_error_stats=model_error_stats,
agent_goals=agent_goals,
agent_relationships=agent_relationships,
agent_private_diary_str=agent_private_diary_str,
log_file_path=log_file_path,
phase=phase,
)
invalid_info: list[str] = []
valid: list[str] = []
invalid: list[str] = []
# ── 2. Type check ──────────────────────────────────────────
if not isinstance(raw_orders, list):
logger.warning("[%s] Orders received from LLM are not a list: %s. Using fallback.", power_name, raw_orders)
model_error_stats[client.model_name]["order_decoding_errors"] += 1
return {"valid": client.fallback_orders(possible_orders), "invalid": []}
# ── 3. Round-trip validation with engine ───────────────────
CODE_TO_ENGINE = {
"AUT": "AUSTRIA",
"ENG": "ENGLAND",
"FRA": "FRANCE",
"GER": "GERMANY",
"ITA": "ITALY",
"RUS": "RUSSIA",
"TUR": "TURKEY",
}
engine_power = power_name if power_name in game.powers else CODE_TO_ENGINE[power_name]
for move in raw_orders:
if not move or not move.strip():
continue
upper = move.upper()
# WAIVE is always valid
if upper == "WAIVE":
valid.append("WAIVE")
continue
game.clear_orders(engine_power)
game.set_orders(engine_power, [upper])
normed = game.get_orders(engine_power)
if normed: # accepted
valid.append(normed[0])
else: # rejected
invalid.append(upper)
invalid_info.append(f"Order '{move}' is invalid for {power_name}")
game.clear_orders(engine_power) # clean slate for main engine flow
# ── 4. Legacy logging & stats updates ──────────────────────
if invalid_info: # at least one bad move
logger.debug("[%s] Invalid orders: %s", power_name, ", ".join(invalid_info))
model_error_stats[client.model_name]["order_decoding_errors"] += 1
logger.debug("[%s] Some orders invalid, using fallback.", power_name)
else:
logger.debug("[%s] All orders valid: %s", power_name, valid)
# ── 5. Fallback when nothing survives ─────────────────────
if not valid:
fallback = client.fallback_orders(possible_orders)
return {"valid": fallback, "invalid": invalid}
return {"valid": valid, "invalid": invalid}
def normalize_and_compare_orders(
issued_orders: Dict[str, List[str]],
accepted_orders_dict: Dict[str, List[str]],
game: Game,
) -> Tuple[Dict[str, Set[str]], Dict[str, Set[str]]]:
"""
Normalizes and compares issued orders against accepted orders from the game engine.
Uses the map's built-in normalization methods to ensure consistent formatting.
Args:
issued_orders: Dictionary of orders issued by power {power_name: [orders]}
accepted_orders_dict: Dictionary of orders accepted by the engine,
typically from game.get_state()["orders"].
game: The current Game object containing the map.
Returns:
Tuple[Dict[str, Set[str]], Dict[str, Set[str]]]: (orders_not_accepted, orders_not_issued)
- orders_not_accepted: Orders issued but not accepted by engine (normalized).
- orders_not_issued: Orders accepted by engine but not issued (normalized).
"""
game_map = game.map
def normalize_order(order: str) -> str:
# Inner function to normalize a single order string using the game map.
if not order:
return order
try:
# Use map's normalization methods directly
normalized = game_map.norm(order)
# Further split and normalize parts for complex orders if necessary
# (This part might need refinement depending on how complex orders are handled
# and represented after initial normalization by game_map.norm)
# Example (simplified, game_map.norm often handles this):
# Split support orders
# parts = normalized.split(" S ")
# normalized_parts = []
# for part in parts:
# move_parts = part.split(" - ")
# move_parts = [game_map.norm(p.strip()) for p in move_parts]
# move_parts = [game_map.aliases.get(p, p) for p in move_parts]
# normalized_parts.append(" - ".join(move_parts))
# return " S ".join(normalized_parts)
return normalized # Return the directly normalized string for now
except Exception as e:
logger.warning(f"Could not normalize order '{order}': {e}")
return order # Return original if normalization fails
orders_not_accepted = {}
orders_not_issued = {}
all_powers = set(issued_orders.keys()) | set(accepted_orders_dict.keys())
for pwr in all_powers:
# Normalize issued orders for the power, handling potential absence
issued_set = set()
if pwr in issued_orders:
try:
issued_set = {normalize_order(o) for o in issued_orders.get(pwr, []) if o}
except Exception as e:
logger.error(f"Error normalizing issued orders for {pwr}: {e}")
# Normalize accepted orders for the power, handling potential absence
accepted_set = set()
if pwr in accepted_orders_dict:
try:
accepted_set = {normalize_order(o) for o in accepted_orders_dict.get(pwr, []) if o}
except Exception as e:
logger.error(f"Error normalizing accepted orders for {pwr}: {e}")
# Compare the sets
missing_from_engine = issued_set - accepted_set
missing_from_issued = accepted_set - issued_set
if missing_from_engine:
orders_not_accepted[pwr] = missing_from_engine
if missing_from_issued:
orders_not_issued[pwr] = missing_from_issued
return orders_not_accepted, orders_not_issued
def load_prompt(fname: str | Path, prompts_dir: str | Path | None = None) -> str:
"""
Resolve *fname* to an absolute path and return its contents.
Resolution rules (first match wins):
1. If *fname* is absolute -> use as-is.
2. If *prompts_dir* is given -> prompts_dir / fname
3. Otherwise -> <package_root>/prompts / fname
"""
fname = Path(fname)
if fname.is_absolute():
prompt_path = fname
else:
if prompts_dir is not None:
prompt_path = Path(prompts_dir) / fname
else:
package_root = Path(__file__).resolve().parent
prompt_path = package_root / "prompts" / fname
try:
content = prompt_path.read_text(encoding="utf-8").strip()
logger.debug(f"Loaded prompt from {prompt_path}, length: {len(content)}")
return content
except FileNotFoundError:
logger.error("Prompt file not found: %s", prompt_path)
raise Exception("Prompt file not found: " + str(prompt_path))
# == New LLM Response Logging Function ==
def log_llm_response(
log_file_path: str,
model_name: str,
power_name: Optional[str], # Optional for non-power-specific calls like summary
phase: str,
response_type: str,
raw_input_prompt: str, # Added new parameter for the raw input
raw_response: str,
success: str, # Changed from bool to str
):
"""Appends a raw LLM response to a CSV log file."""
try:
# Ensure the directory exists
log_dir = os.path.dirname(log_file_path)
if log_dir: # Ensure log_dir is not empty (e.g., if path is just a filename)
os.makedirs(log_dir, exist_ok=True)
# Check if file exists and has content to determine if we need headers
file_exists = os.path.isfile(log_file_path) and os.path.getsize(log_file_path) > 0
with open(log_file_path, "a", newline="", encoding="utf-8") as csvfile:
# Added "raw_input" to fieldnames
fieldnames = ["model", "power", "phase", "response_type", "raw_input", "raw_response", "success"]
writer = csv.DictWriter(
csvfile,
fieldnames=fieldnames,
quoting=csv.QUOTE_ALL, # Quote all fields to handle commas and newlines
escapechar="\\",
) # Use backslash for escaping
if not file_exists:
writer.writeheader() # Write header only if file is new
writer.writerow(
{
"model": model_name,
"power": power_name if power_name else "game", # Use 'game' if no specific power
"phase": phase,
"response_type": response_type,
"raw_input": raw_input_prompt, # Added raw_input to the row
"raw_response": raw_response,
"success": success,
}
)
except Exception as e:
logger.error(f"Failed to log LLM response to {log_file_path}: {e}", exc_info=True)
# A tuple of exception types that we consider safe to retry.
# This includes network issues, timeouts, rate limits, and the ValueError
# we now raise for empty/invalid responses.
RETRYABLE_EXCEPTIONS = (
RateLimitError,
APIConnectionError,
APITimeoutError,
aiohttp.ClientError,
requests.RequestException,
asyncio.TimeoutError,
ValueError, # We explicitly raise this for empty responses, which might be a temporary glitch.
)
async def run_llm_and_log(
client: "BaseModelClient",
prompt: str,
power_name: Optional[str],
phase: str,
response_type: str,
temperature: float = 0.0,
*,
attempts: int = 5,
backoff_base: float = 1.0,
backoff_factor: float = 2.0,
jitter: float = 0.3,
) -> str:
"""
Calls `client.generate_response` with robust retry logic and returns the raw output.
This function handles exceptions gracefully:
- It retries on a specific set of `RETRYABLE_EXCEPTIONS` (e.g., network errors, rate limits).
- It immediately stops and re-raises critical exceptions like `KeyboardInterrupt`.
- It logs a warning for each failed retry attempt.
- On final failure after all retries, it logs a detailed error and re-raises the last
exception, ensuring the calling code is aware of the failure.
"""
last_exception: Optional[Exception] = None
for attempt in range(attempts):
try:
raw_response = await client.generate_response(prompt, temperature=temperature)
# The clients now raise ValueError, but this is a final safeguard.
if not raw_response or not raw_response.strip():
raise ValueError("LLM client returned an empty or whitespace-only string.")
# Success!
return raw_response
except RETRYABLE_EXCEPTIONS as e:
last_exception = e
if attempt == attempts - 1:
# This was the last attempt, so we'll fall through to the final error handling.
break
# Calculate exponential backoff with jitter
delay = backoff_base * (backoff_factor**attempt) + random.uniform(0, jitter)
logger.warning(
f"LLM call failed for {client.model_name}/{power_name} (Attempt {attempt + 1}/{attempts}). "
f"Error: {type(e).__name__}('{e}'). Retrying in {delay:.2f} seconds."
)
await asyncio.sleep(delay)
except (KeyboardInterrupt, asyncio.CancelledError):
# If the user hits Ctrl-C or the task is cancelled, stop immediately.
logger.warning(f"LLM call for {client.model_name}/{power_name} was cancelled or interrupted by user.")
raise # Re-raise to allow the application to exit cleanly.
except Exception as e:
last_exception = e
if attempt == attempts - 1:
# This was the last attempt, so we'll fall through to the final error handling.
break
# Calculate exponential backoff with jitter
delay = backoff_base * (backoff_factor**attempt) + random.uniform(0, jitter)
logger.error(
f"An unexpected error occurred during LLM call for {client.model_name}/{power_name}: {e}"
f"LLM call failed for {client.model_name}/{power_name} (Attempt {attempt + 1}/{attempts}). "
f"Error: {type(e).__name__}('{e}'). Retrying in {delay:.2f} seconds.",
exc_info=True,
)
await asyncio.sleep(delay)
# This part of the code is only reached if all retry attempts have failed.
final_error_message = (
f"API Error after {attempts} attempts for {client.model_name}/{power_name}/{response_type} "
f"in phase {phase}. Final error: {type(last_exception).__name__}('{last_exception}')"
)
logger.error(final_error_message, exc_info=last_exception)
# Re-raise the last captured exception so the caller knows the operation failed.
# 'from None' prevents chaining the exception with the try/except block itself.
raise last_exception from None
# This generates a few lines of random alphanum chars to inject into the
# system prompt. This lets us use temp=0 while still getting variation
# between trials.
# Temp=0 is important for better performance on deciding moves, and to
# ensure valid json outputs.
def generate_random_seed(n_lines: int = 5, n_chars_per_line: int = 80):
# Generate x lines of y random alphanumeric characters
seed_lines = ["".join(random.choices(string.ascii_letters + string.digits, k=n_chars_per_line)) for _ in range(n_lines)]
random_seed_block = "<RANDOM SEED PLEASE IGNORE>\n" + "\n".join(seed_lines) + "\n</RANDOM SEED>"
return random_seed_block
def get_prompt_path(prompt_name: str) -> str:
"""Get the appropriate prompt path based on USE_UNFORMATTED_PROMPTS setting.
Args:
prompt_name: Base name of the prompt file (e.g., "conversation_instructions.txt")
Returns:
str: Either "unformatted/{prompt_name}" or just "{prompt_name}"
"""
if config.USE_UNFORMATTED_PROMPTS:
return f"unformatted/{prompt_name}"
else:
return prompt_name
def normalize_recipient_name(recipient: str) -> str:
"""Normalize recipient names to handle LLM typos and abbreviations."""
if not recipient:
return recipient
recipient = recipient.upper().strip()
# Handle common LLM typos and abbreviations found in data
name_mapping = {
"EGMANY": "GERMANY",
"GERMAN": "GERMANY",
"UK": "ENGLAND",
"BRIT": "ENGLAND",
"ENGLAND": "ENGLAND", # Keep as-is
"FRANCE": "FRANCE", # Keep as-is
"GERMANY": "GERMANY", # Keep as-is
"ITALY": "ITALY", # Keep as-is
"AUSTRIA": "AUSTRIA", # Keep as-is
"RUSSIA": "RUSSIA", # Keep as-is
"TURKEY": "TURKEY", # Keep as-is
"Germany": "GERMANY",
"England": "ENGLAND",
"France": "FRANCE",
"Italy": "ITALY",
"Russia": "RUSSIA",
"Austria": "AUSTRIA",
"Turkey": "TURKEY",
}
normalized = name_mapping.get(recipient, recipient)
return normalized
def parse_prompts_dir_arg(raw: str | None) -> Dict[str, Path]:
"""
Resolve --prompts_dir into a mapping {power: Path}.
Accepts either a single path or 7 comma-separated paths.
Every path is normalised to an **absolute** Path object
(using Path(...).expanduser().resolve()) and checked for existence.
"""
if not raw:
return {}
parts = [s.strip() for s in raw.split(",") if s.strip()]
if len(parts) not in {1, 7}:
raise ValueError(
f"--prompts_dir expects 1 or 7 paths, got {len(parts)} "
f"({raw})"
)
# Expand/resolve & verify
def _norm(p: str) -> Path:
path = Path(p).expanduser().resolve()
if not path.is_dir():
raise FileNotFoundError(f"Prompt directory not found: {path}")
return path
if len(parts) == 1:
path = _norm(parts[0])
return {pwr: path for pwr in POWERS_ORDER}
paths = [_norm(p) for p in parts]
return dict(zip(POWERS_ORDER, paths))
async def atomic_write_json_async(data: dict, filepath: str):
"""Writes a dictionary to a JSON file atomically using async I/O."""
# Use asyncio.to_thread to run the synchronous atomic_write_json in a thread pool
# This prevents blocking the event loop while maintaining all the safety guarantees
await asyncio.to_thread(atomic_write_json, data, filepath)
async def log_llm_response_async(
log_file_path: str,
model_name: str,
power_name: Optional[str],
phase: str,
response_type: str,
raw_input_prompt: str,
raw_response: str,
success: str,
):
"""Async version of log_llm_response that runs in a thread pool."""
await asyncio.to_thread(
log_llm_response,
log_file_path,
model_name,
power_name,
phase,
response_type,
raw_input_prompt,
raw_response,
success
)
def get_board_state(board_state: dict, game: Game) -> Tuple[str, str]:
# Build units representation with power status and counts
units_lines = []
for p, units in board_state["units"].items():
units_str = ", ".join(units)
units_count = len(units)
line = f" {p}: {units_count} unit{'s' if units_count != 1 else ''} {units_str}"
if game.powers[p].is_eliminated():
line += " [ELIMINATED]"
units_lines.append(line)
units_repr = "\n".join(units_lines)
# Build centers representation with power status and counts
centers_lines = []
for p, centers in board_state["centers"].items():
centers_str = ", ".join(centers)
centers_count = len(centers)
line = f" {p}: {centers_count} supply center{'s' if centers_count != 1 else ''} {centers_str}"
if game.powers[p].is_eliminated():
line += " [ELIMINATED]"
centers_lines.append(line)
centers_repr = "\n".join(centers_lines)
return (units_repr, centers_repr)