mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-19 12:58:09 +00:00
386 lines
16 KiB
Python
386 lines
16 KiB
Python
# ai_diplomacy/game_logic.py
|
|
import logging
|
|
import os
|
|
import json
|
|
import asyncio
|
|
from typing import Dict, Tuple, Optional, Any
|
|
from argparse import Namespace
|
|
from pathlib import Path
|
|
|
|
from diplomacy import Game
|
|
from diplomacy.utils.export import to_saved_game_format, from_saved_game_format
|
|
|
|
from .agent import DiplomacyAgent, ALL_POWERS
|
|
from .clients import load_model_client
|
|
from .game_history import GameHistory
|
|
from .initialization import initialize_agent_state_ext
|
|
from .utils import atomic_write_json, assign_models_to_powers
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# --- Serialization / Deserialization ---
|
|
|
|
|
|
def serialize_agent(agent: DiplomacyAgent) -> dict:
|
|
"""Converts an agent object to a JSON-serializable dictionary."""
|
|
return {
|
|
"power_name": agent.power_name,
|
|
"model_id": agent.client.model_name,
|
|
"max_tokens": agent.client.max_tokens,
|
|
"goals": agent.goals,
|
|
"relationships": agent.relationships,
|
|
"full_private_diary": agent.full_private_diary,
|
|
"private_diary": agent.private_diary,
|
|
}
|
|
|
|
|
|
def deserialize_agent(agent_data: dict, prompts_dir: Optional[str] = None, *, override_model_id: Optional[str] = None) -> DiplomacyAgent:
|
|
"""
|
|
Recreates an agent object from a dictionary.
|
|
|
|
If *override_model_id* is provided (e.g. because the CLI argument
|
|
``--models`` was used when resuming a game), that model is loaded
|
|
instead of the one stored in the save file.
|
|
"""
|
|
model_id = override_model_id or agent_data["model_id"]
|
|
client = load_model_client(model_id, prompts_dir=prompts_dir)
|
|
|
|
# Keep the original or fallback token limit exactly as before.
|
|
client.max_tokens = agent_data.get("max_tokens", 16000)
|
|
|
|
agent = DiplomacyAgent(
|
|
power_name=agent_data["power_name"],
|
|
client=client,
|
|
initial_goals=agent_data.get("goals", []),
|
|
initial_relationships=agent_data.get("relationships", None),
|
|
prompts_dir=prompts_dir,
|
|
)
|
|
|
|
# Restore diary state
|
|
agent.full_private_diary = agent_data.get("full_private_diary", [])
|
|
agent.private_diary = agent_data.get("private_diary", [])
|
|
|
|
return agent
|
|
|
|
|
|
# --- State Management ---
|
|
|
|
# game_logic.py
|
|
_PHASE_ORDER = ["M", "R", "A"] # Movement → Retreats → Adjustments
|
|
|
|
|
|
def _next_phase_name(short: str) -> str:
|
|
"""
|
|
Return the Diplomacy phase string that chronologically follows *short*.
|
|
(E.g. S1901M → S1901R, S1901R → W1901A, W1901A → S1902M)
|
|
"""
|
|
season = short[0] # 'S' | 'W'
|
|
year = int(short[1:5])
|
|
typ = short[-1] # 'M' | 'R' | 'A'
|
|
|
|
idx = _PHASE_ORDER.index(typ)
|
|
if idx < 2: # still in the same season
|
|
return f"{season}{year}{_PHASE_ORDER[idx + 1]}"
|
|
|
|
# typ was 'A' → roll season
|
|
if season == "S": # summer → winter, same year
|
|
return f"W{year}M"
|
|
else: # winter→ spring, next year
|
|
return f"S{year + 1}M"
|
|
|
|
|
|
def save_game_state(
|
|
game: Game, agents: Dict[str, DiplomacyAgent], game_history: GameHistory, output_path: str, run_config: Namespace, completed_phase_name: str
|
|
):
|
|
"""
|
|
Serialise the entire game to JSON, preserving per-phase custom metadata
|
|
(e.g. 'state_agents') that may have been written by earlier save passes.
|
|
"""
|
|
logger.info(f"Saving game state to {output_path}…")
|
|
|
|
# ------------------------------------------------------------------ #
|
|
# 1. If the file already exists, cache the per-phase custom blocks. #
|
|
# ------------------------------------------------------------------ #
|
|
previous_phase_extras: Dict[str, Dict[str, Any]] = {}
|
|
if os.path.isfile(output_path):
|
|
try:
|
|
with open(output_path, "r", encoding="utf-8") as fh:
|
|
previous_save = json.load(fh)
|
|
for phase in previous_save.get("phases", []):
|
|
# Keep a copy of *all* non-standard keys so that future
|
|
# additions survive automatically.
|
|
extras = {
|
|
k: v
|
|
for k, v in phase.items()
|
|
if k
|
|
not in {
|
|
"name",
|
|
"orders",
|
|
"results",
|
|
"messages",
|
|
"state",
|
|
"config",
|
|
}
|
|
}
|
|
if extras:
|
|
previous_phase_extras[phase["name"]] = extras
|
|
except Exception as exc:
|
|
logger.warning("Could not load previous save to retain metadata: %s", exc, exc_info=True)
|
|
|
|
# -------------------------------------------------------------- #
|
|
# 2. Build the fresh base structure from the diplomacy library. #
|
|
# -------------------------------------------------------------- #
|
|
saved_game = to_saved_game_format(game)
|
|
|
|
# -------------------------------------------------------------- #
|
|
# 3. Walk every phase and merge the metadata back in. #
|
|
# -------------------------------------------------------------- #
|
|
# Capture the *current* snapshot of every live agent exactly once.
|
|
current_state_agents = {p_name: serialize_agent(p_agent) for p_name, p_agent in agents.items() if not game.powers[p_name].is_eliminated()}
|
|
|
|
for phase_block in saved_game.get("phases", []):
|
|
if int(phase_block["name"][1:5]) > run_config.max_year:
|
|
break
|
|
|
|
phase_name = phase_block["name"]
|
|
|
|
# 3a. Re-attach anything we cached from a previous save.
|
|
if phase_name in previous_phase_extras:
|
|
phase_block.update(previous_phase_extras[phase_name])
|
|
|
|
# 3b. For *this* phase we also inject the fresh agent snapshot
|
|
# and the plans written during the turn.
|
|
if phase_name == completed_phase_name:
|
|
# ---- make run_config serialisable ---------------------------------
|
|
cfg = vars(run_config).copy()
|
|
if "prompts_dir_map" in cfg and isinstance(cfg["prompts_dir_map"], dict):
|
|
cfg["prompts_dir_map"] = {p: str(path) for p, path in cfg["prompts_dir_map"].items()}
|
|
if isinstance(cfg.get("prompts_dir"), Path):
|
|
cfg["prompts_dir"] = str(cfg["prompts_dir"])
|
|
# -------------------------------------------------------------------
|
|
phase_block["config"] = cfg
|
|
phase_block["state_agents"] = current_state_agents
|
|
|
|
# -------------------------------------------------------------- #
|
|
# 4. Attach top-level metadata and write atomically. #
|
|
# -------------------------------------------------------------- #
|
|
saved_game["phase_summaries"] = getattr(game, "phase_summaries", {})
|
|
saved_game["final_agent_states"] = {p_name: {"relationships": a.relationships, "goals": a.goals} for p_name, a in agents.items()}
|
|
|
|
# Filter out phases > max_year
|
|
# saved_game["phases"] = [
|
|
# ph for ph in saved_game["phases"]
|
|
# if int(ph["name"][1:5]) <= run_config.max_year # <= 1902, for example
|
|
# ]
|
|
atomic_write_json(saved_game, output_path)
|
|
|
|
logger.info("Game state saved successfully.")
|
|
|
|
|
|
def load_game_state(
|
|
run_dir: str,
|
|
game_file_name: str,
|
|
run_config: Namespace,
|
|
resume_from_phase: Optional[str] = None,
|
|
) -> Tuple[Game, Dict[str, DiplomacyAgent], GameHistory, Optional[Namespace]]:
|
|
"""Loads and reconstructs the game state from a saved game file."""
|
|
game_file_path = os.path.join(run_dir, game_file_name)
|
|
if not os.path.exists(game_file_path):
|
|
raise FileNotFoundError(f"Cannot resume. Save file not found at: {game_file_path}")
|
|
|
|
logger.info(f"Loading game state from: {game_file_path}")
|
|
with open(game_file_path, "r") as f:
|
|
saved_game_data = json.load(f)
|
|
|
|
# If resuming, find the specified phase and truncate the data after it
|
|
if resume_from_phase:
|
|
logger.info(f"Resuming from phase '{resume_from_phase}'. Truncating subsequent data.")
|
|
try:
|
|
# Find the index of the phase *before* the one we want to resume from.
|
|
# We will start the simulation *at* the resume_from_phase.
|
|
resume_idx = next(i for i, phase in enumerate(saved_game_data["phases"]) if phase["name"] == resume_from_phase)
|
|
# Truncate the list to exclude everything after the resume phase
|
|
# Note: the state saved for a given phase represents the state at the beginning of that phase.
|
|
saved_game_data["phases"] = saved_game_data["phases"][: resume_idx + 1]
|
|
|
|
# Wipe any data that must be regenerated.
|
|
for key in ("orders", "results", "messages"):
|
|
saved_game_data["phases"][-1].pop(key, None)
|
|
logger.info(f"Game history truncated to {len(saved_game_data['phases'])} phases. The next phase to run will be {resume_from_phase}.")
|
|
except StopIteration:
|
|
# If the phase is not found, maybe it's the first phase (S1901M)
|
|
if resume_from_phase == "S1901M":
|
|
saved_game_data["phases"] = []
|
|
logger.info("Resuming from S1901M. Starting with a clean history.")
|
|
else:
|
|
raise ValueError(f"Resume phase '{resume_from_phase}' not found in the save file.")
|
|
|
|
# Reconstruct the Game object
|
|
last_phase = saved_game_data["phases"][-1]
|
|
|
|
# Wipe the data that must be regenerated **but preserve the keys**
|
|
last_phase["orders"] = {} # was dict
|
|
last_phase["results"] = {} # was dict
|
|
last_phase["messages"] = []
|
|
|
|
game = from_saved_game_format(saved_game_data)
|
|
|
|
game.phase_summaries = saved_game_data.get("phase_summaries", {})
|
|
|
|
# Reconstruct agents and game history from the *last* valid phase in the data
|
|
if not saved_game_data["phases"]:
|
|
# This happens if we are resuming from the very beginning (S1901M)
|
|
logger.info("No previous phases found. Initializing fresh agents and history.")
|
|
agents = {} # Will be created by the main loop
|
|
game_history = GameHistory()
|
|
else:
|
|
# We save the game state up to & including the current (uncompleted) phase.
|
|
# So we need to grab the agent state from the previous (completed) phase.
|
|
if len(saved_game_data["phases"]) <= 1:
|
|
last_phase_data = {}
|
|
else:
|
|
last_phase_data = saved_game_data["phases"][-2]
|
|
|
|
# -------------------- Rebuild agents -------------------- #
|
|
agents = {}
|
|
|
|
# Build a power→model map from the CLI argument --models, if present.
|
|
power_model_map: Dict[str, str] = {}
|
|
if run_config and getattr(run_config, "models", None):
|
|
provided = [m.strip() for m in run_config.models.split(",")]
|
|
powers_order = sorted(list(ALL_POWERS))
|
|
if len(provided) == len(powers_order):
|
|
power_model_map = dict(zip(powers_order, provided))
|
|
elif len(provided) == 1:
|
|
power_model_map = dict(zip(powers_order, provided * len(powers_order)))
|
|
else:
|
|
raise ValueError(f"Invalid --models argument: expected 1 or {len(powers_order)} items, got {len(provided)}.")
|
|
|
|
if "state_agents" in last_phase_data:
|
|
logger.info("Rebuilding agents from saved state...")
|
|
|
|
for power_name, agent_data in last_phase_data["state_agents"].items():
|
|
override_id = power_model_map.get(power_name)
|
|
prompts_dir_from_config = (
|
|
run_config.prompts_dir_map.get(power_name)
|
|
if getattr(run_config, "prompts_dir_map", None)
|
|
else run_config.prompts_dir # fallback to old single path
|
|
)
|
|
agents[power_name] = deserialize_agent(
|
|
agent_data,
|
|
prompts_dir=prompts_dir_from_config,
|
|
override_model_id=override_id,
|
|
)
|
|
logger.info(f"Rebuilt {len(agents)} agents.")
|
|
else:
|
|
raise ValueError("Cannot resume: 'state_agents' key not found in the last phase of the save file.")
|
|
|
|
# Rebuild GameHistory
|
|
game_history = GameHistory()
|
|
logger.info("Rebuilding game history...")
|
|
for phase_data in saved_game_data["phases"][:-1]:
|
|
phase_name = phase_data["name"]
|
|
game_history.add_phase(phase_name)
|
|
# Add messages
|
|
for msg in phase_data.get("messages", []):
|
|
game_history.add_message(phase_name, msg["sender"], msg["recipient"], msg["message"])
|
|
# Add plans
|
|
if "state_history_plans" in phase_data:
|
|
for p_name, plan in phase_data["state_history_plans"].items():
|
|
game_history.add_plan(phase_name, p_name, plan)
|
|
logger.info("Game history rebuilt.")
|
|
|
|
return game, agents, game_history, run_config
|
|
|
|
|
|
# ai_diplomacy/game_logic.py
|
|
async def initialize_new_game(
|
|
args: Namespace,
|
|
game: Game,
|
|
game_history: GameHistory,
|
|
llm_log_file_path: str,
|
|
) -> Dict[str, DiplomacyAgent]:
|
|
"""Initializes agents for a new game (supports per-power prompt directories)."""
|
|
|
|
powers_order = sorted(list(ALL_POWERS))
|
|
|
|
# Parse token limits
|
|
default_max_tokens = args.max_tokens
|
|
model_max_tokens = {p: default_max_tokens for p in powers_order}
|
|
|
|
if args.max_tokens_per_model:
|
|
per_model_values = [s.strip() for s in args.max_tokens_per_model.split(",")]
|
|
if len(per_model_values) == 7:
|
|
for power, token_val_str in zip(powers_order, per_model_values):
|
|
model_max_tokens[power] = int(token_val_str)
|
|
else:
|
|
logger.warning("Expected 7 values for --max_tokens_per_model, using default.")
|
|
|
|
# Handle power-model mapping
|
|
if args.models:
|
|
provided_models = [name.strip() for name in args.models.split(",")]
|
|
if len(provided_models) == len(powers_order):
|
|
game.power_model_map = dict(zip(powers_order, provided_models))
|
|
elif len(provided_models) == 1:
|
|
game.power_model_map = dict(zip(powers_order, provided_models * 7))
|
|
else:
|
|
logger.error(
|
|
f"Expected {len(powers_order)} models for --models but got {len(provided_models)}."
|
|
)
|
|
raise Exception(
|
|
"Invalid number of models. Models list must be either exactly 1 or 7 models, comma delimited."
|
|
)
|
|
else:
|
|
game.power_model_map = assign_models_to_powers()
|
|
|
|
agents: Dict[str, DiplomacyAgent] = {}
|
|
initialization_tasks = []
|
|
logger.info("Initializing Diplomacy Agents for each power...")
|
|
|
|
for power_name, model_id in game.power_model_map.items():
|
|
if not game.powers[power_name].is_eliminated():
|
|
# Determine the prompts directory for this power
|
|
if hasattr(args, "prompts_dir_map") and args.prompts_dir_map:
|
|
prompts_dir_for_power = args.prompts_dir_map.get(power_name, args.prompts_dir)
|
|
else:
|
|
prompts_dir_for_power = args.prompts_dir
|
|
|
|
try:
|
|
client = load_model_client(model_id, prompts_dir=prompts_dir_for_power)
|
|
client.max_tokens = model_max_tokens[power_name]
|
|
agent = DiplomacyAgent(
|
|
power_name=power_name,
|
|
client=client,
|
|
prompts_dir=prompts_dir_for_power,
|
|
)
|
|
agents[power_name] = agent
|
|
logger.info(f"Preparing initialization task for {power_name} with model {model_id}")
|
|
initialization_tasks.append(
|
|
initialize_agent_state_ext(
|
|
agent,
|
|
game,
|
|
game_history,
|
|
llm_log_file_path,
|
|
prompts_dir=prompts_dir_for_power,
|
|
)
|
|
)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Failed to create agent or client for {power_name} with model {model_id}: {e}",
|
|
exc_info=True,
|
|
)
|
|
|
|
logger.info(f"Running {len(initialization_tasks)} agent initializations concurrently...")
|
|
initialization_results = await asyncio.gather(*initialization_tasks, return_exceptions=True)
|
|
|
|
initialized_powers = list(agents.keys())
|
|
for i, result in enumerate(initialization_results):
|
|
if i < len(initialized_powers):
|
|
power_name = initialized_powers[i]
|
|
if isinstance(result, Exception):
|
|
logger.error(f"Failed to initialize agent state for {power_name}: {result}", exc_info=result)
|
|
else:
|
|
logger.info(f"Successfully initialized agent state for {power_name}.")
|
|
|
|
return agents
|
|
|