mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-19 12:58:09 +00:00
573 lines
25 KiB
Python
573 lines
25 KiB
Python
import argparse
|
|
import logging
|
|
import time
|
|
import dotenv
|
|
import os
|
|
import json
|
|
import asyncio
|
|
from collections import defaultdict
|
|
import concurrent.futures
|
|
|
|
# Suppress Gemini/PaLM gRPC warnings
|
|
os.environ["GRPC_PYTHON_LOG_LEVEL"] = "40" # ERROR level only
|
|
|
|
from diplomacy import Game
|
|
from diplomacy.engine.message import GLOBAL, Message
|
|
from diplomacy.utils.export import to_saved_game_format
|
|
|
|
from ai_diplomacy.clients import load_model_client
|
|
from ai_diplomacy.utils import (
|
|
get_valid_orders,
|
|
gather_possible_orders,
|
|
assign_models_to_powers,
|
|
)
|
|
from ai_diplomacy.negotiations import conduct_negotiations
|
|
from ai_diplomacy.planning import planning_phase
|
|
from ai_diplomacy.game_history import GameHistory
|
|
from ai_diplomacy.agent import DiplomacyAgent
|
|
import ai_diplomacy.narrative
|
|
|
|
dotenv.load_dotenv()
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
|
|
datefmt="%H:%M:%S",
|
|
)
|
|
# Silence noisy dependencies
|
|
logging.getLogger("httpx").setLevel(logging.WARNING)
|
|
logging.getLogger("root").setLevel(logging.WARNING) # Assuming root handles AFC
|
|
|
|
|
|
def parse_arguments():
|
|
parser = argparse.ArgumentParser(
|
|
description="Run a Diplomacy game simulation with configurable parameters."
|
|
)
|
|
parser.add_argument(
|
|
"--max_year",
|
|
type=int,
|
|
default=1901,
|
|
help="Maximum year to simulate. The game will stop once this year is reached.",
|
|
)
|
|
parser.add_argument(
|
|
"--num_negotiation_rounds",
|
|
type=int,
|
|
default=0,
|
|
help="Number of negotiation rounds per phase.",
|
|
)
|
|
parser.add_argument(
|
|
"--output",
|
|
type=str,
|
|
default="",
|
|
help="Output filename for the final JSON result. If not provided, a timestamped name will be generated.",
|
|
)
|
|
parser.add_argument(
|
|
"--models",
|
|
type=str,
|
|
default="",
|
|
help=(
|
|
"Comma-separated list of model names to assign to powers in order. "
|
|
"The order is: AUSTRIA, ENGLAND, FRANCE, GERMANY, ITALY, RUSSIA, TURKEY."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--planning_phase",
|
|
action="store_true",
|
|
help="Enable the planning phase for each power to set strategic directives.",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
async def main():
|
|
args = parse_arguments()
|
|
max_year = args.max_year
|
|
|
|
logger.info(
|
|
"Starting a new Diplomacy game for testing with multiple LLMs, now async!"
|
|
)
|
|
start_whole = time.time()
|
|
|
|
model_error_stats = defaultdict(
|
|
lambda: {"conversation_errors": 0, "order_decoding_errors": 0}
|
|
)
|
|
|
|
# Create a fresh Diplomacy game
|
|
game = Game()
|
|
game_history = GameHistory()
|
|
|
|
# Ensure game has phase_summaries attribute
|
|
if not hasattr(game, "phase_summaries"):
|
|
game.phase_summaries = {}
|
|
|
|
# Determine the result folder based on a timestamp
|
|
timestamp_str = time.strftime("%Y%m%d_%H%M%S")
|
|
result_folder = f"./results/{timestamp_str}"
|
|
os.makedirs(result_folder, exist_ok=True)
|
|
|
|
# File paths
|
|
manifesto_path = f"{result_folder}/game_manifesto.txt"
|
|
# Use provided output filename or generate one based on the timestamp
|
|
game_file_path = args.output if args.output else f"{result_folder}/lmvsgame.json"
|
|
overview_file_path = f"{result_folder}/overview.jsonl"
|
|
|
|
# Handle power model mapping
|
|
if args.models:
|
|
# Expected order: AUSTRIA, ENGLAND, FRANCE, GERMANY, ITALY, RUSSIA, TURKEY
|
|
powers_order = [
|
|
"AUSTRIA",
|
|
"ENGLAND",
|
|
"FRANCE",
|
|
"GERMANY",
|
|
"ITALY",
|
|
"RUSSIA",
|
|
"TURKEY",
|
|
]
|
|
provided_models = [name.strip() for name in args.models.split(",")]
|
|
if len(provided_models) != len(powers_order):
|
|
logger.error(
|
|
f"Expected {len(powers_order)} models for --power-models but got {len(provided_models)}. Exiting."
|
|
)
|
|
return
|
|
game.power_model_map = dict(zip(powers_order, provided_models))
|
|
else:
|
|
game.power_model_map = assign_models_to_powers()
|
|
|
|
# == Goal 1: Centralize Agent Instances ==
|
|
agents = {}
|
|
initialization_tasks = []
|
|
logger.info("Initializing Diplomacy Agents for each power...")
|
|
for power_name, model_id in game.power_model_map.items():
|
|
if not game.powers[power_name].is_eliminated(): # Only create for active powers initially
|
|
try:
|
|
client = load_model_client(model_id)
|
|
# TODO: Potentially load initial goals/relationships from config later
|
|
agent = DiplomacyAgent(power_name=power_name, client=client)
|
|
agents[power_name] = agent
|
|
logger.info(f"Preparing initialization task for {power_name} with model {model_id}")
|
|
initialization_tasks.append(agent.initialize_agent_state(game, game_history))
|
|
except Exception as e:
|
|
logger.error(f"Failed to create agent or client for {power_name} with model {model_id}: {e}", exc_info=True)
|
|
else:
|
|
logger.info(f"Skipping agent initialization for eliminated power: {power_name}")
|
|
|
|
# == Run initializations concurrently ==
|
|
logger.info(f"Running {len(initialization_tasks)} agent initializations concurrently...")
|
|
initialization_results = await asyncio.gather(*initialization_tasks, return_exceptions=True)
|
|
# Check results for errors
|
|
# Note: agents dict might have fewer entries than results if client creation failed
|
|
initialized_powers = list(agents.keys()) # Get powers for which agents were created
|
|
for i, result in enumerate(initialization_results):
|
|
if i < len(initialized_powers): # Ensure index is valid for initialized_powers
|
|
power_name = initialized_powers[i]
|
|
if isinstance(result, Exception):
|
|
logger.error(f"Failed to initialize agent state for {power_name}: {result}", exc_info=result)
|
|
# Potentially remove agent if initialization failed? Depends on desired behavior.
|
|
else:
|
|
logger.info(f"Successfully initialized agent state for {power_name}.")
|
|
else:
|
|
logger.error(f"Initialization result mismatch - unexpected result: {result}")
|
|
# ========================================
|
|
|
|
while not game.is_game_done:
|
|
phase_start = time.time()
|
|
current_phase = game.get_current_phase()
|
|
|
|
# Ensure the current phase is registered in the history
|
|
game_history.add_phase(current_phase)
|
|
|
|
# Store the current phase's short name once for consistent use
|
|
current_short_phase = game.current_short_phase
|
|
|
|
logger.info(
|
|
f"PHASE: {current_phase} (time so far: {phase_start - start_whole:.2f}s)"
|
|
)
|
|
|
|
# DEBUG: Print the short phase to confirm
|
|
logger.debug(f"DEBUG: current_short_phase is '{current_short_phase}'")
|
|
|
|
# Prevent unbounded simulation based on year
|
|
year_str = current_phase[1:5]
|
|
year_int = int(year_str)
|
|
if year_int > max_year:
|
|
logger.info(f"Reached year {year_int}, stopping the test game early.")
|
|
break
|
|
|
|
# If it's a movement phase (e.g. ends with "M"), conduct negotiations
|
|
if game.current_short_phase.endswith("M"):
|
|
|
|
if args.planning_phase:
|
|
logger.debug("Starting planning phase block...")
|
|
game_history = planning_phase(
|
|
game,
|
|
agents, # Pass agents dict
|
|
game_history,
|
|
model_error_stats,
|
|
)
|
|
logger.debug("Starting negotiation phase block...")
|
|
# Await the async negotiation function
|
|
game_history = await conduct_negotiations(
|
|
game,
|
|
agents, # Pass agents dict
|
|
game_history,
|
|
model_error_stats,
|
|
max_rounds=args.num_negotiation_rounds,
|
|
)
|
|
|
|
# --- Async Order Generation ---
|
|
active_powers = [
|
|
(p_name, p_obj)
|
|
for p_name, p_obj in game.powers.items()
|
|
if not p_obj.is_eliminated()
|
|
]
|
|
|
|
order_tasks = []
|
|
power_names_for_order_tasks = []
|
|
|
|
for power_name, _ in active_powers:
|
|
if power_name not in agents:
|
|
logger.warning(f"Agent not found for active power {power_name}. Skipping order generation.")
|
|
continue
|
|
agent = agents[power_name]
|
|
client = agent.client
|
|
|
|
possible_orders = gather_possible_orders(game, power_name)
|
|
if not possible_orders:
|
|
logger.debug(f"No orderable locations for {power_name}; skipping order generation.")
|
|
continue
|
|
board_state = game.get_state()
|
|
|
|
# Append the awaitable call to get_valid_orders
|
|
order_tasks.append(
|
|
get_valid_orders(
|
|
game,
|
|
client,
|
|
board_state,
|
|
power_name,
|
|
possible_orders,
|
|
game_history, # Pass game_history object
|
|
model_error_stats,
|
|
agent_goals=agent.goals,
|
|
agent_relationships=agent.relationships,
|
|
)
|
|
)
|
|
power_names_for_order_tasks.append(power_name)
|
|
logger.debug(f"Prepared get_valid_orders task for {power_name}.")
|
|
|
|
# Run order generation tasks concurrently
|
|
if order_tasks:
|
|
logger.debug(f"Running {len(order_tasks)} order generation tasks concurrently...")
|
|
order_results = await asyncio.gather(*order_tasks, return_exceptions=True)
|
|
else:
|
|
logger.debug("No order generation tasks to run.")
|
|
order_results = []
|
|
|
|
# Process order results and set them in the game
|
|
for i, result in enumerate(order_results):
|
|
p_name = power_names_for_order_tasks[i]
|
|
agent = agents[p_name] # Get agent for logging/stats if needed
|
|
model_name = agent.client.model_name
|
|
|
|
if isinstance(result, Exception):
|
|
logger.error(f"Error during get_valid_orders for {p_name}: {result}", exc_info=result)
|
|
# Log error stats (consider if fallback orders should be set here)
|
|
if model_name in model_error_stats:
|
|
model_error_stats[model_name].setdefault("order_generation_errors", 0)
|
|
model_error_stats[model_name]["order_generation_errors"] += 1
|
|
# Optionally set fallback orders here if needed, e.g., game.set_orders(p_name, []) or specific fallback
|
|
game.set_orders(p_name, []) # Set empty orders on error for now
|
|
logger.warning(f"Setting empty orders for {p_name} due to generation error.")
|
|
elif result is None:
|
|
# Handle case where get_valid_orders might theoretically return None
|
|
logger.warning(f"get_valid_orders returned None for {p_name}. Setting empty orders.")
|
|
game.set_orders(p_name, [])
|
|
if model_name in model_error_stats:
|
|
model_error_stats[model_name].setdefault("order_generation_errors", 0)
|
|
model_error_stats[model_name]["order_generation_errors"] += 1
|
|
else:
|
|
# Result is the list of validated orders
|
|
orders = result
|
|
logger.debug(f"Validated orders for {p_name}: {orders}")
|
|
if orders:
|
|
game.set_orders(p_name, orders)
|
|
logger.debug(
|
|
f"Set orders for {p_name} in {game.current_short_phase}: {orders}"
|
|
)
|
|
else:
|
|
logger.debug(f"No valid orders returned by get_valid_orders for {p_name}. Setting empty orders.")
|
|
game.set_orders(p_name, []) # Set empty if get_valid_orders returned empty
|
|
|
|
# --- End Async Order Generation ---
|
|
|
|
# Process orders
|
|
logger.info(f"Processing orders for {current_phase}...")
|
|
|
|
# Process with a custom summary callback that captures our custom game_history data
|
|
def phase_summary_callback(system_prompt, user_prompt):
|
|
# This will be called by the game engine's _generate_phase_summary method
|
|
# Get messages for this phase from game_history
|
|
current_phase_obj = None
|
|
for phase in game_history.phases:
|
|
if phase.name == current_short_phase:
|
|
current_phase_obj = phase
|
|
break
|
|
|
|
if not current_phase_obj:
|
|
return f"Phase {current_short_phase} Summary: (No game history data available)"
|
|
|
|
# 1) Gather the current board state, sorted by # of centers
|
|
power_info = []
|
|
for power_name, power in game.powers.items():
|
|
units_list = list(power.units)
|
|
centers_list = list(power.centers)
|
|
power_info.append(
|
|
(power_name, len(centers_list), units_list, centers_list)
|
|
)
|
|
# Sort by descending # of centers
|
|
power_info.sort(key=lambda x: x[1], reverse=True)
|
|
|
|
# 2) Build text lines for the top "Board State Overview"
|
|
top_lines = ["Current Board State (Ordered by SC Count):"]
|
|
for (p_name, sc_count, units, centers) in power_info:
|
|
top_lines.append(
|
|
f" • {p_name}: {sc_count} centers (needs 18 to win). "
|
|
f"Units={units} Centers={centers}"
|
|
)
|
|
|
|
# 3) Map orders to "successful", "failed", or "other" outcomes
|
|
success_dict = {}
|
|
fail_dict = {}
|
|
other_dict = {}
|
|
|
|
orders_dict = game.order_history.get(current_short_phase, {})
|
|
results_for_phase = game.result_history.get(current_short_phase, {})
|
|
|
|
for pwr, pwr_orders in orders_dict.items():
|
|
for order_str in pwr_orders:
|
|
# Extract the unit from the string
|
|
tokens = order_str.split()
|
|
if len(tokens) < 3:
|
|
# Something malformed
|
|
other_dict.setdefault(pwr, []).append(order_str)
|
|
continue
|
|
unit_name = " ".join(tokens[:2])
|
|
# We retrieve the order results for that unit
|
|
results_list = results_for_phase.get(unit_name, [])
|
|
# Check if the results contain e.g. "dislodged", "bounce", "void"
|
|
# We consider success if the result list is empty or has no negative results
|
|
if not results_list or all(res not in ["bounce", "void", "no convoy", "cut", "dislodged", "disrupted"] for res in results_list):
|
|
success_dict.setdefault(pwr, []).append(order_str)
|
|
elif any(res in ["bounce", "void", "no convoy", "cut", "dislodged", "disrupted"] for res in results_list):
|
|
fail_dict.setdefault(pwr, []).append(f"{order_str} - {', '.join(str(r) for r in results_list)}")
|
|
else:
|
|
other_dict.setdefault(pwr, []).append(order_str)
|
|
|
|
# 4) Build textual lists of successful, failed, and "other" moves
|
|
def format_moves_dict(title, moves_dict):
|
|
lines = [title]
|
|
if not moves_dict:
|
|
lines.append(" None.")
|
|
return "\n".join(lines)
|
|
for pwr in sorted(moves_dict.keys()):
|
|
lines.append(f" {pwr}:")
|
|
for mv in moves_dict[pwr]:
|
|
lines.append(f" {mv}")
|
|
return "\n".join(lines)
|
|
|
|
success_section = format_moves_dict("Successful Moves:", success_dict)
|
|
fail_section = format_moves_dict("Unsuccessful Moves:", fail_dict)
|
|
other_section = format_moves_dict("Other / Unclassified Moves:", other_dict)
|
|
|
|
# 5) Combine everything into the final summary text
|
|
summary_parts = []
|
|
summary_parts.append("\n".join(top_lines))
|
|
summary_parts.append("\n" + success_section)
|
|
summary_parts.append("\n" + fail_section)
|
|
|
|
# Only include "Other" section if it has content
|
|
if other_dict:
|
|
summary_parts.append("\n" + other_section)
|
|
|
|
return f"Phase {current_short_phase} Summary:\n\n" + "\n".join(summary_parts)
|
|
|
|
# Process with our custom callback
|
|
game.process(phase_summary_callback=phase_summary_callback)
|
|
|
|
# Log the results
|
|
logger.info(f"Results for {current_phase}:")
|
|
for power_name, power in game.powers.items():
|
|
logger.info(f"{power_name}: {power.centers}")
|
|
|
|
# Ensure messages from game_history are added to the game's message system
|
|
# This is required for messages to appear in the Messages tab
|
|
for phase in game_history.phases:
|
|
if phase.name == current_short_phase:
|
|
for msg in phase.messages:
|
|
try:
|
|
# Only add if not already present (avoid duplicates)
|
|
if not any(m.sender == msg.sender and
|
|
m.recipient == msg.recipient and
|
|
m.message == msg.content
|
|
for m in game.messages.values()):
|
|
game.add_message(Message(
|
|
phase=current_short_phase,
|
|
sender=msg.sender,
|
|
recipient=msg.recipient,
|
|
message=msg.content,
|
|
time_sent=int(time.time())
|
|
))
|
|
except Exception as e:
|
|
logger.warning(f"Could not add message to game: {e}")
|
|
|
|
# Add orders to game history
|
|
for power_name in game.order_history[current_short_phase]:
|
|
orders = game.order_history[current_short_phase][power_name]
|
|
results = []
|
|
for order in orders:
|
|
# Example move: "A PAR H" -> unit="A PAR", order_part="H"
|
|
tokens = order.split(" ", 2)
|
|
if len(tokens) < 3:
|
|
continue
|
|
unit = " ".join(tokens[:2]) # e.g. "A PAR"
|
|
order_part = tokens[2] # e.g. "H" or "S A MAR"
|
|
results.append(
|
|
[str(x) for x in game.result_history[current_short_phase][unit]]
|
|
)
|
|
game_history.add_orders(
|
|
current_short_phase,
|
|
power_name,
|
|
game.order_history[current_short_phase][power_name],
|
|
)
|
|
|
|
logger.info(f"--- Orders Submitted for {current_phase} ---")
|
|
for power, orders in game.order_history.get(current_short_phase, {}).items():
|
|
order_str = ", ".join(orders) if orders else "(No orders/NOP)"
|
|
logger.info(f" {power:<8}: {order_str}")
|
|
logger.info("-----------------------------------")
|
|
|
|
# --- Async State Update ---
|
|
completed_phase_name = current_phase # The phase name BEFORE processing
|
|
logger.info(f"Starting state update analysis for completed phase {completed_phase_name}...")
|
|
|
|
# Ensure phase summary exists (it should be generated by the callback during game.process)
|
|
phase_summary = game.phase_summaries.get(completed_phase_name, f"Summary for {completed_phase_name} not found.")
|
|
if f"Summary for {completed_phase_name} not found" in phase_summary:
|
|
logger.warning(phase_summary)
|
|
|
|
current_board_state = game.get_state() # State *after* processing
|
|
|
|
# Update state concurrently for active agents
|
|
active_agent_powers = [p for p in active_powers if p[0] in agents] # Filter for powers with active agents
|
|
|
|
if active_agent_powers: # Only run if there are agents to update
|
|
logger.info(f"Beginning concurrent state analysis for {len(active_agent_powers)} agents...")
|
|
|
|
analysis_tasks = []
|
|
power_names_for_analysis = []
|
|
|
|
for power_name, _ in active_agent_powers:
|
|
agent = agents[power_name]
|
|
logger.debug(f"Preparing state analysis task for {power_name}")
|
|
# Append the awaitable call
|
|
analysis_tasks.append(
|
|
agent.analyze_phase_and_update_state(
|
|
game,
|
|
current_board_state, # Use state AFTER processing
|
|
phase_summary,
|
|
game_history
|
|
)
|
|
)
|
|
power_names_for_analysis.append(power_name)
|
|
|
|
# Run analysis tasks concurrently
|
|
if analysis_tasks:
|
|
logger.debug(f"Running {len(analysis_tasks)} state analysis tasks concurrently...")
|
|
analysis_results = await asyncio.gather(*analysis_tasks, return_exceptions=True)
|
|
else:
|
|
analysis_results = []
|
|
|
|
# Process results (check for exceptions)
|
|
for i, result in enumerate(analysis_results):
|
|
power_name = power_names_for_analysis[i]
|
|
if isinstance(result, Exception):
|
|
logger.error(f"Error during state analysis for {power_name}: {result}", exc_info=result)
|
|
# Optionally log error stats here
|
|
else:
|
|
# Result is None if the function completes normally
|
|
logger.debug(f"State analysis completed successfully for {power_name}.")
|
|
|
|
logger.info(f"Finished concurrent state analysis for {len(active_agent_powers)} agents.")
|
|
logger.info(f"Completed state update analysis for phase {completed_phase_name}.")
|
|
else:
|
|
logger.info(f"No active agents found to perform state update analysis for phase {completed_phase_name}.")
|
|
# --- End Async State Update ---
|
|
|
|
# Append the strategic directives to the manifesto file
|
|
strategic_directives = game_history.get_strategic_directives()
|
|
if strategic_directives:
|
|
out_str = f"Strategic directives for {current_phase}:\n"
|
|
for power, directive in strategic_directives.items():
|
|
out_str += f"{power}: {directive}\n\n"
|
|
out_str += f"------------------------------------------\n"
|
|
with open(manifesto_path, "a") as f:
|
|
f.write(out_str)
|
|
|
|
# Check if we've exceeded the max year
|
|
year_str = current_phase[1:5]
|
|
year_int = int(year_str)
|
|
if year_int > max_year:
|
|
logger.info(f"Reached year {year_int}, stopping the test game early.")
|
|
break
|
|
|
|
# Game is done
|
|
total_time = time.time() - start_whole
|
|
logger.info(f"Game ended after {total_time:.2f}s. Saving results...")
|
|
|
|
# Now save the game with our added data
|
|
output_path = game_file_path
|
|
# If the file already exists, append a timestamp to the filename
|
|
if os.path.exists(output_path):
|
|
logger.info("Game file already exists, saving with unique filename.")
|
|
timestamp = int(time.time())
|
|
base, ext = os.path.splitext(output_path)
|
|
output_path = f"{base}_{timestamp}{ext}"
|
|
|
|
# Generate the saved game JSON using the standard export function
|
|
saved_game = to_saved_game_format(game)
|
|
|
|
# Verify phase_summaries are available in game.phase_summaries
|
|
logger.info(f"Game has {len(game.phase_summaries)} phase summaries: {list(game.phase_summaries.keys())}")
|
|
|
|
# CRITICAL: Add phase_summaries as a top-level property in the saved game
|
|
# The frontend expects this exact structure
|
|
saved_game['phase_summaries'] = game.phase_summaries
|
|
logger.info(f"Added phase_summaries to saved game with {len(game.phase_summaries)} phases")
|
|
|
|
# Also add summaries to individual phases for backward compatibility
|
|
summary_phases_count = 0
|
|
for i, phase in enumerate(saved_game['phases']):
|
|
phase_name = phase['name']
|
|
if phase_name in game.phase_summaries:
|
|
saved_game['phases'][i]['summary'] = game.phase_summaries[phase_name]
|
|
summary_phases_count += 1
|
|
logger.debug(f"Added summary to phase {phase_name} in export")
|
|
logger.info(f"Added summaries to {summary_phases_count}/{len(saved_game['phases'])} phases in the export")
|
|
|
|
# Save the modified game data
|
|
logger.info(f"Saving game to {output_path}...")
|
|
with open(output_path, "w") as f:
|
|
json.dump(saved_game, f, indent=4)
|
|
|
|
# Dump error stats and power model mapping to the overview file
|
|
with open(overview_file_path, "w") as overview_file:
|
|
overview_file.write(json.dumps(model_error_stats) + "\n")
|
|
overview_file.write(json.dumps(game.power_model_map) + "\n")
|
|
overview_file.write(json.dumps(vars(args)) + "\n")
|
|
|
|
logger.info(f"Saved game data, manifesto, and error stats in: {result_folder}")
|
|
logger.info("Done.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Run the async main function
|
|
asyncio.run(main())
|