mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-19 12:58:09 +00:00
fix missing data in gamehistory object on resume
This commit is contained in:
parent
4a19548cac
commit
7c59153ab4
1 changed files with 119 additions and 124 deletions
|
|
@ -80,25 +80,26 @@ def _phase_year(phase_name: str) -> Optional[int]:
|
|||
|
||||
|
||||
def save_game_state(
|
||||
game: Game, agents: Dict[str, DiplomacyAgent], game_history: GameHistory, output_path: str, run_config: Namespace, completed_phase_name: str
|
||||
game: "Game",
|
||||
agents: Dict[str, "DiplomacyAgent"],
|
||||
game_history: "GameHistory",
|
||||
output_path: str,
|
||||
run_config,
|
||||
completed_phase_name: str,
|
||||
):
|
||||
"""
|
||||
Serialise the entire game to JSON, preserving per-phase custom metadata
|
||||
(e.g. 'state_agents') that may have been written by earlier save passes.
|
||||
Serialise the entire game to JSON, preserving per-phase custom metadata and
|
||||
adding `state_phase_summaries` for every completed phase.
|
||||
"""
|
||||
logger.info(f"Saving game state to {output_path}…")
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# 1. If the file already exists, cache the per-phase custom blocks. #
|
||||
# ------------------------------------------------------------------ #
|
||||
# 1. If a previous save exists, cache its extra per-phase keys -------------
|
||||
previous_phase_extras: Dict[str, Dict[str, Any]] = {}
|
||||
if os.path.isfile(output_path):
|
||||
try:
|
||||
with open(output_path, "r", encoding="utf-8") as fh:
|
||||
previous_save = json.load(fh)
|
||||
for phase in previous_save.get("phases", []):
|
||||
# Keep a copy of *all* non-standard keys so that future
|
||||
# additions survive automatically.
|
||||
extras = {
|
||||
k: v
|
||||
for k, v in phase.items()
|
||||
|
|
@ -117,67 +118,64 @@ def save_game_state(
|
|||
except Exception as exc:
|
||||
logger.warning("Could not load previous save to retain metadata: %s", exc, exc_info=True)
|
||||
|
||||
# -------------------------------------------------------------- #
|
||||
# 2. Build the fresh base structure from the diplomacy library. #
|
||||
# -------------------------------------------------------------- #
|
||||
# 2. Base structure from diplomacy-python ---------------------------------
|
||||
saved_game = to_saved_game_format(game)
|
||||
|
||||
# -------------------------------------------------------------- #
|
||||
# 3. Walk every phase and merge the metadata back in. #
|
||||
# -------------------------------------------------------------- #
|
||||
# Capture the *current* snapshot of every live agent exactly once.
|
||||
current_state_agents = {p_name: serialize_agent(p_agent) for p_name, p_agent in agents.items() if not game.powers[p_name].is_eliminated()}
|
||||
# 3. Re-insert extras, order_results, phase_summaries, state_agents --------
|
||||
current_state_agents = {
|
||||
p_name: serialize_agent(p_agent)
|
||||
for p_name, p_agent in agents.items()
|
||||
if not game.powers[p_name].is_eliminated()
|
||||
}
|
||||
|
||||
for phase_block in saved_game.get("phases", []):
|
||||
year_val = _phase_year(phase_block["name"])
|
||||
if year_val is not None and year_val > run_config.max_year:
|
||||
break
|
||||
phase_name = phase_block["name"]
|
||||
|
||||
phase_name = phase_block["name"]
|
||||
|
||||
# 3a. Re-attach anything we cached from a previous save.
|
||||
# 3a. Merge cached extras
|
||||
if phase_name in previous_phase_extras:
|
||||
phase_block.update(previous_phase_extras[phase_name])
|
||||
|
||||
# 3b. For *this* phase we also inject the fresh agent snapshot
|
||||
# and the plans written during the turn.
|
||||
# 3b. Inject data only for the newly completed phase
|
||||
if phase_name == completed_phase_name:
|
||||
# ---- make run_config serialisable ---------------------------------
|
||||
# Config made JSON-safe
|
||||
cfg = vars(run_config).copy()
|
||||
if "prompts_dir_map" in cfg and isinstance(cfg["prompts_dir_map"], dict):
|
||||
cfg["prompts_dir_map"] = {p: str(path) for p, path in cfg["prompts_dir_map"].items()}
|
||||
if isinstance(cfg.get("prompts_dir"), Path):
|
||||
if isinstance(cfg.get("prompts_dir"), os.PathLike):
|
||||
cfg["prompts_dir"] = str(cfg["prompts_dir"])
|
||||
# -------------------------------------------------------------------
|
||||
if "prompts_dir_map" in cfg and isinstance(cfg["prompts_dir_map"], dict):
|
||||
cfg["prompts_dir_map"] = {p: str(v) for p, v in cfg["prompts_dir_map"].items()}
|
||||
|
||||
phase_block["config"] = cfg
|
||||
phase_block["state_agents"] = current_state_agents
|
||||
phase_block["order_results"] = game_history.get_orders_history_for_phase(
|
||||
game, completed_phase_name
|
||||
)
|
||||
phase_block["order_results"] = game_history.get_orders_history_for_phase(game, completed_phase_name)
|
||||
|
||||
# -------------------------------------------------------------- #
|
||||
# 4. Attach top-level metadata and write atomically. #
|
||||
# -------------------------------------------------------------- #
|
||||
# NEW: save per-power phase summaries
|
||||
hist = game_history._get_phase(phase_name)
|
||||
if hist and hist.phase_summaries:
|
||||
phase_block["state_phase_summaries"] = hist.phase_summaries
|
||||
|
||||
# 4. Top-level metadata ----------------------------------------------------
|
||||
saved_game["phase_summaries"] = getattr(game, "phase_summaries", {})
|
||||
saved_game["final_agent_states"] = {p_name: {"relationships": a.relationships, "goals": a.goals} for p_name, a in agents.items()}
|
||||
saved_game["final_agent_states"] = {
|
||||
p_name: {"relationships": a.relationships, "goals": a.goals} for p_name, a in agents.items()
|
||||
}
|
||||
|
||||
# Filter out phases > max_year
|
||||
# saved_game["phases"] = [
|
||||
# ph for ph in saved_game["phases"]
|
||||
# if int(ph["name"][1:5]) <= run_config.max_year # <= 1902, for example
|
||||
# ]
|
||||
atomic_write_json(saved_game, output_path)
|
||||
|
||||
logger.info("Game state saved successfully.")
|
||||
|
||||
|
||||
def load_game_state(
|
||||
run_dir: str,
|
||||
game_file_name: str,
|
||||
run_config: Namespace,
|
||||
run_config,
|
||||
resume_from_phase: Optional[str] = None,
|
||||
) -> Tuple[Game, Dict[str, DiplomacyAgent], GameHistory, Optional[Namespace]]:
|
||||
"""Loads and reconstructs the game state from a saved game file."""
|
||||
) -> Tuple["Game", Dict[str, "DiplomacyAgent"], "GameHistory", Optional[Any]]:
|
||||
"""
|
||||
Load and fully re-hydrate the game, agents and GameHistory – including
|
||||
`orders_by_power`, `results_by_power`, `submitted_orders_by_power`,
|
||||
and per-power `phase_summaries`.
|
||||
"""
|
||||
from collections import defaultdict # local to avoid new global import
|
||||
|
||||
game_file_path = os.path.join(run_dir, game_file_name)
|
||||
if not os.path.exists(game_file_path):
|
||||
raise FileNotFoundError(f"Cannot resume. Save file not found at: {game_file_path}")
|
||||
|
|
@ -186,103 +184,100 @@ def load_game_state(
|
|||
with open(game_file_path, "r") as f:
|
||||
saved_game_data = json.load(f)
|
||||
|
||||
# If resuming, find the specified phase and truncate the data after it
|
||||
# --- Trim history if --resume_from_phase was requested --------------------
|
||||
if resume_from_phase:
|
||||
logger.info(f"Resuming from phase '{resume_from_phase}'. Truncating subsequent data.")
|
||||
try:
|
||||
# Find the index of the phase *before* the one we want to resume from.
|
||||
# We will start the simulation *at* the resume_from_phase.
|
||||
resume_idx = next(i for i, phase in enumerate(saved_game_data["phases"]) if phase["name"] == resume_from_phase)
|
||||
# Truncate the list to exclude everything after the resume phase
|
||||
# Note: the state saved for a given phase represents the state at the beginning of that phase.
|
||||
resume_idx = next(i for i, ph in enumerate(saved_game_data["phases"]) if ph["name"] == resume_from_phase)
|
||||
saved_game_data["phases"] = saved_game_data["phases"][: resume_idx + 1]
|
||||
|
||||
# Wipe any data that must be regenerated.
|
||||
for key in ("orders", "results", "messages"):
|
||||
saved_game_data["phases"][-1].pop(key, None)
|
||||
logger.info(f"Game history truncated to {len(saved_game_data['phases'])} phases. The next phase to run will be {resume_from_phase}.")
|
||||
for k in ("orders", "results", "messages"):
|
||||
saved_game_data["phases"][-1].pop(k, None)
|
||||
logger.info("Game history truncated for resume.")
|
||||
except StopIteration:
|
||||
# If the phase is not found, maybe it's the first phase (S1901M)
|
||||
if resume_from_phase == "S1901M":
|
||||
saved_game_data["phases"] = []
|
||||
logger.info("Resuming from S1901M. Starting with a clean history.")
|
||||
logger.info("Resuming from start – clean history.")
|
||||
else:
|
||||
raise ValueError(f"Resume phase '{resume_from_phase}' not found in the save file.")
|
||||
|
||||
# Reconstruct the Game object
|
||||
last_phase = saved_game_data["phases"][-1]
|
||||
|
||||
# Wipe the data that must be regenerated **but preserve the keys**
|
||||
last_phase["orders"] = {}
|
||||
last_phase["results"] = {}
|
||||
last_phase["messages"] = []
|
||||
|
||||
# --- Reconstruct Game object ---------------------------------------------
|
||||
if saved_game_data.get("phases"):
|
||||
saved_game_data["phases"][-1].update({"orders": {}, "results": {}, "messages": []})
|
||||
game = from_saved_game_format(saved_game_data)
|
||||
|
||||
game.phase_summaries = saved_game_data.get("phase_summaries", {})
|
||||
|
||||
# Reconstruct agents and game history from the *last* valid phase in the data
|
||||
if not saved_game_data["phases"]:
|
||||
# This happens if we are resuming from the very beginning (S1901M)
|
||||
logger.info("No previous phases found. Initializing fresh agents and history.")
|
||||
agents = {} # Will be created by the main loop
|
||||
game_history = GameHistory()
|
||||
else:
|
||||
# We save the game state up to & including the current (uncompleted) phase.
|
||||
# So we need to grab the agent state from the previous (completed) phase.
|
||||
if len(saved_game_data["phases"]) <= 1:
|
||||
last_phase_data = {}
|
||||
# --- Rebuild agents -------------------------------------------------------
|
||||
agents: Dict[str, "DiplomacyAgent"] = {}
|
||||
power_model_map: Dict[str, str] = {}
|
||||
if run_config and getattr(run_config, "models", None):
|
||||
provided = [m.strip() for m in run_config.models.split(",")]
|
||||
powers_order = sorted(list(ALL_POWERS))
|
||||
if len(provided) == len(powers_order):
|
||||
power_model_map = dict(zip(powers_order, provided))
|
||||
elif len(provided) == 1:
|
||||
power_model_map = dict(zip(powers_order, provided * len(powers_order)))
|
||||
else:
|
||||
last_phase_data = saved_game_data["phases"][-2]
|
||||
raise ValueError(f"Invalid --models argument: expected 1 or {len(powers_order)} items, got {len(provided)}.")
|
||||
|
||||
# -------------------- Rebuild agents -------------------- #
|
||||
agents = {}
|
||||
if saved_game_data.get("phases"):
|
||||
last_phase_data = saved_game_data["phases"][-2] if len(saved_game_data["phases"]) > 1 else {}
|
||||
if "state_agents" not in last_phase_data:
|
||||
raise ValueError("Cannot resume: 'state_agents' key missing in last completed phase.")
|
||||
|
||||
# Build a power→model map from the CLI argument --models, if present.
|
||||
power_model_map: Dict[str, str] = {}
|
||||
if run_config and getattr(run_config, "models", None):
|
||||
provided = [m.strip() for m in run_config.models.split(",")]
|
||||
powers_order = sorted(list(ALL_POWERS))
|
||||
if len(provided) == len(powers_order):
|
||||
power_model_map = dict(zip(powers_order, provided))
|
||||
elif len(provided) == 1:
|
||||
power_model_map = dict(zip(powers_order, provided * len(powers_order)))
|
||||
for power_name, agent_data in last_phase_data["state_agents"].items():
|
||||
override_id = power_model_map.get(power_name)
|
||||
prompts_dir_from_config = (
|
||||
run_config.prompts_dir_map.get(power_name)
|
||||
if getattr(run_config, "prompts_dir_map", None)
|
||||
else run_config.prompts_dir
|
||||
)
|
||||
agents[power_name] = deserialize_agent(
|
||||
agent_data,
|
||||
prompts_dir=prompts_dir_from_config,
|
||||
override_model_id=override_id,
|
||||
)
|
||||
|
||||
# --- Rebuild GameHistory --------------------------------------------------
|
||||
game_history = GameHistory()
|
||||
for phase_data in saved_game_data["phases"][:-1]:
|
||||
phase_name = phase_data["name"]
|
||||
game_history.add_phase(phase_name)
|
||||
ph_obj = game_history._get_phase(phase_name)
|
||||
|
||||
# Messages
|
||||
for msg in phase_data.get("messages", []):
|
||||
game_history.add_message(phase_name, msg["sender"], msg["recipient"], msg["message"])
|
||||
|
||||
# Plans
|
||||
for p_name, plan in phase_data.get("state_history_plans", {}).items():
|
||||
game_history.add_plan(phase_name, p_name, plan)
|
||||
|
||||
# --- NEW restorations --------------------------------------------------
|
||||
# Accepted orders
|
||||
ph_obj.orders_by_power = defaultdict(list, phase_data.get("orders", {}))
|
||||
|
||||
# Results (wrap scalar -> list[list[str]])
|
||||
ph_obj.results_by_power = defaultdict(list)
|
||||
for pwr, res_list in phase_data.get("results", {}).items():
|
||||
if res_list and isinstance(res_list[0], list):
|
||||
ph_obj.results_by_power[pwr] = res_list
|
||||
else:
|
||||
raise ValueError(f"Invalid --models argument: expected 1 or {len(powers_order)} items, got {len(provided)}.")
|
||||
ph_obj.results_by_power[pwr] = [[r] for r in res_list]
|
||||
|
||||
if "state_agents" in last_phase_data:
|
||||
logger.info("Rebuilding agents from saved state...")
|
||||
|
||||
for power_name, agent_data in last_phase_data["state_agents"].items():
|
||||
override_id = power_model_map.get(power_name)
|
||||
prompts_dir_from_config = (
|
||||
run_config.prompts_dir_map.get(power_name)
|
||||
if getattr(run_config, "prompts_dir_map", None)
|
||||
else run_config.prompts_dir # fallback to old single path
|
||||
)
|
||||
agents[power_name] = deserialize_agent(
|
||||
agent_data,
|
||||
prompts_dir=prompts_dir_from_config,
|
||||
override_model_id=override_id,
|
||||
)
|
||||
logger.info(f"Rebuilt {len(agents)} agents.")
|
||||
else:
|
||||
raise ValueError("Cannot resume: 'state_agents' key not found in the last phase of the save file.")
|
||||
# Phase summaries
|
||||
ph_obj.phase_summaries = phase_data.get("state_phase_summaries", {})
|
||||
|
||||
# Rebuild GameHistory
|
||||
game_history = GameHistory()
|
||||
logger.info("Rebuilding game history...")
|
||||
for phase_data in saved_game_data["phases"][:-1]:
|
||||
phase_name = phase_data["name"]
|
||||
game_history.add_phase(phase_name)
|
||||
# Add messages
|
||||
for msg in phase_data.get("messages", []):
|
||||
game_history.add_message(phase_name, msg["sender"], msg["recipient"], msg["message"])
|
||||
# Add plans
|
||||
if "state_history_plans" in phase_data:
|
||||
for p_name, plan in phase_data["state_history_plans"].items():
|
||||
game_history.add_plan(phase_name, p_name, plan)
|
||||
logger.info("Game history rebuilt.")
|
||||
# Submitted orders reconstructed from order_results
|
||||
submitted = defaultdict(list)
|
||||
for pwr, type_map in phase_data.get("order_results", {}).items():
|
||||
for lst in type_map.values():
|
||||
for entry in lst:
|
||||
if isinstance(entry, dict):
|
||||
order_str = entry.get("order")
|
||||
else:
|
||||
order_str = entry
|
||||
if order_str:
|
||||
submitted[pwr].append(order_str)
|
||||
ph_obj.submitted_orders_by_power = submitted
|
||||
|
||||
return game, agents, game_history, run_config
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue