mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-19 12:58:09 +00:00
499 lines
19 KiB
Python
499 lines
19 KiB
Python
from dotenv import load_dotenv
|
||
import logging
|
||
import os
|
||
from typing import Dict, List, Tuple, Set, Optional
|
||
from diplomacy import Game
|
||
import csv
|
||
from typing import TYPE_CHECKING
|
||
import random
|
||
import string
|
||
import json
|
||
import asyncio
|
||
|
||
from ..config import config
|
||
|
||
# Avoid circular import for type hinting
|
||
if TYPE_CHECKING:
|
||
from .clients import BaseModelClient
|
||
# If DiplomacyAgent is used for type hinting for an 'agent' parameter:
|
||
# from .agent import DiplomacyAgent
|
||
|
||
logger = logging.getLogger("utils")
|
||
logger.setLevel(logging.INFO)
|
||
logging.basicConfig(level=logging.INFO)
|
||
|
||
load_dotenv()
|
||
|
||
|
||
def atomic_write_json(data: dict, filepath: str):
|
||
"""Writes a dictionary to a JSON file atomically."""
|
||
try:
|
||
# Ensure the directory exists
|
||
dir_name = os.path.dirname(filepath)
|
||
if dir_name:
|
||
os.makedirs(dir_name, exist_ok=True)
|
||
|
||
# Write to a temporary file in the same directory
|
||
temp_filepath = f"{filepath}.tmp.{os.getpid()}"
|
||
with open(temp_filepath, "w", encoding="utf-8") as f:
|
||
json.dump(data, f, indent=4)
|
||
|
||
# Atomically rename the temporary file to the final destination
|
||
os.rename(temp_filepath, filepath)
|
||
except Exception as e:
|
||
logger.error(f"Failed to perform atomic write to {filepath}: {e}", exc_info=True)
|
||
# Clean up temp file if it exists
|
||
if os.path.exists(temp_filepath):
|
||
try:
|
||
os.remove(temp_filepath)
|
||
except Exception as e_clean:
|
||
logger.error(f"Failed to clean up temp file {temp_filepath}: {e_clean}")
|
||
|
||
|
||
def assign_models_to_powers() -> Dict[str, str]:
|
||
"""
|
||
Example usage: define which model each power uses.
|
||
Return a dict: { power_name: model_id, ... }
|
||
POWERS = ['AUSTRIA', 'ENGLAND', 'FRANCE', 'GERMANY', 'ITALY', 'RUSSIA', 'TURKEY']
|
||
Models supported: o3-mini, o4-mini, o3, gpt-4o, gpt-4o-mini,
|
||
claude-opus-4-20250514, claude-sonnet-4-20250514, claude-3-5-haiku-20241022, claude-3-5-sonnet-20241022, claude-3-7-sonnet-20250219
|
||
gemini-2.0-flash, gemini-2.5-flash-preview-04-17, gemini-2.5-pro-preview-03-25,
|
||
deepseek-chat, deepseek-reasoner
|
||
openrouter-meta-llama/llama-3.3-70b-instruct, openrouter-qwen/qwen3-235b-a22b, openrouter-microsoft/phi-4-reasoning-plus:free,
|
||
openrouter-deepseek/deepseek-prover-v2:free, openrouter-meta-llama/llama-4-maverick:free, openrouter-nvidia/llama-3.3-nemotron-super-49b-v1:free,
|
||
openrouter-google/gemma-3-12b-it:free, openrouter-google/gemini-2.5-flash-preview-05-20
|
||
"""
|
||
|
||
# POWER MODELS
|
||
"""
|
||
return {
|
||
"AUSTRIA": "openrouter-google/gemini-2.5-flash-preview-05-20",
|
||
"ENGLAND": "openrouter-moonshotai/kimi-dev-72b:free",
|
||
"FRANCE": "together-arcee-ai/AFM-4.5B-Preview",
|
||
"GERMANY": "openrouter-google/gemini-2.5-flash-lite-preview-06-17",
|
||
"ITALY": "together-lgai/exaone-deep-32b",
|
||
"RUSSIA": "deepseek-reasoner",
|
||
"TURKEY": "openrouter-cohere/command-a",
|
||
}
|
||
"""
|
||
# TEST MODELS
|
||
|
||
return {
|
||
"AUSTRIA": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
|
||
"ENGLAND": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
|
||
"FRANCE": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
|
||
"GERMANY": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
|
||
"ITALY": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
|
||
"RUSSIA": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
|
||
"TURKEY": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
|
||
}
|
||
|
||
|
||
def get_special_models() -> Dict[str, str]:
|
||
"""
|
||
Define models for special purposes like phase summaries and formatting.
|
||
|
||
These can be overridden via environment variables:
|
||
- AI_DIPLOMACY_NARRATIVE_MODEL: Model for phase summaries (default: "o3")
|
||
- AI_DIPLOMACY_FORMATTER_MODEL: Model for JSON formatting (default: "google/gemini-2.5-flash-lite-preview-06-17")
|
||
|
||
Returns:
|
||
dict: {
|
||
"phase_summary": model for generating narrative phase summaries,
|
||
"formatter": model for formatting natural language to JSON
|
||
}
|
||
|
||
Examples:
|
||
# Use Claude for phase summaries
|
||
export AI_DIPLOMACY_NARRATIVE_MODEL="claude-3-5-sonnet-20241022"
|
||
|
||
# Use a different Gemini model for formatting
|
||
export AI_DIPLOMACY_FORMATTER_MODEL="gemini-2.0-flash"
|
||
"""
|
||
return {"phase_summary": config.AI_DIPLOMACY_NARRATIVE_MODEL, "formatter": config.AI_DIPLOMACY_FORMATTER_MODEL}
|
||
|
||
|
||
def gather_possible_orders(game: Game, power_name: str) -> Dict[str, List[str]]:
|
||
"""
|
||
Returns a dictionary mapping each orderable location to the list of valid orders.
|
||
"""
|
||
orderable_locs = game.get_orderable_locations(power_name)
|
||
all_possible = game.get_all_possible_orders()
|
||
|
||
result = {}
|
||
for loc in orderable_locs:
|
||
result[loc] = all_possible.get(loc, [])
|
||
return result
|
||
|
||
|
||
async def get_valid_orders(
|
||
game: Game,
|
||
client, # BaseModelClient instance
|
||
board_state,
|
||
power_name: str,
|
||
possible_orders: Dict[str, List[str]],
|
||
game_history,
|
||
model_error_stats,
|
||
agent_goals=None,
|
||
agent_relationships=None,
|
||
agent_private_diary_str=None,
|
||
log_file_path: str = None,
|
||
phase: str = None,
|
||
) -> Dict[str, List[str]]:
|
||
"""
|
||
Generates orders with the LLM, validates them by round-tripping through the
|
||
engine, and returns **both** the accepted and rejected orders so the caller
|
||
can record invalid attempts.
|
||
|
||
Returns
|
||
-------
|
||
dict : { "valid": [...], "invalid": [...] }
|
||
"""
|
||
|
||
# ── 1. Ask the model ───────────────────────────────────────
|
||
raw_orders = await client.get_orders(
|
||
game=game,
|
||
board_state=board_state,
|
||
power_name=power_name,
|
||
possible_orders=possible_orders,
|
||
conversation_text=game_history,
|
||
model_error_stats=model_error_stats,
|
||
agent_goals=agent_goals,
|
||
agent_relationships=agent_relationships,
|
||
agent_private_diary_str=agent_private_diary_str,
|
||
log_file_path=log_file_path,
|
||
phase=phase,
|
||
)
|
||
|
||
invalid_info: list[str] = []
|
||
valid: list[str] = []
|
||
invalid: list[str] = []
|
||
|
||
# ── 2. Type check ──────────────────────────────────────────
|
||
if not isinstance(raw_orders, list):
|
||
logger.warning("[%s] Orders received from LLM are not a list: %s. Using fallback.", power_name, raw_orders)
|
||
model_error_stats[client.model_name]["order_decoding_errors"] += 1
|
||
return {"valid": client.fallback_orders(possible_orders), "invalid": []}
|
||
|
||
# ── 3. Round-trip validation with engine ───────────────────
|
||
CODE_TO_ENGINE = {
|
||
"AUT": "AUSTRIA",
|
||
"ENG": "ENGLAND",
|
||
"FRA": "FRANCE",
|
||
"GER": "GERMANY",
|
||
"ITA": "ITALY",
|
||
"RUS": "RUSSIA",
|
||
"TUR": "TURKEY",
|
||
}
|
||
engine_power = power_name if power_name in game.powers else CODE_TO_ENGINE[power_name]
|
||
|
||
for move in raw_orders:
|
||
if not move or not move.strip():
|
||
continue
|
||
|
||
upper = move.upper()
|
||
|
||
# WAIVE is always valid
|
||
if upper == "WAIVE":
|
||
valid.append("WAIVE")
|
||
continue
|
||
|
||
game.clear_orders(engine_power)
|
||
game.set_orders(engine_power, [upper])
|
||
normed = game.get_orders(engine_power)
|
||
|
||
if normed: # accepted
|
||
valid.append(normed[0])
|
||
else: # rejected
|
||
invalid.append(upper)
|
||
invalid_info.append(f"Order '{move}' is invalid for {power_name}")
|
||
|
||
game.clear_orders(engine_power) # clean slate for main engine flow
|
||
|
||
# ── 4. Legacy logging & stats updates ──────────────────────
|
||
if invalid_info: # at least one bad move
|
||
logger.debug("[%s] Invalid orders: %s", power_name, ", ".join(invalid_info))
|
||
model_error_stats[client.model_name]["order_decoding_errors"] += 1
|
||
logger.debug("[%s] Some orders invalid, using fallback.", power_name)
|
||
else:
|
||
logger.debug("[%s] All orders valid: %s", power_name, valid)
|
||
|
||
# ── 5. Fallback when nothing survives ─────────────────────
|
||
if not valid:
|
||
fallback = client.fallback_orders(possible_orders)
|
||
return {"valid": fallback, "invalid": invalid}
|
||
|
||
return {"valid": valid, "invalid": invalid}
|
||
|
||
|
||
def normalize_and_compare_orders(
|
||
issued_orders: Dict[str, List[str]],
|
||
accepted_orders_dict: Dict[str, List[str]],
|
||
game: Game,
|
||
) -> Tuple[Dict[str, Set[str]], Dict[str, Set[str]]]:
|
||
"""
|
||
Normalizes and compares issued orders against accepted orders from the game engine.
|
||
Uses the map's built-in normalization methods to ensure consistent formatting.
|
||
|
||
Args:
|
||
issued_orders: Dictionary of orders issued by power {power_name: [orders]}
|
||
accepted_orders_dict: Dictionary of orders accepted by the engine,
|
||
typically from game.get_state()["orders"].
|
||
game: The current Game object containing the map.
|
||
|
||
Returns:
|
||
Tuple[Dict[str, Set[str]], Dict[str, Set[str]]]: (orders_not_accepted, orders_not_issued)
|
||
- orders_not_accepted: Orders issued but not accepted by engine (normalized).
|
||
- orders_not_issued: Orders accepted by engine but not issued (normalized).
|
||
"""
|
||
game_map = game.map
|
||
|
||
def normalize_order(order: str) -> str:
|
||
# Inner function to normalize a single order string using the game map.
|
||
if not order:
|
||
return order
|
||
|
||
try:
|
||
# Use map's normalization methods directly
|
||
normalized = game_map.norm(order)
|
||
# Further split and normalize parts for complex orders if necessary
|
||
# (This part might need refinement depending on how complex orders are handled
|
||
# and represented after initial normalization by game_map.norm)
|
||
|
||
# Example (simplified, game_map.norm often handles this):
|
||
# Split support orders
|
||
# parts = normalized.split(" S ")
|
||
# normalized_parts = []
|
||
# for part in parts:
|
||
# move_parts = part.split(" - ")
|
||
# move_parts = [game_map.norm(p.strip()) for p in move_parts]
|
||
# move_parts = [game_map.aliases.get(p, p) for p in move_parts]
|
||
# normalized_parts.append(" - ".join(move_parts))
|
||
# return " S ".join(normalized_parts)
|
||
|
||
return normalized # Return the directly normalized string for now
|
||
except Exception as e:
|
||
logger.warning(f"Could not normalize order '{order}': {e}")
|
||
return order # Return original if normalization fails
|
||
|
||
orders_not_accepted = {}
|
||
orders_not_issued = {}
|
||
|
||
all_powers = set(issued_orders.keys()) | set(accepted_orders_dict.keys())
|
||
|
||
for pwr in all_powers:
|
||
# Normalize issued orders for the power, handling potential absence
|
||
issued_set = set()
|
||
if pwr in issued_orders:
|
||
try:
|
||
issued_set = {normalize_order(o) for o in issued_orders.get(pwr, []) if o}
|
||
except Exception as e:
|
||
logger.error(f"Error normalizing issued orders for {pwr}: {e}")
|
||
|
||
# Normalize accepted orders for the power, handling potential absence
|
||
accepted_set = set()
|
||
if pwr in accepted_orders_dict:
|
||
try:
|
||
accepted_set = {normalize_order(o) for o in accepted_orders_dict.get(pwr, []) if o}
|
||
except Exception as e:
|
||
logger.error(f"Error normalizing accepted orders for {pwr}: {e}")
|
||
|
||
# Compare the sets
|
||
missing_from_engine = issued_set - accepted_set
|
||
missing_from_issued = accepted_set - issued_set
|
||
|
||
if missing_from_engine:
|
||
orders_not_accepted[pwr] = missing_from_engine
|
||
if missing_from_issued:
|
||
orders_not_issued[pwr] = missing_from_issued
|
||
|
||
return orders_not_accepted, orders_not_issued
|
||
|
||
|
||
# Helper to load prompt text from file relative to the expected 'prompts' dir
|
||
def load_prompt(filename: str, prompts_dir: Optional[str] = None) -> str:
|
||
"""
|
||
Return the contents of *filename* while never joining paths twice.
|
||
|
||
Logic
|
||
-----
|
||
1. If *filename* is absolute → use it directly.
|
||
2. Elif *filename* already contains a path component (e.g. 'x/y.txt')
|
||
→ treat it as a relative path and use it directly.
|
||
3. Elif *prompts_dir* is provided → join prompts_dir + filename.
|
||
4. Otherwise → join the package’s default prompts dir.
|
||
"""
|
||
if os.path.isabs(filename): # rule 1
|
||
prompt_path = filename
|
||
elif os.path.dirname(filename): # rule 2 (has slash)
|
||
# If it's a relative path with directory, join with prompts_dir if provided
|
||
if prompts_dir:
|
||
prompt_path = os.path.join(prompts_dir, filename)
|
||
else:
|
||
default_dir = os.path.join(os.path.dirname(__file__), "prompts")
|
||
prompt_path = os.path.join(default_dir, filename)
|
||
elif prompts_dir: # rule 3
|
||
prompt_path = os.path.join(prompts_dir, filename)
|
||
else: # rule 4
|
||
default_dir = os.path.join(os.path.dirname(__file__), "prompts")
|
||
prompt_path = os.path.join(default_dir, filename)
|
||
|
||
try:
|
||
with open(prompt_path, "r", encoding="utf-8") as fh:
|
||
return fh.read().strip()
|
||
except FileNotFoundError:
|
||
logger.error(f"Prompt file not found: {prompt_path}")
|
||
return ""
|
||
|
||
|
||
# == New LLM Response Logging Function ==
|
||
def log_llm_response(
|
||
log_file_path: str,
|
||
model_name: str,
|
||
power_name: Optional[str], # Optional for non-power-specific calls like summary
|
||
phase: str,
|
||
response_type: str,
|
||
raw_input_prompt: str, # Added new parameter for the raw input
|
||
raw_response: str,
|
||
success: str, # Changed from bool to str
|
||
):
|
||
"""Appends a raw LLM response to a CSV log file."""
|
||
try:
|
||
# Ensure the directory exists
|
||
log_dir = os.path.dirname(log_file_path)
|
||
if log_dir: # Ensure log_dir is not empty (e.g., if path is just a filename)
|
||
os.makedirs(log_dir, exist_ok=True)
|
||
|
||
# Check if file exists and has content to determine if we need headers
|
||
file_exists = os.path.isfile(log_file_path) and os.path.getsize(log_file_path) > 0
|
||
|
||
with open(log_file_path, "a", newline="", encoding="utf-8") as csvfile:
|
||
# Added "raw_input" to fieldnames
|
||
fieldnames = ["model", "power", "phase", "response_type", "raw_input", "raw_response", "success"]
|
||
writer = csv.DictWriter(
|
||
csvfile,
|
||
fieldnames=fieldnames,
|
||
quoting=csv.QUOTE_ALL, # Quote all fields to handle commas and newlines
|
||
escapechar="\\",
|
||
) # Use backslash for escaping
|
||
|
||
if not file_exists:
|
||
writer.writeheader() # Write header only if file is new
|
||
|
||
writer.writerow(
|
||
{
|
||
"model": model_name,
|
||
"power": power_name if power_name else "game", # Use 'game' if no specific power
|
||
"phase": phase,
|
||
"response_type": response_type,
|
||
"raw_input": raw_input_prompt, # Added raw_input to the row
|
||
"raw_response": raw_response,
|
||
"success": success,
|
||
}
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Failed to log LLM response to {log_file_path}: {e}", exc_info=True)
|
||
|
||
|
||
async def run_llm_and_log(
|
||
client: "BaseModelClient",
|
||
prompt: str,
|
||
log_file_path: str, # Kept for context, but not used for logging here
|
||
power_name: Optional[str], # Kept for context, but not used for logging here
|
||
phase: str, # Kept for context, but not used for logging here
|
||
response_type: str, # Kept for context, but not used for logging here
|
||
temperature: float = 0.0,
|
||
*,
|
||
attempts: int = 5,
|
||
backoff_base: float = 1.0,
|
||
backoff_factor: float = 2.0,
|
||
jitter: float = 0.3,
|
||
) -> str:
|
||
"""
|
||
Calls `client.generate_response` with retry logic and returns the raw output.
|
||
|
||
Logging behaviour is identical to the original implementation:
|
||
- On a final failure (after all retries), it logs a single error with the
|
||
same message format as before.
|
||
- If a retry eventually succeeds, no errors are logged.
|
||
"""
|
||
raw_response = "" # Initialize in case of error
|
||
|
||
for attempt in range(attempts):
|
||
try:
|
||
raw_response = await client.generate_response(prompt, temperature=temperature)
|
||
if not raw_response:
|
||
raise Exception("Empty response from client")
|
||
return raw_response
|
||
except Exception as e:
|
||
if attempt == attempts - 1:
|
||
logger.error(
|
||
f"API Error during LLM call for {client.model_name}/{power_name}/{response_type} in phase {phase}: {e}",
|
||
exc_info=True,
|
||
)
|
||
# Back-off before the next attempt (unless this was the last)
|
||
delay = backoff_base * (backoff_factor**attempt) + random.uniform(0, jitter)
|
||
await asyncio.sleep(delay)
|
||
|
||
return raw_response
|
||
|
||
|
||
# This generates a few lines of random alphanum chars to inject into the
|
||
# system prompt. This lets us use temp=0 while still getting variation
|
||
# between trials.
|
||
# Temp=0 is important for better performance on deciding moves, and to
|
||
# ensure valid json outputs.
|
||
def generate_random_seed(n_lines: int = 5, n_chars_per_line: int = 80):
|
||
# Generate x lines of y random alphanumeric characters
|
||
seed_lines = ["".join(random.choices(string.ascii_letters + string.digits, k=n_chars_per_line)) for _ in range(n_lines)]
|
||
random_seed_block = "<RANDOM SEED PLEASE IGNORE>\n" + "\n".join(seed_lines) + "\n</RANDOM SEED>"
|
||
return random_seed_block
|
||
|
||
|
||
def get_prompt_path(prompt_name: str) -> str:
|
||
"""Get the appropriate prompt path based on USE_UNFORMATTED_PROMPTS setting.
|
||
|
||
Args:
|
||
prompt_name: Base name of the prompt file (e.g., "conversation_instructions.txt")
|
||
|
||
Returns:
|
||
str: Either "unformatted/{prompt_name}" or just "{prompt_name}"
|
||
"""
|
||
if config.USE_UNFORMATTED_PROMPTS:
|
||
return f"unformatted/{prompt_name}"
|
||
else:
|
||
return prompt_name
|
||
|
||
|
||
def normalize_recipient_name(recipient: str) -> str:
|
||
"""Normalize recipient names to handle LLM typos and abbreviations."""
|
||
if not recipient:
|
||
return recipient
|
||
|
||
recipient = recipient.upper().strip()
|
||
|
||
# Handle common LLM typos and abbreviations found in data
|
||
name_mapping = {
|
||
"EGMANY": "GERMANY",
|
||
"GERMAN": "GERMANY",
|
||
"UK": "ENGLAND",
|
||
"BRIT": "ENGLAND",
|
||
"ENGLAND": "ENGLAND", # Keep as-is
|
||
"FRANCE": "FRANCE", # Keep as-is
|
||
"GERMANY": "GERMANY", # Keep as-is
|
||
"ITALY": "ITALY", # Keep as-is
|
||
"AUSTRIA": "AUSTRIA", # Keep as-is
|
||
"RUSSIA": "RUSSIA", # Keep as-is
|
||
"TURKEY": "TURKEY", # Keep as-is
|
||
"Germany": "GERMANY",
|
||
"England": "ENGLAND",
|
||
"France": "FRANCE",
|
||
"Italy": "ITALY",
|
||
"Russia": "RUSSIA",
|
||
"Austria": "AUSTRIA",
|
||
"Turkey": "TURKEY",
|
||
}
|
||
|
||
normalized = name_mapping.get(recipient, recipient)
|
||
|
||
return normalized
|