AI_Diplomacy/ai_diplomacy/utils.py

403 lines
No EOL
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from dotenv import load_dotenv
import logging
import os
from typing import Dict, List, Tuple, Set, Optional
from diplomacy import Game
import csv
from typing import TYPE_CHECKING
import random
import string
import json
# Avoid circular import for type hinting
if TYPE_CHECKING:
from .clients import BaseModelClient
# If DiplomacyAgent is used for type hinting for an 'agent' parameter:
# from .agent import DiplomacyAgent
logger = logging.getLogger("utils")
logger.setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)
load_dotenv()
def atomic_write_json(data: dict, filepath: str):
"""Writes a dictionary to a JSON file atomically."""
try:
# Ensure the directory exists
dir_name = os.path.dirname(filepath)
if dir_name:
os.makedirs(dir_name, exist_ok=True)
# Write to a temporary file in the same directory
temp_filepath = f"{filepath}.tmp.{os.getpid()}"
with open(temp_filepath, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=4)
# Atomically rename the temporary file to the final destination
os.rename(temp_filepath, filepath)
except Exception as e:
logger.error(f"Failed to perform atomic write to {filepath}: {e}", exc_info=True)
# Clean up temp file if it exists
if os.path.exists(temp_filepath):
try:
os.remove(temp_filepath)
except Exception as e_clean:
logger.error(f"Failed to clean up temp file {temp_filepath}: {e_clean}")
def assign_models_to_powers() -> Dict[str, str]:
"""
Example usage: define which model each power uses.
Return a dict: { power_name: model_id, ... }
POWERS = ['AUSTRIA', 'ENGLAND', 'FRANCE', 'GERMANY', 'ITALY', 'RUSSIA', 'TURKEY']
Models supported: o3-mini, o4-mini, o3, gpt-4o, gpt-4o-mini,
claude-opus-4-20250514, claude-sonnet-4-20250514, claude-3-5-haiku-20241022, claude-3-5-sonnet-20241022, claude-3-7-sonnet-20250219
gemini-2.0-flash, gemini-2.5-flash-preview-04-17, gemini-2.5-pro-preview-03-25,
deepseek-chat, deepseek-reasoner
openrouter-meta-llama/llama-3.3-70b-instruct, openrouter-qwen/qwen3-235b-a22b, openrouter-microsoft/phi-4-reasoning-plus:free,
openrouter-deepseek/deepseek-prover-v2:free, openrouter-meta-llama/llama-4-maverick:free, openrouter-nvidia/llama-3.3-nemotron-super-49b-v1:free,
openrouter-google/gemma-3-12b-it:free, openrouter-google/gemini-2.5-flash-preview-05-20
"""
# POWER MODELS
"""
return {
"AUSTRIA": "openrouter-google/gemini-2.5-flash-preview-05-20",
"ENGLAND": "openrouter-moonshotai/kimi-dev-72b:free",
"FRANCE": "together-arcee-ai/AFM-4.5B-Preview",
"GERMANY": "openrouter-google/gemini-2.5-flash-lite-preview-06-17",
"ITALY": "together-lgai/exaone-deep-32b",
"RUSSIA": "deepseek-reasoner",
"TURKEY": "openrouter-cohere/command-a",
}
"""
# TEST MODELS
return {
"AUSTRIA": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
"ENGLAND": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
"FRANCE": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
"GERMANY": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
"ITALY": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
"RUSSIA": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
"TURKEY": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
}
def gather_possible_orders(game: Game, power_name: str) -> Dict[str, List[str]]:
"""
Returns a dictionary mapping each orderable location to the list of valid orders.
"""
orderable_locs = game.get_orderable_locations(power_name)
all_possible = game.get_all_possible_orders()
result = {}
for loc in orderable_locs:
result[loc] = all_possible.get(loc, [])
return result
async def get_valid_orders(
game: Game,
client, # BaseModelClient instance
board_state,
power_name: str,
possible_orders: Dict[str, List[str]],
game_history,
model_error_stats,
agent_goals=None,
agent_relationships=None,
agent_private_diary_str=None,
log_file_path: str = None,
phase: str = None,
) -> Dict[str, List[str]]:
"""
Generates orders with the LLM, validates them by round-tripping through the
engine, and returns **both** the accepted and rejected orders so the caller
can record invalid attempts.
Returns
-------
dict : { "valid": [...], "invalid": [...] }
"""
# ── 1. Ask the model ───────────────────────────────────────
raw_orders = await client.get_orders(
game=game,
board_state=board_state,
power_name=power_name,
possible_orders=possible_orders,
conversation_text=game_history,
model_error_stats=model_error_stats,
agent_goals=agent_goals,
agent_relationships=agent_relationships,
agent_private_diary_str=agent_private_diary_str,
log_file_path=log_file_path,
phase=phase,
)
invalid_info: list[str] = []
valid: list[str] = []
invalid: list[str] = []
# ── 2. Type check ──────────────────────────────────────────
if not isinstance(raw_orders, list):
logger.warning("[%s] Orders received from LLM are not a list: %s. Using fallback.",
power_name, raw_orders)
model_error_stats[client.model_name]["order_decoding_errors"] += 1
return {"valid": client.fallback_orders(possible_orders), "invalid": []}
# ── 3. Round-trip validation with engine ───────────────────
CODE_TO_ENGINE = {
"AUT": "AUSTRIA", "ENG": "ENGLAND", "FRA": "FRANCE",
"GER": "GERMANY", "ITA": "ITALY", "RUS": "RUSSIA", "TUR": "TURKEY",
}
engine_power = power_name if power_name in game.powers else CODE_TO_ENGINE[power_name]
for move in raw_orders:
if not move or not move.strip():
continue
upper = move.upper()
# WAIVE is always valid
if upper == "WAIVE":
valid.append("WAIVE")
continue
game.clear_orders(engine_power)
game.set_orders(engine_power, [upper])
normed = game.get_orders(engine_power)
if normed: # accepted
valid.append(normed[0])
else: # rejected
invalid.append(upper)
invalid_info.append(f"Order '{move}' is invalid for {power_name}")
game.clear_orders(engine_power) # clean slate for main engine flow
# ── 4. Legacy logging & stats updates ──────────────────────
if invalid_info: # at least one bad move
logger.debug("[%s] Invalid orders: %s", power_name, ", ".join(invalid_info))
model_error_stats[client.model_name]["order_decoding_errors"] += 1
logger.debug("[%s] Some orders invalid, using fallback.", power_name)
else:
logger.debug("[%s] All orders valid: %s", power_name, valid)
# ── 5. Fallback when nothing survives ─────────────────────
if not valid:
fallback = client.fallback_orders(possible_orders)
return {"valid": fallback, "invalid": invalid}
return {"valid": valid, "invalid": invalid}
def normalize_and_compare_orders(
issued_orders: Dict[str, List[str]],
accepted_orders_dict: Dict[str, List[str]],
game: Game,
) -> Tuple[Dict[str, Set[str]], Dict[str, Set[str]]]:
"""
Normalizes and compares issued orders against accepted orders from the game engine.
Uses the map's built-in normalization methods to ensure consistent formatting.
Args:
issued_orders: Dictionary of orders issued by power {power_name: [orders]}
accepted_orders_dict: Dictionary of orders accepted by the engine,
typically from game.get_state()["orders"].
game: The current Game object containing the map.
Returns:
Tuple[Dict[str, Set[str]], Dict[str, Set[str]]]: (orders_not_accepted, orders_not_issued)
- orders_not_accepted: Orders issued but not accepted by engine (normalized).
- orders_not_issued: Orders accepted by engine but not issued (normalized).
"""
game_map = game.map
def normalize_order(order: str) -> str:
# Inner function to normalize a single order string using the game map.
if not order:
return order
try:
# Use map's normalization methods directly
normalized = game_map.norm(order)
# Further split and normalize parts for complex orders if necessary
# (This part might need refinement depending on how complex orders are handled
# and represented after initial normalization by game_map.norm)
# Example (simplified, game_map.norm often handles this):
# Split support orders
# parts = normalized.split(" S ")
# normalized_parts = []
# for part in parts:
# move_parts = part.split(" - ")
# move_parts = [game_map.norm(p.strip()) for p in move_parts]
# move_parts = [game_map.aliases.get(p, p) for p in move_parts]
# normalized_parts.append(" - ".join(move_parts))
# return " S ".join(normalized_parts)
return normalized # Return the directly normalized string for now
except Exception as e:
logger.warning(f"Could not normalize order '{order}': {e}")
return order # Return original if normalization fails
orders_not_accepted = {}
orders_not_issued = {}
all_powers = set(issued_orders.keys()) | set(accepted_orders_dict.keys())
for pwr in all_powers:
# Normalize issued orders for the power, handling potential absence
issued_set = set()
if pwr in issued_orders:
try:
issued_set = {normalize_order(o) for o in issued_orders.get(pwr, []) if o}
except Exception as e:
logger.error(f"Error normalizing issued orders for {pwr}: {e}")
# Normalize accepted orders for the power, handling potential absence
accepted_set = set()
if pwr in accepted_orders_dict:
try:
accepted_set = {normalize_order(o) for o in accepted_orders_dict.get(pwr, []) if o}
except Exception as e:
logger.error(f"Error normalizing accepted orders for {pwr}: {e}")
# Compare the sets
missing_from_engine = issued_set - accepted_set
missing_from_issued = accepted_set - issued_set
if missing_from_engine:
orders_not_accepted[pwr] = missing_from_engine
if missing_from_issued:
orders_not_issued[pwr] = missing_from_issued
return orders_not_accepted, orders_not_issued
# Helper to load prompt text from file relative to the expected 'prompts' dir
def load_prompt(filename: str, prompts_dir: Optional[str] = None) -> str:
"""
Return the contents of *filename* while never joining paths twice.
Logic
-----
1. If *filename* is absolute → use it directly.
2. Elif *filename* already contains a path component (e.g. 'x/y.txt')
→ treat it as a relative path and use it directly.
3. Elif *prompts_dir* is provided → join prompts_dir + filename.
4. Otherwise → join the packages default prompts dir.
"""
if os.path.isabs(filename): # rule 1
prompt_path = filename
elif os.path.dirname(filename): # rule 2 (has slash)
# If it's a relative path with directory, join with prompts_dir if provided
if prompts_dir:
prompt_path = os.path.join(prompts_dir, filename)
else:
default_dir = os.path.join(os.path.dirname(__file__), "prompts")
prompt_path = os.path.join(default_dir, filename)
elif prompts_dir: # rule 3
prompt_path = os.path.join(prompts_dir, filename)
else: # rule 4
default_dir = os.path.join(os.path.dirname(__file__), "prompts")
prompt_path = os.path.join(default_dir, filename)
try:
with open(prompt_path, "r", encoding="utf-8") as fh:
return fh.read().strip()
except FileNotFoundError:
logger.error(f"Prompt file not found: {prompt_path}")
return ""
# == New LLM Response Logging Function ==
def log_llm_response(
log_file_path: str,
model_name: str,
power_name: Optional[str], # Optional for non-power-specific calls like summary
phase: str,
response_type: str,
raw_input_prompt: str, # Added new parameter for the raw input
raw_response: str,
success: str, # Changed from bool to str
):
"""Appends a raw LLM response to a CSV log file."""
try:
# Ensure the directory exists
log_dir = os.path.dirname(log_file_path)
if log_dir: # Ensure log_dir is not empty (e.g., if path is just a filename)
os.makedirs(log_dir, exist_ok=True)
# Check if file exists and has content to determine if we need headers
file_exists = os.path.isfile(log_file_path) and os.path.getsize(log_file_path) > 0
with open(log_file_path, "a", newline="", encoding="utf-8") as csvfile:
# Added "raw_input" to fieldnames
fieldnames = ["model", "power", "phase", "response_type", "raw_input", "raw_response", "success"]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames,
quoting=csv.QUOTE_ALL, # Quote all fields to handle commas and newlines
escapechar='\\') # Use backslash for escaping
if not file_exists:
writer.writeheader() # Write header only if file is new
writer.writerow({
"model": model_name,
"power": power_name if power_name else "game", # Use 'game' if no specific power
"phase": phase,
"response_type": response_type,
"raw_input": raw_input_prompt, # Added raw_input to the row
"raw_response": raw_response,
"success": success,
})
except Exception as e:
logger.error(f"Failed to log LLM response to {log_file_path}: {e}", exc_info=True)
# == New Async LLM Wrapper with Logging ==
async def run_llm_and_log(
client: 'BaseModelClient',
prompt: str,
log_file_path: str, # Kept for context, but not used for logging here
power_name: Optional[str], # Kept for context, but not used for logging here
phase: str, # Kept for context, but not used for logging here
response_type: str, # Kept for context, but not used for logging here
temperature: float = 0.0,
) -> str:
"""Calls the client's generate_response and returns the raw output. Logging is handled by the caller."""
raw_response = "" # Initialize in case of error
try:
raw_response = await client.generate_response(prompt, temperature=temperature)
except Exception as e:
# Log the API call error. The caller will decide how to log this in llm_responses.csv
logger.error(f"API Error during LLM call for {client.model_name}/{power_name}/{response_type} in phase {phase}: {e}", exc_info=True)
# raw_response remains "" indicating failure to the caller
return raw_response
# This generates a few lines of random alphanum chars to inject into the
# system prompt. This lets us use temp=0 while still getting variation
# between trials.
# Temp=0 is important for better performance on deciding moves, and to
# ensure valid json outputs.
def generate_random_seed(n_lines: int = 5, n_chars_per_line: int = 80):
# Generate x lines of y random alphanumeric characters
seed_lines = [
''.join(random.choices(string.ascii_letters + string.digits, k=n_chars_per_line))
for _ in range(n_lines)
]
random_seed_block = (
"<RANDOM SEED PLEASE IGNORE>\n" +
"\n".join(seed_lines) +
"\n</RANDOM SEED>"
)
return random_seed_block