AI_Diplomacy/ai_diplomacy/game_logic.py

375 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# ai_diplomacy/game_logic.py
import logging
import os
import json
import asyncio
from typing import Dict, Tuple, Optional, Any
from argparse import Namespace
from pathlib import Path
import re
from diplomacy import Game
from diplomacy.utils.export import to_saved_game_format, from_saved_game_format
from .agent import DiplomacyAgent, ALL_POWERS
from .clients import load_model_client
from .game_history import GameHistory
from .initialization import initialize_agent_state_ext
from .utils import atomic_write_json, assign_models_to_powers
logger = logging.getLogger(__name__)
# --- Serialization / Deserialization ---
def serialize_agent(agent: DiplomacyAgent) -> dict:
"""Converts an agent object to a JSON-serializable dictionary."""
return {
"power_name": agent.power_name,
"model_id": agent.client.model_name,
"max_tokens": agent.client.max_tokens,
"goals": agent.goals,
"relationships": agent.relationships,
"full_private_diary": agent.full_private_diary,
"private_diary": agent.private_diary,
}
def deserialize_agent(agent_data: dict, prompts_dir: Optional[str] = None, *, override_model_id: Optional[str] = None) -> DiplomacyAgent:
"""
Recreates an agent object from a dictionary.
If *override_model_id* is provided (e.g. because the CLI argument
``--models`` was used when resuming a game), that model is loaded
instead of the one stored in the save file.
"""
model_id = override_model_id or agent_data["model_id"]
client = load_model_client(model_id, prompts_dir=prompts_dir)
# Keep the original or fallback token limit exactly as before.
client.max_tokens = agent_data.get("max_tokens", 16000)
agent = DiplomacyAgent(
power_name=agent_data["power_name"],
client=client,
initial_goals=agent_data.get("goals", []),
initial_relationships=agent_data.get("relationships", None),
prompts_dir=prompts_dir,
)
# Restore diary state
agent.full_private_diary = agent_data.get("full_private_diary", [])
agent.private_diary = agent_data.get("private_diary", [])
return agent
# --- State Management ---
_PHASE_RE = re.compile(r"^[SW](\d{4})[MRA]$")
def _phase_year(phase_name: str) -> Optional[int]:
"""
Return the four-digit year encoded in standard phase strings
like 'S1901M'. For anything non-standard (e.g. 'COMPLETE')
return None so callers can decide how to handle it.
"""
m = _PHASE_RE.match(phase_name)
return int(m.group(1)) if m else None
def save_game_state(
game: "Game",
agents: Dict[str, "DiplomacyAgent"],
game_history: "GameHistory",
output_path: str,
run_config,
completed_phase_name: str,
):
"""
Serialise the entire game to JSON, preserving per-phase custom metadata and
adding `state_phase_summaries` for every completed phase.
"""
logger.info(f"Saving game state to {output_path}")
# 1. If a previous save exists, cache its extra per-phase keys -------------
previous_phase_extras: Dict[str, Dict[str, Any]] = {}
if os.path.isfile(output_path):
try:
with open(output_path, "r", encoding="utf-8") as fh:
previous_save = json.load(fh)
for phase in previous_save.get("phases", []):
extras = {
k: v
for k, v in phase.items()
if k
not in {
"name",
"orders",
"results",
"messages",
"state",
"config",
}
}
if extras:
previous_phase_extras[phase["name"]] = extras
except Exception as exc:
logger.warning("Could not load previous save to retain metadata: %s", exc, exc_info=True)
# 2. Base structure from diplomacy-python ---------------------------------
saved_game = to_saved_game_format(game)
# 3. Re-insert extras, order_results, phase_summaries, state_agents --------
current_state_agents = {
p_name: serialize_agent(p_agent)
for p_name, p_agent in agents.items()
if not game.powers[p_name].is_eliminated()
}
for phase_block in saved_game.get("phases", []):
phase_name = phase_block["name"]
# 3a. Merge cached extras
if phase_name in previous_phase_extras:
phase_block.update(previous_phase_extras[phase_name])
# 3b. Inject data only for the newly completed phase
if phase_name == completed_phase_name:
# Config made JSON-safe
cfg = vars(run_config).copy()
if isinstance(cfg.get("prompts_dir"), os.PathLike):
cfg["prompts_dir"] = str(cfg["prompts_dir"])
if "prompts_dir_map" in cfg and isinstance(cfg["prompts_dir_map"], dict):
cfg["prompts_dir_map"] = {p: str(v) for p, v in cfg["prompts_dir_map"].items()}
phase_block["config"] = cfg
phase_block["state_agents"] = current_state_agents
phase_block["order_results"] = game_history.get_orders_history_for_phase(game, completed_phase_name)
# NEW: save per-power phase summaries
hist = game_history._get_phase(phase_name)
if hist and hist.phase_summaries:
phase_block["state_phase_summaries"] = hist.phase_summaries
# 4. Top-level metadata ----------------------------------------------------
saved_game["phase_summaries"] = getattr(game, "phase_summaries", {})
saved_game["final_agent_states"] = {
p_name: {"relationships": a.relationships, "goals": a.goals} for p_name, a in agents.items()
}
atomic_write_json(saved_game, output_path)
logger.info("Game state saved successfully.")
def load_game_state(
run_dir: str,
game_file_name: str,
run_config,
resume_from_phase: Optional[str] = None,
) -> Tuple["Game", Dict[str, "DiplomacyAgent"], "GameHistory", Optional[Any]]:
"""
Load and fully re-hydrate the game, agents and GameHistory including
`orders_by_power`, `results_by_power`, `submitted_orders_by_power`,
and per-power `phase_summaries`.
"""
from collections import defaultdict # local to avoid new global import
game_file_path = os.path.join(run_dir, game_file_name)
if not os.path.exists(game_file_path):
raise FileNotFoundError(f"Cannot resume. Save file not found at: {game_file_path}")
logger.info(f"Loading game state from: {game_file_path}")
with open(game_file_path, "r") as f:
saved_game_data = json.load(f)
# --- Trim history if --resume_from_phase was requested --------------------
if resume_from_phase:
try:
resume_idx = next(i for i, ph in enumerate(saved_game_data["phases"]) if ph["name"] == resume_from_phase)
saved_game_data["phases"] = saved_game_data["phases"][: resume_idx + 1]
for k in ("orders", "results", "messages"):
saved_game_data["phases"][-1].pop(k, None)
logger.info("Game history truncated for resume.")
except StopIteration:
if resume_from_phase == "S1901M":
saved_game_data["phases"] = []
logger.info("Resuming from start clean history.")
else:
raise ValueError(f"Resume phase '{resume_from_phase}' not found in the save file.")
# --- Reconstruct Game object ---------------------------------------------
if saved_game_data.get("phases"):
saved_game_data["phases"][-1].update({"orders": {}, "results": {}, "messages": []})
game = from_saved_game_format(saved_game_data)
game.phase_summaries = saved_game_data.get("phase_summaries", {})
# --- Rebuild agents -------------------------------------------------------
agents: Dict[str, "DiplomacyAgent"] = {}
power_model_map: Dict[str, str] = {}
if run_config and getattr(run_config, "models", None):
provided = [m.strip() for m in run_config.models.split(",")]
powers_order = sorted(list(ALL_POWERS))
if len(provided) == len(powers_order):
power_model_map = dict(zip(powers_order, provided))
elif len(provided) == 1:
power_model_map = dict(zip(powers_order, provided * len(powers_order)))
else:
raise ValueError(f"Invalid --models argument: expected 1 or {len(powers_order)} items, got {len(provided)}.")
if saved_game_data.get("phases"):
last_phase_data = saved_game_data["phases"][-2] if len(saved_game_data["phases"]) > 1 else {}
if "state_agents" not in last_phase_data:
raise ValueError("Cannot resume: 'state_agents' key missing in last completed phase.")
for power_name, agent_data in last_phase_data["state_agents"].items():
override_id = power_model_map.get(power_name)
prompts_dir_from_config = (
run_config.prompts_dir_map.get(power_name)
if getattr(run_config, "prompts_dir_map", None)
else run_config.prompts_dir
)
agents[power_name] = deserialize_agent(
agent_data,
prompts_dir=prompts_dir_from_config,
override_model_id=override_id,
)
# --- Rebuild GameHistory --------------------------------------------------
game_history = GameHistory()
for phase_data in saved_game_data["phases"][:-1]:
phase_name = phase_data["name"]
game_history.add_phase(phase_name)
ph_obj = game_history._get_phase(phase_name)
# Messages
for msg in phase_data.get("messages", []):
game_history.add_message(phase_name, msg["sender"], msg["recipient"], msg["message"])
# Plans
for p_name, plan in phase_data.get("state_history_plans", {}).items():
game_history.add_plan(phase_name, p_name, plan)
# --- NEW restorations --------------------------------------------------
# Accepted orders
ph_obj.orders_by_power = defaultdict(list, phase_data.get("orders", {}))
# Results (wrap scalar -> list[list[str]])
ph_obj.results_by_power = defaultdict(list)
for pwr, res_list in phase_data.get("results", {}).items():
if res_list and isinstance(res_list[0], list):
ph_obj.results_by_power[pwr] = res_list
else:
ph_obj.results_by_power[pwr] = [[r] for r in res_list]
# Phase summaries
ph_obj.phase_summaries = phase_data.get("state_phase_summaries", {})
# Submitted orders reconstructed from order_results
submitted = defaultdict(list)
for pwr, type_map in phase_data.get("order_results", {}).items():
for lst in type_map.values():
for entry in lst:
if isinstance(entry, dict):
order_str = entry.get("order")
else:
order_str = entry
if order_str:
submitted[pwr].append(order_str)
ph_obj.submitted_orders_by_power = submitted
return game, agents, game_history, run_config
# ai_diplomacy/game_logic.py
async def initialize_new_game(
args: Namespace,
game: Game,
game_history: GameHistory,
llm_log_file_path: str,
) -> Dict[str, DiplomacyAgent]:
"""Initializes agents for a new game (supports per-power prompt directories)."""
powers_order = sorted(list(ALL_POWERS))
# Parse token limits
default_max_tokens = args.max_tokens
model_max_tokens = {p: default_max_tokens for p in powers_order}
if args.max_tokens_per_model:
per_model_values = [s.strip() for s in args.max_tokens_per_model.split(",")]
if len(per_model_values) == 7:
for power, token_val_str in zip(powers_order, per_model_values):
model_max_tokens[power] = int(token_val_str)
else:
logger.warning("Expected 7 values for --max_tokens_per_model, using default.")
# Handle power-model mapping
if args.models:
provided_models = [name.strip() for name in args.models.split(",")]
if len(provided_models) == len(powers_order):
game.power_model_map = dict(zip(powers_order, provided_models))
elif len(provided_models) == 1:
game.power_model_map = dict(zip(powers_order, provided_models * 7))
else:
logger.error(
f"Expected {len(powers_order)} models for --models but got {len(provided_models)}."
)
raise Exception(
"Invalid number of models. Models list must be either exactly 1 or 7 models, comma delimited."
)
else:
game.power_model_map = assign_models_to_powers()
agents: Dict[str, DiplomacyAgent] = {}
initialization_tasks = []
logger.info("Initializing Diplomacy Agents for each power...")
for power_name, model_id in game.power_model_map.items():
if not game.powers[power_name].is_eliminated():
# Determine the prompts directory for this power
if hasattr(args, "prompts_dir_map") and args.prompts_dir_map:
prompts_dir_for_power = args.prompts_dir_map.get(power_name, args.prompts_dir)
else:
prompts_dir_for_power = args.prompts_dir
try:
client = load_model_client(model_id, prompts_dir=prompts_dir_for_power)
client.max_tokens = model_max_tokens[power_name]
agent = DiplomacyAgent(
power_name=power_name,
client=client,
prompts_dir=prompts_dir_for_power,
)
agents[power_name] = agent
logger.info(f"Preparing initialization task for {power_name} with model {model_id}")
initialization_tasks.append(
initialize_agent_state_ext(
agent,
game,
game_history,
llm_log_file_path,
prompts_dir=prompts_dir_for_power,
)
)
except Exception as e:
logger.error(
f"Failed to create agent or client for {power_name} with model {model_id}: {e}",
exc_info=True,
)
logger.info(f"Running {len(initialization_tasks)} agent initializations concurrently...")
initialization_results = await asyncio.gather(*initialization_tasks, return_exceptions=True)
initialized_powers = list(agents.keys())
for i, result in enumerate(initialization_results):
if i < len(initialized_powers):
power_name = initialized_powers[i]
if isinstance(result, Exception):
logger.error(f"Failed to initialize agent state for {power_name}: {result}", exc_info=result)
else:
logger.info(f"Successfully initialized agent state for {power_name}.")
return agents