AI_Diplomacy/ai_diplomacy/game_logic.py

386 lines
16 KiB
Python

# ai_diplomacy/game_logic.py
import logging
import os
import json
import asyncio
from typing import Dict, Tuple, Optional, Any
from argparse import Namespace
from pathlib import Path
from diplomacy import Game
from diplomacy.utils.export import to_saved_game_format, from_saved_game_format
from .agent import DiplomacyAgent, ALL_POWERS
from .clients import load_model_client
from .game_history import GameHistory
from .initialization import initialize_agent_state_ext
from .utils import atomic_write_json, assign_models_to_powers
logger = logging.getLogger(__name__)
# --- Serialization / Deserialization ---
def serialize_agent(agent: DiplomacyAgent) -> dict:
"""Converts an agent object to a JSON-serializable dictionary."""
return {
"power_name": agent.power_name,
"model_id": agent.client.model_name,
"max_tokens": agent.client.max_tokens,
"goals": agent.goals,
"relationships": agent.relationships,
"full_private_diary": agent.full_private_diary,
"private_diary": agent.private_diary,
}
def deserialize_agent(agent_data: dict, prompts_dir: Optional[str] = None, *, override_model_id: Optional[str] = None) -> DiplomacyAgent:
"""
Recreates an agent object from a dictionary.
If *override_model_id* is provided (e.g. because the CLI argument
``--models`` was used when resuming a game), that model is loaded
instead of the one stored in the save file.
"""
model_id = override_model_id or agent_data["model_id"]
client = load_model_client(model_id, prompts_dir=prompts_dir)
# Keep the original or fallback token limit exactly as before.
client.max_tokens = agent_data.get("max_tokens", 16000)
agent = DiplomacyAgent(
power_name=agent_data["power_name"],
client=client,
initial_goals=agent_data.get("goals", []),
initial_relationships=agent_data.get("relationships", None),
prompts_dir=prompts_dir,
)
# Restore diary state
agent.full_private_diary = agent_data.get("full_private_diary", [])
agent.private_diary = agent_data.get("private_diary", [])
return agent
# --- State Management ---
# game_logic.py
_PHASE_ORDER = ["M", "R", "A"] # Movement → Retreats → Adjustments
def _next_phase_name(short: str) -> str:
"""
Return the Diplomacy phase string that chronologically follows *short*.
(E.g. S1901M → S1901R, S1901R → W1901A, W1901A → S1902M)
"""
season = short[0] # 'S' | 'W'
year = int(short[1:5])
typ = short[-1] # 'M' | 'R' | 'A'
idx = _PHASE_ORDER.index(typ)
if idx < 2: # still in the same season
return f"{season}{year}{_PHASE_ORDER[idx + 1]}"
# typ was 'A' → roll season
if season == "S": # summer → winter, same year
return f"W{year}M"
else: # winter→ spring, next year
return f"S{year + 1}M"
def save_game_state(
game: Game, agents: Dict[str, DiplomacyAgent], game_history: GameHistory, output_path: str, run_config: Namespace, completed_phase_name: str
):
"""
Serialise the entire game to JSON, preserving per-phase custom metadata
(e.g. 'state_agents') that may have been written by earlier save passes.
"""
logger.info(f"Saving game state to {output_path}")
# ------------------------------------------------------------------ #
# 1. If the file already exists, cache the per-phase custom blocks. #
# ------------------------------------------------------------------ #
previous_phase_extras: Dict[str, Dict[str, Any]] = {}
if os.path.isfile(output_path):
try:
with open(output_path, "r", encoding="utf-8") as fh:
previous_save = json.load(fh)
for phase in previous_save.get("phases", []):
# Keep a copy of *all* non-standard keys so that future
# additions survive automatically.
extras = {
k: v
for k, v in phase.items()
if k
not in {
"name",
"orders",
"results",
"messages",
"state",
"config",
}
}
if extras:
previous_phase_extras[phase["name"]] = extras
except Exception as exc:
logger.warning("Could not load previous save to retain metadata: %s", exc, exc_info=True)
# -------------------------------------------------------------- #
# 2. Build the fresh base structure from the diplomacy library. #
# -------------------------------------------------------------- #
saved_game = to_saved_game_format(game)
# -------------------------------------------------------------- #
# 3. Walk every phase and merge the metadata back in. #
# -------------------------------------------------------------- #
# Capture the *current* snapshot of every live agent exactly once.
current_state_agents = {p_name: serialize_agent(p_agent) for p_name, p_agent in agents.items() if not game.powers[p_name].is_eliminated()}
for phase_block in saved_game.get("phases", []):
if int(phase_block["name"][1:5]) > run_config.max_year:
break
phase_name = phase_block["name"]
# 3a. Re-attach anything we cached from a previous save.
if phase_name in previous_phase_extras:
phase_block.update(previous_phase_extras[phase_name])
# 3b. For *this* phase we also inject the fresh agent snapshot
# and the plans written during the turn.
if phase_name == completed_phase_name:
# ---- make run_config serialisable ---------------------------------
cfg = vars(run_config).copy()
if "prompts_dir_map" in cfg and isinstance(cfg["prompts_dir_map"], dict):
cfg["prompts_dir_map"] = {p: str(path) for p, path in cfg["prompts_dir_map"].items()}
if isinstance(cfg.get("prompts_dir"), Path):
cfg["prompts_dir"] = str(cfg["prompts_dir"])
# -------------------------------------------------------------------
phase_block["config"] = cfg
phase_block["state_agents"] = current_state_agents
# -------------------------------------------------------------- #
# 4. Attach top-level metadata and write atomically. #
# -------------------------------------------------------------- #
saved_game["phase_summaries"] = getattr(game, "phase_summaries", {})
saved_game["final_agent_states"] = {p_name: {"relationships": a.relationships, "goals": a.goals} for p_name, a in agents.items()}
# Filter out phases > max_year
# saved_game["phases"] = [
# ph for ph in saved_game["phases"]
# if int(ph["name"][1:5]) <= run_config.max_year # <= 1902, for example
# ]
atomic_write_json(saved_game, output_path)
logger.info("Game state saved successfully.")
def load_game_state(
run_dir: str,
game_file_name: str,
run_config: Namespace,
resume_from_phase: Optional[str] = None,
) -> Tuple[Game, Dict[str, DiplomacyAgent], GameHistory, Optional[Namespace]]:
"""Loads and reconstructs the game state from a saved game file."""
game_file_path = os.path.join(run_dir, game_file_name)
if not os.path.exists(game_file_path):
raise FileNotFoundError(f"Cannot resume. Save file not found at: {game_file_path}")
logger.info(f"Loading game state from: {game_file_path}")
with open(game_file_path, "r") as f:
saved_game_data = json.load(f)
# If resuming, find the specified phase and truncate the data after it
if resume_from_phase:
logger.info(f"Resuming from phase '{resume_from_phase}'. Truncating subsequent data.")
try:
# Find the index of the phase *before* the one we want to resume from.
# We will start the simulation *at* the resume_from_phase.
resume_idx = next(i for i, phase in enumerate(saved_game_data["phases"]) if phase["name"] == resume_from_phase)
# Truncate the list to exclude everything after the resume phase
# Note: the state saved for a given phase represents the state at the beginning of that phase.
saved_game_data["phases"] = saved_game_data["phases"][: resume_idx + 1]
# Wipe any data that must be regenerated.
for key in ("orders", "results", "messages"):
saved_game_data["phases"][-1].pop(key, None)
logger.info(f"Game history truncated to {len(saved_game_data['phases'])} phases. The next phase to run will be {resume_from_phase}.")
except StopIteration:
# If the phase is not found, maybe it's the first phase (S1901M)
if resume_from_phase == "S1901M":
saved_game_data["phases"] = []
logger.info("Resuming from S1901M. Starting with a clean history.")
else:
raise ValueError(f"Resume phase '{resume_from_phase}' not found in the save file.")
# Reconstruct the Game object
last_phase = saved_game_data["phases"][-1]
# Wipe the data that must be regenerated **but preserve the keys**
last_phase["orders"] = {} # was dict
last_phase["results"] = {} # was dict
last_phase["messages"] = []
game = from_saved_game_format(saved_game_data)
game.phase_summaries = saved_game_data.get("phase_summaries", {})
# Reconstruct agents and game history from the *last* valid phase in the data
if not saved_game_data["phases"]:
# This happens if we are resuming from the very beginning (S1901M)
logger.info("No previous phases found. Initializing fresh agents and history.")
agents = {} # Will be created by the main loop
game_history = GameHistory()
else:
# We save the game state up to & including the current (uncompleted) phase.
# So we need to grab the agent state from the previous (completed) phase.
if len(saved_game_data["phases"]) <= 1:
last_phase_data = {}
else:
last_phase_data = saved_game_data["phases"][-2]
# -------------------- Rebuild agents -------------------- #
agents = {}
# Build a power→model map from the CLI argument --models, if present.
power_model_map: Dict[str, str] = {}
if run_config and getattr(run_config, "models", None):
provided = [m.strip() for m in run_config.models.split(",")]
powers_order = sorted(list(ALL_POWERS))
if len(provided) == len(powers_order):
power_model_map = dict(zip(powers_order, provided))
elif len(provided) == 1:
power_model_map = dict(zip(powers_order, provided * len(powers_order)))
else:
raise ValueError(f"Invalid --models argument: expected 1 or {len(powers_order)} items, got {len(provided)}.")
if "state_agents" in last_phase_data:
logger.info("Rebuilding agents from saved state...")
for power_name, agent_data in last_phase_data["state_agents"].items():
override_id = power_model_map.get(power_name)
prompts_dir_from_config = (
run_config.prompts_dir_map.get(power_name)
if getattr(run_config, "prompts_dir_map", None)
else run_config.prompts_dir # fallback to old single path
)
agents[power_name] = deserialize_agent(
agent_data,
prompts_dir=prompts_dir_from_config,
override_model_id=override_id,
)
logger.info(f"Rebuilt {len(agents)} agents.")
else:
raise ValueError("Cannot resume: 'state_agents' key not found in the last phase of the save file.")
# Rebuild GameHistory
game_history = GameHistory()
logger.info("Rebuilding game history...")
for phase_data in saved_game_data["phases"][:-1]:
phase_name = phase_data["name"]
game_history.add_phase(phase_name)
# Add messages
for msg in phase_data.get("messages", []):
game_history.add_message(phase_name, msg["sender"], msg["recipient"], msg["message"])
# Add plans
if "state_history_plans" in phase_data:
for p_name, plan in phase_data["state_history_plans"].items():
game_history.add_plan(phase_name, p_name, plan)
logger.info("Game history rebuilt.")
return game, agents, game_history, run_config
# ai_diplomacy/game_logic.py
async def initialize_new_game(
args: Namespace,
game: Game,
game_history: GameHistory,
llm_log_file_path: str,
) -> Dict[str, DiplomacyAgent]:
"""Initializes agents for a new game (supports per-power prompt directories)."""
powers_order = sorted(list(ALL_POWERS))
# Parse token limits
default_max_tokens = args.max_tokens
model_max_tokens = {p: default_max_tokens for p in powers_order}
if args.max_tokens_per_model:
per_model_values = [s.strip() for s in args.max_tokens_per_model.split(",")]
if len(per_model_values) == 7:
for power, token_val_str in zip(powers_order, per_model_values):
model_max_tokens[power] = int(token_val_str)
else:
logger.warning("Expected 7 values for --max_tokens_per_model, using default.")
# Handle power-model mapping
if args.models:
provided_models = [name.strip() for name in args.models.split(",")]
if len(provided_models) == len(powers_order):
game.power_model_map = dict(zip(powers_order, provided_models))
elif len(provided_models) == 1:
game.power_model_map = dict(zip(powers_order, provided_models * 7))
else:
logger.error(
f"Expected {len(powers_order)} models for --models but got {len(provided_models)}."
)
raise Exception(
"Invalid number of models. Models list must be either exactly 1 or 7 models, comma delimited."
)
else:
game.power_model_map = assign_models_to_powers()
agents: Dict[str, DiplomacyAgent] = {}
initialization_tasks = []
logger.info("Initializing Diplomacy Agents for each power...")
for power_name, model_id in game.power_model_map.items():
if not game.powers[power_name].is_eliminated():
# Determine the prompts directory for this power
if hasattr(args, "prompts_dir_map") and args.prompts_dir_map:
prompts_dir_for_power = args.prompts_dir_map.get(power_name, args.prompts_dir)
else:
prompts_dir_for_power = args.prompts_dir
try:
client = load_model_client(model_id, prompts_dir=prompts_dir_for_power)
client.max_tokens = model_max_tokens[power_name]
agent = DiplomacyAgent(
power_name=power_name,
client=client,
prompts_dir=prompts_dir_for_power,
)
agents[power_name] = agent
logger.info(f"Preparing initialization task for {power_name} with model {model_id}")
initialization_tasks.append(
initialize_agent_state_ext(
agent,
game,
game_history,
llm_log_file_path,
prompts_dir=prompts_dir_for_power,
)
)
except Exception as e:
logger.error(
f"Failed to create agent or client for {power_name} with model {model_id}: {e}",
exc_info=True,
)
logger.info(f"Running {len(initialization_tasks)} agent initializations concurrently...")
initialization_results = await asyncio.gather(*initialization_tasks, return_exceptions=True)
initialized_powers = list(agents.keys())
for i, result in enumerate(initialization_results):
if i < len(initialized_powers):
power_name = initialized_powers[i]
if isinstance(result, Exception):
logger.error(f"Failed to initialize agent state for {power_name}: {result}", exc_info=result)
else:
logger.info(f"Successfully initialized agent state for {power_name}.")
return agents