fix missing data in gamehistory object on resume

This commit is contained in:
sam-paech 2025-07-16 16:51:48 +10:00
parent 4a19548cac
commit 7c59153ab4

View file

@ -80,25 +80,26 @@ def _phase_year(phase_name: str) -> Optional[int]:
def save_game_state(
game: Game, agents: Dict[str, DiplomacyAgent], game_history: GameHistory, output_path: str, run_config: Namespace, completed_phase_name: str
game: "Game",
agents: Dict[str, "DiplomacyAgent"],
game_history: "GameHistory",
output_path: str,
run_config,
completed_phase_name: str,
):
"""
Serialise the entire game to JSON, preserving per-phase custom metadata
(e.g. 'state_agents') that may have been written by earlier save passes.
Serialise the entire game to JSON, preserving per-phase custom metadata and
adding `state_phase_summaries` for every completed phase.
"""
logger.info(f"Saving game state to {output_path}")
# ------------------------------------------------------------------ #
# 1. If the file already exists, cache the per-phase custom blocks. #
# ------------------------------------------------------------------ #
# 1. If a previous save exists, cache its extra per-phase keys -------------
previous_phase_extras: Dict[str, Dict[str, Any]] = {}
if os.path.isfile(output_path):
try:
with open(output_path, "r", encoding="utf-8") as fh:
previous_save = json.load(fh)
for phase in previous_save.get("phases", []):
# Keep a copy of *all* non-standard keys so that future
# additions survive automatically.
extras = {
k: v
for k, v in phase.items()
@ -117,67 +118,64 @@ def save_game_state(
except Exception as exc:
logger.warning("Could not load previous save to retain metadata: %s", exc, exc_info=True)
# -------------------------------------------------------------- #
# 2. Build the fresh base structure from the diplomacy library. #
# -------------------------------------------------------------- #
# 2. Base structure from diplomacy-python ---------------------------------
saved_game = to_saved_game_format(game)
# -------------------------------------------------------------- #
# 3. Walk every phase and merge the metadata back in. #
# -------------------------------------------------------------- #
# Capture the *current* snapshot of every live agent exactly once.
current_state_agents = {p_name: serialize_agent(p_agent) for p_name, p_agent in agents.items() if not game.powers[p_name].is_eliminated()}
# 3. Re-insert extras, order_results, phase_summaries, state_agents --------
current_state_agents = {
p_name: serialize_agent(p_agent)
for p_name, p_agent in agents.items()
if not game.powers[p_name].is_eliminated()
}
for phase_block in saved_game.get("phases", []):
year_val = _phase_year(phase_block["name"])
if year_val is not None and year_val > run_config.max_year:
break
phase_name = phase_block["name"]
phase_name = phase_block["name"]
# 3a. Re-attach anything we cached from a previous save.
# 3a. Merge cached extras
if phase_name in previous_phase_extras:
phase_block.update(previous_phase_extras[phase_name])
# 3b. For *this* phase we also inject the fresh agent snapshot
# and the plans written during the turn.
# 3b. Inject data only for the newly completed phase
if phase_name == completed_phase_name:
# ---- make run_config serialisable ---------------------------------
# Config made JSON-safe
cfg = vars(run_config).copy()
if "prompts_dir_map" in cfg and isinstance(cfg["prompts_dir_map"], dict):
cfg["prompts_dir_map"] = {p: str(path) for p, path in cfg["prompts_dir_map"].items()}
if isinstance(cfg.get("prompts_dir"), Path):
if isinstance(cfg.get("prompts_dir"), os.PathLike):
cfg["prompts_dir"] = str(cfg["prompts_dir"])
# -------------------------------------------------------------------
if "prompts_dir_map" in cfg and isinstance(cfg["prompts_dir_map"], dict):
cfg["prompts_dir_map"] = {p: str(v) for p, v in cfg["prompts_dir_map"].items()}
phase_block["config"] = cfg
phase_block["state_agents"] = current_state_agents
phase_block["order_results"] = game_history.get_orders_history_for_phase(
game, completed_phase_name
)
phase_block["order_results"] = game_history.get_orders_history_for_phase(game, completed_phase_name)
# -------------------------------------------------------------- #
# 4. Attach top-level metadata and write atomically. #
# -------------------------------------------------------------- #
# NEW: save per-power phase summaries
hist = game_history._get_phase(phase_name)
if hist and hist.phase_summaries:
phase_block["state_phase_summaries"] = hist.phase_summaries
# 4. Top-level metadata ----------------------------------------------------
saved_game["phase_summaries"] = getattr(game, "phase_summaries", {})
saved_game["final_agent_states"] = {p_name: {"relationships": a.relationships, "goals": a.goals} for p_name, a in agents.items()}
saved_game["final_agent_states"] = {
p_name: {"relationships": a.relationships, "goals": a.goals} for p_name, a in agents.items()
}
# Filter out phases > max_year
# saved_game["phases"] = [
# ph for ph in saved_game["phases"]
# if int(ph["name"][1:5]) <= run_config.max_year # <= 1902, for example
# ]
atomic_write_json(saved_game, output_path)
logger.info("Game state saved successfully.")
def load_game_state(
run_dir: str,
game_file_name: str,
run_config: Namespace,
run_config,
resume_from_phase: Optional[str] = None,
) -> Tuple[Game, Dict[str, DiplomacyAgent], GameHistory, Optional[Namespace]]:
"""Loads and reconstructs the game state from a saved game file."""
) -> Tuple["Game", Dict[str, "DiplomacyAgent"], "GameHistory", Optional[Any]]:
"""
Load and fully re-hydrate the game, agents and GameHistory including
`orders_by_power`, `results_by_power`, `submitted_orders_by_power`,
and per-power `phase_summaries`.
"""
from collections import defaultdict # local to avoid new global import
game_file_path = os.path.join(run_dir, game_file_name)
if not os.path.exists(game_file_path):
raise FileNotFoundError(f"Cannot resume. Save file not found at: {game_file_path}")
@ -186,103 +184,100 @@ def load_game_state(
with open(game_file_path, "r") as f:
saved_game_data = json.load(f)
# If resuming, find the specified phase and truncate the data after it
# --- Trim history if --resume_from_phase was requested --------------------
if resume_from_phase:
logger.info(f"Resuming from phase '{resume_from_phase}'. Truncating subsequent data.")
try:
# Find the index of the phase *before* the one we want to resume from.
# We will start the simulation *at* the resume_from_phase.
resume_idx = next(i for i, phase in enumerate(saved_game_data["phases"]) if phase["name"] == resume_from_phase)
# Truncate the list to exclude everything after the resume phase
# Note: the state saved for a given phase represents the state at the beginning of that phase.
resume_idx = next(i for i, ph in enumerate(saved_game_data["phases"]) if ph["name"] == resume_from_phase)
saved_game_data["phases"] = saved_game_data["phases"][: resume_idx + 1]
# Wipe any data that must be regenerated.
for key in ("orders", "results", "messages"):
saved_game_data["phases"][-1].pop(key, None)
logger.info(f"Game history truncated to {len(saved_game_data['phases'])} phases. The next phase to run will be {resume_from_phase}.")
for k in ("orders", "results", "messages"):
saved_game_data["phases"][-1].pop(k, None)
logger.info("Game history truncated for resume.")
except StopIteration:
# If the phase is not found, maybe it's the first phase (S1901M)
if resume_from_phase == "S1901M":
saved_game_data["phases"] = []
logger.info("Resuming from S1901M. Starting with a clean history.")
logger.info("Resuming from start clean history.")
else:
raise ValueError(f"Resume phase '{resume_from_phase}' not found in the save file.")
# Reconstruct the Game object
last_phase = saved_game_data["phases"][-1]
# Wipe the data that must be regenerated **but preserve the keys**
last_phase["orders"] = {}
last_phase["results"] = {}
last_phase["messages"] = []
# --- Reconstruct Game object ---------------------------------------------
if saved_game_data.get("phases"):
saved_game_data["phases"][-1].update({"orders": {}, "results": {}, "messages": []})
game = from_saved_game_format(saved_game_data)
game.phase_summaries = saved_game_data.get("phase_summaries", {})
# Reconstruct agents and game history from the *last* valid phase in the data
if not saved_game_data["phases"]:
# This happens if we are resuming from the very beginning (S1901M)
logger.info("No previous phases found. Initializing fresh agents and history.")
agents = {} # Will be created by the main loop
game_history = GameHistory()
else:
# We save the game state up to & including the current (uncompleted) phase.
# So we need to grab the agent state from the previous (completed) phase.
if len(saved_game_data["phases"]) <= 1:
last_phase_data = {}
# --- Rebuild agents -------------------------------------------------------
agents: Dict[str, "DiplomacyAgent"] = {}
power_model_map: Dict[str, str] = {}
if run_config and getattr(run_config, "models", None):
provided = [m.strip() for m in run_config.models.split(",")]
powers_order = sorted(list(ALL_POWERS))
if len(provided) == len(powers_order):
power_model_map = dict(zip(powers_order, provided))
elif len(provided) == 1:
power_model_map = dict(zip(powers_order, provided * len(powers_order)))
else:
last_phase_data = saved_game_data["phases"][-2]
raise ValueError(f"Invalid --models argument: expected 1 or {len(powers_order)} items, got {len(provided)}.")
# -------------------- Rebuild agents -------------------- #
agents = {}
if saved_game_data.get("phases"):
last_phase_data = saved_game_data["phases"][-2] if len(saved_game_data["phases"]) > 1 else {}
if "state_agents" not in last_phase_data:
raise ValueError("Cannot resume: 'state_agents' key missing in last completed phase.")
# Build a power→model map from the CLI argument --models, if present.
power_model_map: Dict[str, str] = {}
if run_config and getattr(run_config, "models", None):
provided = [m.strip() for m in run_config.models.split(",")]
powers_order = sorted(list(ALL_POWERS))
if len(provided) == len(powers_order):
power_model_map = dict(zip(powers_order, provided))
elif len(provided) == 1:
power_model_map = dict(zip(powers_order, provided * len(powers_order)))
for power_name, agent_data in last_phase_data["state_agents"].items():
override_id = power_model_map.get(power_name)
prompts_dir_from_config = (
run_config.prompts_dir_map.get(power_name)
if getattr(run_config, "prompts_dir_map", None)
else run_config.prompts_dir
)
agents[power_name] = deserialize_agent(
agent_data,
prompts_dir=prompts_dir_from_config,
override_model_id=override_id,
)
# --- Rebuild GameHistory --------------------------------------------------
game_history = GameHistory()
for phase_data in saved_game_data["phases"][:-1]:
phase_name = phase_data["name"]
game_history.add_phase(phase_name)
ph_obj = game_history._get_phase(phase_name)
# Messages
for msg in phase_data.get("messages", []):
game_history.add_message(phase_name, msg["sender"], msg["recipient"], msg["message"])
# Plans
for p_name, plan in phase_data.get("state_history_plans", {}).items():
game_history.add_plan(phase_name, p_name, plan)
# --- NEW restorations --------------------------------------------------
# Accepted orders
ph_obj.orders_by_power = defaultdict(list, phase_data.get("orders", {}))
# Results (wrap scalar -> list[list[str]])
ph_obj.results_by_power = defaultdict(list)
for pwr, res_list in phase_data.get("results", {}).items():
if res_list and isinstance(res_list[0], list):
ph_obj.results_by_power[pwr] = res_list
else:
raise ValueError(f"Invalid --models argument: expected 1 or {len(powers_order)} items, got {len(provided)}.")
ph_obj.results_by_power[pwr] = [[r] for r in res_list]
if "state_agents" in last_phase_data:
logger.info("Rebuilding agents from saved state...")
for power_name, agent_data in last_phase_data["state_agents"].items():
override_id = power_model_map.get(power_name)
prompts_dir_from_config = (
run_config.prompts_dir_map.get(power_name)
if getattr(run_config, "prompts_dir_map", None)
else run_config.prompts_dir # fallback to old single path
)
agents[power_name] = deserialize_agent(
agent_data,
prompts_dir=prompts_dir_from_config,
override_model_id=override_id,
)
logger.info(f"Rebuilt {len(agents)} agents.")
else:
raise ValueError("Cannot resume: 'state_agents' key not found in the last phase of the save file.")
# Phase summaries
ph_obj.phase_summaries = phase_data.get("state_phase_summaries", {})
# Rebuild GameHistory
game_history = GameHistory()
logger.info("Rebuilding game history...")
for phase_data in saved_game_data["phases"][:-1]:
phase_name = phase_data["name"]
game_history.add_phase(phase_name)
# Add messages
for msg in phase_data.get("messages", []):
game_history.add_message(phase_name, msg["sender"], msg["recipient"], msg["message"])
# Add plans
if "state_history_plans" in phase_data:
for p_name, plan in phase_data["state_history_plans"].items():
game_history.add_plan(phase_name, p_name, plan)
logger.info("Game history rebuilt.")
# Submitted orders reconstructed from order_results
submitted = defaultdict(list)
for pwr, type_map in phase_data.get("order_results", {}).items():
for lst in type_map.values():
for entry in lst:
if isinstance(entry, dict):
order_str = entry.get("order")
else:
order_str = entry
if order_str:
submitted[pwr].append(order_str)
ph_obj.submitted_orders_by_power = submitted
return game, agents, game_history, run_config