AI_Diplomacy/ai_diplomacy/utils.py
2025-07-03 14:27:50 -07:00

499 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from dotenv import load_dotenv
import logging
import os
from typing import Dict, List, Tuple, Set, Optional
from diplomacy import Game
import csv
from typing import TYPE_CHECKING
import random
import string
import json
import asyncio
from ..config import config
# Avoid circular import for type hinting
if TYPE_CHECKING:
from .clients import BaseModelClient
# If DiplomacyAgent is used for type hinting for an 'agent' parameter:
# from .agent import DiplomacyAgent
logger = logging.getLogger("utils")
logger.setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)
load_dotenv()
def atomic_write_json(data: dict, filepath: str):
"""Writes a dictionary to a JSON file atomically."""
try:
# Ensure the directory exists
dir_name = os.path.dirname(filepath)
if dir_name:
os.makedirs(dir_name, exist_ok=True)
# Write to a temporary file in the same directory
temp_filepath = f"{filepath}.tmp.{os.getpid()}"
with open(temp_filepath, "w", encoding="utf-8") as f:
json.dump(data, f, indent=4)
# Atomically rename the temporary file to the final destination
os.rename(temp_filepath, filepath)
except Exception as e:
logger.error(f"Failed to perform atomic write to {filepath}: {e}", exc_info=True)
# Clean up temp file if it exists
if os.path.exists(temp_filepath):
try:
os.remove(temp_filepath)
except Exception as e_clean:
logger.error(f"Failed to clean up temp file {temp_filepath}: {e_clean}")
def assign_models_to_powers() -> Dict[str, str]:
"""
Example usage: define which model each power uses.
Return a dict: { power_name: model_id, ... }
POWERS = ['AUSTRIA', 'ENGLAND', 'FRANCE', 'GERMANY', 'ITALY', 'RUSSIA', 'TURKEY']
Models supported: o3-mini, o4-mini, o3, gpt-4o, gpt-4o-mini,
claude-opus-4-20250514, claude-sonnet-4-20250514, claude-3-5-haiku-20241022, claude-3-5-sonnet-20241022, claude-3-7-sonnet-20250219
gemini-2.0-flash, gemini-2.5-flash-preview-04-17, gemini-2.5-pro-preview-03-25,
deepseek-chat, deepseek-reasoner
openrouter-meta-llama/llama-3.3-70b-instruct, openrouter-qwen/qwen3-235b-a22b, openrouter-microsoft/phi-4-reasoning-plus:free,
openrouter-deepseek/deepseek-prover-v2:free, openrouter-meta-llama/llama-4-maverick:free, openrouter-nvidia/llama-3.3-nemotron-super-49b-v1:free,
openrouter-google/gemma-3-12b-it:free, openrouter-google/gemini-2.5-flash-preview-05-20
"""
# POWER MODELS
"""
return {
"AUSTRIA": "openrouter-google/gemini-2.5-flash-preview-05-20",
"ENGLAND": "openrouter-moonshotai/kimi-dev-72b:free",
"FRANCE": "together-arcee-ai/AFM-4.5B-Preview",
"GERMANY": "openrouter-google/gemini-2.5-flash-lite-preview-06-17",
"ITALY": "together-lgai/exaone-deep-32b",
"RUSSIA": "deepseek-reasoner",
"TURKEY": "openrouter-cohere/command-a",
}
"""
# TEST MODELS
return {
"AUSTRIA": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
"ENGLAND": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
"FRANCE": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
"GERMANY": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
"ITALY": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
"RUSSIA": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
"TURKEY": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
}
def get_special_models() -> Dict[str, str]:
"""
Define models for special purposes like phase summaries and formatting.
These can be overridden via environment variables:
- AI_DIPLOMACY_NARRATIVE_MODEL: Model for phase summaries (default: "o3")
- AI_DIPLOMACY_FORMATTER_MODEL: Model for JSON formatting (default: "google/gemini-2.5-flash-lite-preview-06-17")
Returns:
dict: {
"phase_summary": model for generating narrative phase summaries,
"formatter": model for formatting natural language to JSON
}
Examples:
# Use Claude for phase summaries
export AI_DIPLOMACY_NARRATIVE_MODEL="claude-3-5-sonnet-20241022"
# Use a different Gemini model for formatting
export AI_DIPLOMACY_FORMATTER_MODEL="gemini-2.0-flash"
"""
return {"phase_summary": config.AI_DIPLOMACY_NARRATIVE_MODEL, "formatter": config.AI_DIPLOMACY_FORMATTER_MODEL}
def gather_possible_orders(game: Game, power_name: str) -> Dict[str, List[str]]:
"""
Returns a dictionary mapping each orderable location to the list of valid orders.
"""
orderable_locs = game.get_orderable_locations(power_name)
all_possible = game.get_all_possible_orders()
result = {}
for loc in orderable_locs:
result[loc] = all_possible.get(loc, [])
return result
async def get_valid_orders(
game: Game,
client, # BaseModelClient instance
board_state,
power_name: str,
possible_orders: Dict[str, List[str]],
game_history,
model_error_stats,
agent_goals=None,
agent_relationships=None,
agent_private_diary_str=None,
log_file_path: str = None,
phase: str = None,
) -> Dict[str, List[str]]:
"""
Generates orders with the LLM, validates them by round-tripping through the
engine, and returns **both** the accepted and rejected orders so the caller
can record invalid attempts.
Returns
-------
dict : { "valid": [...], "invalid": [...] }
"""
# ── 1. Ask the model ───────────────────────────────────────
raw_orders = await client.get_orders(
game=game,
board_state=board_state,
power_name=power_name,
possible_orders=possible_orders,
conversation_text=game_history,
model_error_stats=model_error_stats,
agent_goals=agent_goals,
agent_relationships=agent_relationships,
agent_private_diary_str=agent_private_diary_str,
log_file_path=log_file_path,
phase=phase,
)
invalid_info: list[str] = []
valid: list[str] = []
invalid: list[str] = []
# ── 2. Type check ──────────────────────────────────────────
if not isinstance(raw_orders, list):
logger.warning("[%s] Orders received from LLM are not a list: %s. Using fallback.", power_name, raw_orders)
model_error_stats[client.model_name]["order_decoding_errors"] += 1
return {"valid": client.fallback_orders(possible_orders), "invalid": []}
# ── 3. Round-trip validation with engine ───────────────────
CODE_TO_ENGINE = {
"AUT": "AUSTRIA",
"ENG": "ENGLAND",
"FRA": "FRANCE",
"GER": "GERMANY",
"ITA": "ITALY",
"RUS": "RUSSIA",
"TUR": "TURKEY",
}
engine_power = power_name if power_name in game.powers else CODE_TO_ENGINE[power_name]
for move in raw_orders:
if not move or not move.strip():
continue
upper = move.upper()
# WAIVE is always valid
if upper == "WAIVE":
valid.append("WAIVE")
continue
game.clear_orders(engine_power)
game.set_orders(engine_power, [upper])
normed = game.get_orders(engine_power)
if normed: # accepted
valid.append(normed[0])
else: # rejected
invalid.append(upper)
invalid_info.append(f"Order '{move}' is invalid for {power_name}")
game.clear_orders(engine_power) # clean slate for main engine flow
# ── 4. Legacy logging & stats updates ──────────────────────
if invalid_info: # at least one bad move
logger.debug("[%s] Invalid orders: %s", power_name, ", ".join(invalid_info))
model_error_stats[client.model_name]["order_decoding_errors"] += 1
logger.debug("[%s] Some orders invalid, using fallback.", power_name)
else:
logger.debug("[%s] All orders valid: %s", power_name, valid)
# ── 5. Fallback when nothing survives ─────────────────────
if not valid:
fallback = client.fallback_orders(possible_orders)
return {"valid": fallback, "invalid": invalid}
return {"valid": valid, "invalid": invalid}
def normalize_and_compare_orders(
issued_orders: Dict[str, List[str]],
accepted_orders_dict: Dict[str, List[str]],
game: Game,
) -> Tuple[Dict[str, Set[str]], Dict[str, Set[str]]]:
"""
Normalizes and compares issued orders against accepted orders from the game engine.
Uses the map's built-in normalization methods to ensure consistent formatting.
Args:
issued_orders: Dictionary of orders issued by power {power_name: [orders]}
accepted_orders_dict: Dictionary of orders accepted by the engine,
typically from game.get_state()["orders"].
game: The current Game object containing the map.
Returns:
Tuple[Dict[str, Set[str]], Dict[str, Set[str]]]: (orders_not_accepted, orders_not_issued)
- orders_not_accepted: Orders issued but not accepted by engine (normalized).
- orders_not_issued: Orders accepted by engine but not issued (normalized).
"""
game_map = game.map
def normalize_order(order: str) -> str:
# Inner function to normalize a single order string using the game map.
if not order:
return order
try:
# Use map's normalization methods directly
normalized = game_map.norm(order)
# Further split and normalize parts for complex orders if necessary
# (This part might need refinement depending on how complex orders are handled
# and represented after initial normalization by game_map.norm)
# Example (simplified, game_map.norm often handles this):
# Split support orders
# parts = normalized.split(" S ")
# normalized_parts = []
# for part in parts:
# move_parts = part.split(" - ")
# move_parts = [game_map.norm(p.strip()) for p in move_parts]
# move_parts = [game_map.aliases.get(p, p) for p in move_parts]
# normalized_parts.append(" - ".join(move_parts))
# return " S ".join(normalized_parts)
return normalized # Return the directly normalized string for now
except Exception as e:
logger.warning(f"Could not normalize order '{order}': {e}")
return order # Return original if normalization fails
orders_not_accepted = {}
orders_not_issued = {}
all_powers = set(issued_orders.keys()) | set(accepted_orders_dict.keys())
for pwr in all_powers:
# Normalize issued orders for the power, handling potential absence
issued_set = set()
if pwr in issued_orders:
try:
issued_set = {normalize_order(o) for o in issued_orders.get(pwr, []) if o}
except Exception as e:
logger.error(f"Error normalizing issued orders for {pwr}: {e}")
# Normalize accepted orders for the power, handling potential absence
accepted_set = set()
if pwr in accepted_orders_dict:
try:
accepted_set = {normalize_order(o) for o in accepted_orders_dict.get(pwr, []) if o}
except Exception as e:
logger.error(f"Error normalizing accepted orders for {pwr}: {e}")
# Compare the sets
missing_from_engine = issued_set - accepted_set
missing_from_issued = accepted_set - issued_set
if missing_from_engine:
orders_not_accepted[pwr] = missing_from_engine
if missing_from_issued:
orders_not_issued[pwr] = missing_from_issued
return orders_not_accepted, orders_not_issued
# Helper to load prompt text from file relative to the expected 'prompts' dir
def load_prompt(filename: str, prompts_dir: Optional[str] = None) -> str:
"""
Return the contents of *filename* while never joining paths twice.
Logic
-----
1. If *filename* is absolute → use it directly.
2. Elif *filename* already contains a path component (e.g. 'x/y.txt')
→ treat it as a relative path and use it directly.
3. Elif *prompts_dir* is provided → join prompts_dir + filename.
4. Otherwise → join the packages default prompts dir.
"""
if os.path.isabs(filename): # rule 1
prompt_path = filename
elif os.path.dirname(filename): # rule 2 (has slash)
# If it's a relative path with directory, join with prompts_dir if provided
if prompts_dir:
prompt_path = os.path.join(prompts_dir, filename)
else:
default_dir = os.path.join(os.path.dirname(__file__), "prompts")
prompt_path = os.path.join(default_dir, filename)
elif prompts_dir: # rule 3
prompt_path = os.path.join(prompts_dir, filename)
else: # rule 4
default_dir = os.path.join(os.path.dirname(__file__), "prompts")
prompt_path = os.path.join(default_dir, filename)
try:
with open(prompt_path, "r", encoding="utf-8") as fh:
return fh.read().strip()
except FileNotFoundError:
logger.error(f"Prompt file not found: {prompt_path}")
return ""
# == New LLM Response Logging Function ==
def log_llm_response(
log_file_path: str,
model_name: str,
power_name: Optional[str], # Optional for non-power-specific calls like summary
phase: str,
response_type: str,
raw_input_prompt: str, # Added new parameter for the raw input
raw_response: str,
success: str, # Changed from bool to str
):
"""Appends a raw LLM response to a CSV log file."""
try:
# Ensure the directory exists
log_dir = os.path.dirname(log_file_path)
if log_dir: # Ensure log_dir is not empty (e.g., if path is just a filename)
os.makedirs(log_dir, exist_ok=True)
# Check if file exists and has content to determine if we need headers
file_exists = os.path.isfile(log_file_path) and os.path.getsize(log_file_path) > 0
with open(log_file_path, "a", newline="", encoding="utf-8") as csvfile:
# Added "raw_input" to fieldnames
fieldnames = ["model", "power", "phase", "response_type", "raw_input", "raw_response", "success"]
writer = csv.DictWriter(
csvfile,
fieldnames=fieldnames,
quoting=csv.QUOTE_ALL, # Quote all fields to handle commas and newlines
escapechar="\\",
) # Use backslash for escaping
if not file_exists:
writer.writeheader() # Write header only if file is new
writer.writerow(
{
"model": model_name,
"power": power_name if power_name else "game", # Use 'game' if no specific power
"phase": phase,
"response_type": response_type,
"raw_input": raw_input_prompt, # Added raw_input to the row
"raw_response": raw_response,
"success": success,
}
)
except Exception as e:
logger.error(f"Failed to log LLM response to {log_file_path}: {e}", exc_info=True)
async def run_llm_and_log(
client: "BaseModelClient",
prompt: str,
log_file_path: str, # Kept for context, but not used for logging here
power_name: Optional[str], # Kept for context, but not used for logging here
phase: str, # Kept for context, but not used for logging here
response_type: str, # Kept for context, but not used for logging here
temperature: float = 0.0,
*,
attempts: int = 5,
backoff_base: float = 1.0,
backoff_factor: float = 2.0,
jitter: float = 0.3,
) -> str:
"""
Calls `client.generate_response` with retry logic and returns the raw output.
Logging behaviour is identical to the original implementation:
- On a final failure (after all retries), it logs a single error with the
same message format as before.
- If a retry eventually succeeds, no errors are logged.
"""
raw_response = "" # Initialize in case of error
for attempt in range(attempts):
try:
raw_response = await client.generate_response(prompt, temperature=temperature)
if not raw_response:
raise Exception("Empty response from client")
return raw_response
except Exception as e:
if attempt == attempts - 1:
logger.error(
f"API Error during LLM call for {client.model_name}/{power_name}/{response_type} in phase {phase}: {e}",
exc_info=True,
)
# Back-off before the next attempt (unless this was the last)
delay = backoff_base * (backoff_factor**attempt) + random.uniform(0, jitter)
await asyncio.sleep(delay)
return raw_response
# This generates a few lines of random alphanum chars to inject into the
# system prompt. This lets us use temp=0 while still getting variation
# between trials.
# Temp=0 is important for better performance on deciding moves, and to
# ensure valid json outputs.
def generate_random_seed(n_lines: int = 5, n_chars_per_line: int = 80):
# Generate x lines of y random alphanumeric characters
seed_lines = ["".join(random.choices(string.ascii_letters + string.digits, k=n_chars_per_line)) for _ in range(n_lines)]
random_seed_block = "<RANDOM SEED PLEASE IGNORE>\n" + "\n".join(seed_lines) + "\n</RANDOM SEED>"
return random_seed_block
def get_prompt_path(prompt_name: str) -> str:
"""Get the appropriate prompt path based on USE_UNFORMATTED_PROMPTS setting.
Args:
prompt_name: Base name of the prompt file (e.g., "conversation_instructions.txt")
Returns:
str: Either "unformatted/{prompt_name}" or just "{prompt_name}"
"""
if config.USE_UNFORMATTED_PROMPTS:
return f"unformatted/{prompt_name}"
else:
return prompt_name
def normalize_recipient_name(recipient: str) -> str:
"""Normalize recipient names to handle LLM typos and abbreviations."""
if not recipient:
return recipient
recipient = recipient.upper().strip()
# Handle common LLM typos and abbreviations found in data
name_mapping = {
"EGMANY": "GERMANY",
"GERMAN": "GERMANY",
"UK": "ENGLAND",
"BRIT": "ENGLAND",
"ENGLAND": "ENGLAND", # Keep as-is
"FRANCE": "FRANCE", # Keep as-is
"GERMANY": "GERMANY", # Keep as-is
"ITALY": "ITALY", # Keep as-is
"AUSTRIA": "AUSTRIA", # Keep as-is
"RUSSIA": "RUSSIA", # Keep as-is
"TURKEY": "TURKEY", # Keep as-is
"Germany": "GERMANY",
"England": "ENGLAND",
"France": "FRANCE",
"Italy": "ITALY",
"Russia": "RUSSIA",
"Austria": "AUSTRIA",
"Turkey": "TURKEY",
}
normalized = name_mapping.get(recipient, recipient)
return normalized