fix prompt errors & add per-power prompt dir functionality

This commit is contained in:
sam-paech 2025-07-04 11:31:57 +10:00
parent 22cf3202c7
commit 7edc7c465f
11 changed files with 160 additions and 112 deletions

View file

@ -25,29 +25,6 @@ logger = logging.getLogger(__name__)
ALL_POWERS = frozenset({"AUSTRIA", "ENGLAND", "FRANCE", "GERMANY", "ITALY", "RUSSIA", "TURKEY"}) ALL_POWERS = frozenset({"AUSTRIA", "ENGLAND", "FRANCE", "GERMANY", "ITALY", "RUSSIA", "TURKEY"})
ALLOWED_RELATIONSHIPS = ["Enemy", "Unfriendly", "Neutral", "Friendly", "Ally"] ALLOWED_RELATIONSHIPS = ["Enemy", "Unfriendly", "Neutral", "Friendly", "Ally"]
# == New: Helper function to load prompt files reliably ==
def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> Optional[str]:
"""Loads a prompt template from the prompts directory."""
try:
if prompts_dir:
filepath = os.path.join(prompts_dir, filename)
else:
# Construct path relative to this file's location
current_dir = os.path.dirname(os.path.abspath(__file__))
default_prompts_dir = os.path.join(current_dir, "prompts")
filepath = os.path.join(default_prompts_dir, filename)
with open(filepath, "r", encoding="utf-8") as f:
return f.read()
except FileNotFoundError:
logger.error(f"Prompt file not found: {filepath}")
return None
except Exception as e:
logger.error(f"Error loading prompt file {filepath}: {e}")
return None
class DiplomacyAgent: class DiplomacyAgent:
""" """
Represents a stateful AI agent playing as a specific power in Diplomacy. Represents a stateful AI agent playing as a specific power in Diplomacy.
@ -452,7 +429,7 @@ class DiplomacyAgent:
try: try:
# Load the prompt template file # Load the prompt template file
prompt_template_content = _load_prompt_file(get_prompt_path("negotiation_diary_prompt.txt"), prompts_dir=self.prompts_dir) prompt_template_content = load_prompt(get_prompt_path("negotiation_diary_prompt.txt"), prompts_dir=self.prompts_dir)
if not prompt_template_content: if not prompt_template_content:
logger.error(f"[{self.power_name}] Could not load {get_prompt_path('negotiation_diary_prompt.txt')}. Skipping diary entry.") logger.error(f"[{self.power_name}] Could not load {get_prompt_path('negotiation_diary_prompt.txt')}. Skipping diary entry.")
success_status = "Failure: Prompt file not loaded" success_status = "Failure: Prompt file not loaded"
@ -668,7 +645,7 @@ class DiplomacyAgent:
logger.info(f"[{self.power_name}] Generating order diary entry for {game.current_short_phase}...") logger.info(f"[{self.power_name}] Generating order diary entry for {game.current_short_phase}...")
# Load the prompt template # Load the prompt template
prompt_template = _load_prompt_file(get_prompt_path("order_diary_prompt.txt"), prompts_dir=self.prompts_dir) prompt_template = load_prompt(get_prompt_path("order_diary_prompt.txt"), prompts_dir=self.prompts_dir)
if not prompt_template: if not prompt_template:
logger.error(f"[{self.power_name}] Could not load {get_prompt_path('order_diary_prompt.txt')}. Skipping diary entry.") logger.error(f"[{self.power_name}] Could not load {get_prompt_path('order_diary_prompt.txt')}. Skipping diary entry.")
return return
@ -815,7 +792,7 @@ class DiplomacyAgent:
logger.info(f"[{self.power_name}] Generating phase result diary entry for {game.current_short_phase}...") logger.info(f"[{self.power_name}] Generating phase result diary entry for {game.current_short_phase}...")
# Load the template # Load the template
prompt_template = _load_prompt_file("phase_result_diary_prompt.txt", prompts_dir=self.prompts_dir) prompt_template = load_prompt("phase_result_diary_prompt.txt", prompts_dir=self.prompts_dir)
if not prompt_template: if not prompt_template:
logger.error(f"[{self.power_name}] Could not load phase_result_diary_prompt.txt. Skipping diary entry.") logger.error(f"[{self.power_name}] Could not load phase_result_diary_prompt.txt. Skipping diary entry.")
return return
@ -921,7 +898,7 @@ class DiplomacyAgent:
try: try:
# 1. Construct the prompt using the unformatted state update prompt file # 1. Construct the prompt using the unformatted state update prompt file
prompt_template = _load_prompt_file(get_prompt_path("state_update_prompt.txt"), prompts_dir=self.prompts_dir) prompt_template = load_prompt(get_prompt_path("state_update_prompt.txt"), prompts_dir=self.prompts_dir)
if not prompt_template: if not prompt_template:
logger.error(f"[{power_name}] Could not load {get_prompt_path('state_update_prompt.txt')}. Skipping state update.") logger.error(f"[{power_name}] Could not load {get_prompt_path('state_update_prompt.txt')}. Skipping state update.")
return return

View file

@ -3,7 +3,7 @@ import logging
import re import re
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from .utils import run_llm_and_log, log_llm_response from .utils import run_llm_and_log, log_llm_response, load_prompt
if TYPE_CHECKING: if TYPE_CHECKING:
from diplomacy import Game from diplomacy import Game
@ -12,25 +12,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> str | None:
"""A local copy of the helper from agent.py to avoid circular imports."""
import os
try:
if prompts_dir:
filepath = os.path.join(prompts_dir, filename)
else:
current_dir = os.path.dirname(os.path.abspath(__file__))
default_prompts_dir = os.path.join(current_dir, "prompts")
filepath = os.path.join(default_prompts_dir, filename)
with open(filepath, "r", encoding="utf-8") as f:
return f.read()
except Exception as e:
logger.error(f"Error loading prompt file {filepath}: {e}")
return None
async def run_diary_consolidation( async def run_diary_consolidation(
agent: "DiplomacyAgent", agent: "DiplomacyAgent",
game: "Game", game: "Game",
@ -75,7 +56,7 @@ async def run_diary_consolidation(
logger.warning(f"[{agent.power_name}] No eligible entries to summarise; context diary left unchanged") logger.warning(f"[{agent.power_name}] No eligible entries to summarise; context diary left unchanged")
return return
prompt_template = _load_prompt_file("diary_consolidation_prompt.txt", prompts_dir=prompts_dir) prompt_template = load_prompt("diary_consolidation_prompt.txt", prompts_dir=prompts_dir)
if not prompt_template: if not prompt_template:
logger.error(f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting") logger.error(f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting")
return return

View file

@ -5,6 +5,7 @@ import json
import asyncio import asyncio
from typing import Dict, Tuple, Optional, Any from typing import Dict, Tuple, Optional, Any
from argparse import Namespace from argparse import Namespace
from pathlib import Path
from diplomacy import Game from diplomacy import Game
from diplomacy.utils.export import to_saved_game_format, from_saved_game_format from diplomacy.utils.export import to_saved_game_format, from_saved_game_format
@ -150,13 +151,16 @@ def save_game_state(
# 3b. For *this* phase we also inject the fresh agent snapshot # 3b. For *this* phase we also inject the fresh agent snapshot
# and the plans written during the turn. # and the plans written during the turn.
if phase_name == completed_phase_name: if phase_name == completed_phase_name:
phase_block["config"] = vars(run_config) # ---- make run_config serialisable ---------------------------------
cfg = vars(run_config).copy()
if "prompts_dir_map" in cfg and isinstance(cfg["prompts_dir_map"], dict):
cfg["prompts_dir_map"] = {p: str(path) for p, path in cfg["prompts_dir_map"].items()}
if isinstance(cfg.get("prompts_dir"), Path):
cfg["prompts_dir"] = str(cfg["prompts_dir"])
# -------------------------------------------------------------------
phase_block["config"] = cfg
phase_block["state_agents"] = current_state_agents phase_block["state_agents"] = current_state_agents
# Plans for this phase may be empty in non-movement phases.
phase_obj = game_history._get_phase(phase_name)
phase_block["state_history_plans"] = phase_obj.plans if phase_obj else {}
# -------------------------------------------------------------- # # -------------------------------------------------------------- #
# 4. Attach top-level metadata and write atomically. # # 4. Attach top-level metadata and write atomically. #
# -------------------------------------------------------------- # # -------------------------------------------------------------- #
@ -254,9 +258,14 @@ def load_game_state(
if "state_agents" in last_phase_data: if "state_agents" in last_phase_data:
logger.info("Rebuilding agents from saved state...") logger.info("Rebuilding agents from saved state...")
prompts_dir_from_config = run_config.prompts_dir if run_config and hasattr(run_config, "prompts_dir") else None
for power_name, agent_data in last_phase_data["state_agents"].items(): for power_name, agent_data in last_phase_data["state_agents"].items():
override_id = power_model_map.get(power_name) override_id = power_model_map.get(power_name)
prompts_dir_from_config = (
run_config.prompts_dir_map.get(power_name)
if getattr(run_config, "prompts_dir_map", None)
else run_config.prompts_dir # fallback to old single path
)
agents[power_name] = deserialize_agent( agents[power_name] = deserialize_agent(
agent_data, agent_data,
prompts_dir=prompts_dir_from_config, prompts_dir=prompts_dir_from_config,
@ -284,8 +293,15 @@ def load_game_state(
return game, agents, game_history, run_config return game, agents, game_history, run_config
async def initialize_new_game(args: Namespace, game: Game, game_history: GameHistory, llm_log_file_path: str) -> Dict[str, DiplomacyAgent]: # ai_diplomacy/game_logic.py
"""Initializes agents for a new game.""" async def initialize_new_game(
args: Namespace,
game: Game,
game_history: GameHistory,
llm_log_file_path: str,
) -> Dict[str, DiplomacyAgent]:
"""Initializes agents for a new game (supports per-power prompt directories)."""
powers_order = sorted(list(ALL_POWERS)) powers_order = sorted(list(ALL_POWERS))
# Parse token limits # Parse token limits
@ -300,8 +316,7 @@ async def initialize_new_game(args: Namespace, game: Game, game_history: GameHis
else: else:
logger.warning("Expected 7 values for --max_tokens_per_model, using default.") logger.warning("Expected 7 values for --max_tokens_per_model, using default.")
# Handle power model mapping # Handle power-model mapping
if args.models: if args.models:
provided_models = [name.strip() for name in args.models.split(",")] provided_models = [name.strip() for name in args.models.split(",")]
if len(provided_models) == len(powers_order): if len(provided_models) == len(powers_order):
@ -309,25 +324,51 @@ async def initialize_new_game(args: Namespace, game: Game, game_history: GameHis
elif len(provided_models) == 1: elif len(provided_models) == 1:
game.power_model_map = dict(zip(powers_order, provided_models * 7)) game.power_model_map = dict(zip(powers_order, provided_models * 7))
else: else:
logger.error(f"Expected {len(powers_order)} models for --models but got {len(provided_models)}.") logger.error(
raise Exception("Invalid number of models. Models list must be either exactly 1 or 7 models, comma delimited.") f"Expected {len(powers_order)} models for --models but got {len(provided_models)}."
)
raise Exception(
"Invalid number of models. Models list must be either exactly 1 or 7 models, comma delimited."
)
else: else:
game.power_model_map = assign_models_to_powers() game.power_model_map = assign_models_to_powers()
agents = {} agents: Dict[str, DiplomacyAgent] = {}
initialization_tasks = [] initialization_tasks = []
logger.info("Initializing Diplomacy Agents for each power...") logger.info("Initializing Diplomacy Agents for each power...")
for power_name, model_id in game.power_model_map.items(): for power_name, model_id in game.power_model_map.items():
if not game.powers[power_name].is_eliminated(): if not game.powers[power_name].is_eliminated():
# Determine the prompts directory for this power
if hasattr(args, "prompts_dir_map") and args.prompts_dir_map:
prompts_dir_for_power = args.prompts_dir_map.get(power_name, args.prompts_dir)
else:
prompts_dir_for_power = args.prompts_dir
try: try:
client = load_model_client(model_id, prompts_dir=args.prompts_dir) client = load_model_client(model_id, prompts_dir=prompts_dir_for_power)
client.max_tokens = model_max_tokens[power_name] client.max_tokens = model_max_tokens[power_name]
agent = DiplomacyAgent(power_name=power_name, client=client, prompts_dir=args.prompts_dir) agent = DiplomacyAgent(
power_name=power_name,
client=client,
prompts_dir=prompts_dir_for_power,
)
agents[power_name] = agent agents[power_name] = agent
logger.info(f"Preparing initialization task for {power_name} with model {model_id}") logger.info(f"Preparing initialization task for {power_name} with model {model_id}")
initialization_tasks.append(initialize_agent_state_ext(agent, game, game_history, llm_log_file_path, prompts_dir=args.prompts_dir)) initialization_tasks.append(
initialize_agent_state_ext(
agent,
game,
game_history,
llm_log_file_path,
prompts_dir=prompts_dir_for_power,
)
)
except Exception as e: except Exception as e:
logger.error(f"Failed to create agent or client for {power_name} with model {model_id}: {e}", exc_info=True) logger.error(
f"Failed to create agent or client for {power_name} with model {model_id}: {e}",
exc_info=True,
)
logger.info(f"Running {len(initialization_tasks)} agent initializations concurrently...") logger.info(f"Running {len(initialization_tasks)} agent initializations concurrently...")
initialization_results = await asyncio.gather(*initialization_tasks, return_exceptions=True) initialization_results = await asyncio.gather(*initialization_tasks, return_exceptions=True)
@ -342,3 +383,4 @@ async def initialize_new_game(args: Namespace, game: Game, game_history: GameHis
logger.info(f"Successfully initialized agent state for {power_name}.") logger.info(f"Successfully initialized agent state for {power_name}.")
return agents return agents

View file

@ -10,8 +10,8 @@ if False: # TYPE_CHECKING
from diplomacy.models.game import GameHistory from diplomacy.models.game import GameHistory
from .agent import DiplomacyAgent from .agent import DiplomacyAgent
from .agent import ALL_POWERS, ALLOWED_RELATIONSHIPS, _load_prompt_file from .agent import ALL_POWERS, ALLOWED_RELATIONSHIPS
from .utils import run_llm_and_log, log_llm_response, get_prompt_path from .utils import run_llm_and_log, log_llm_response, get_prompt_path, load_prompt
from .prompt_constructor import build_context_prompt from .prompt_constructor import build_context_prompt
from .formatter import format_with_gemini_flash, FORMAT_INITIAL_STATE from .formatter import format_with_gemini_flash, FORMAT_INITIAL_STATE
@ -37,10 +37,7 @@ async def initialize_agent_state_ext(
try: try:
# Load the prompt template # Load the prompt template
allowed_labels_str = ", ".join(ALLOWED_RELATIONSHIPS) allowed_labels_str = ", ".join(ALLOWED_RELATIONSHIPS)
initial_prompt_template = _load_prompt_file(get_prompt_path("initial_state_prompt.txt"), prompts_dir=prompts_dir) initial_prompt_template = load_prompt(get_prompt_path("initial_state_prompt.txt"), prompts_dir=prompts_dir)
if not initial_prompt_template:
logger.error(f"[{power_name}] Could not load {get_prompt_path('initial_state_prompt.txt')}. Cannot initialize.")
return
# Format the prompt with variables # Format the prompt with variables
initial_prompt = initial_prompt_template.format(power_name=power_name, allowed_labels_str=allowed_labels_str) initial_prompt = initial_prompt_template.format(power_name=power_name, allowed_labels_str=allowed_labels_str)

View file

@ -12,8 +12,9 @@ import asyncio
from openai import RateLimitError, APIConnectionError, APITimeoutError from openai import RateLimitError, APIConnectionError, APITimeoutError
import aiohttp import aiohttp
import requests import requests
from pathlib import Path
from config import config from config import config
from models import POWERS_ORDER
# Avoid circular import for type hinting # Avoid circular import for type hinting
if TYPE_CHECKING: if TYPE_CHECKING:
@ -313,40 +314,34 @@ def normalize_and_compare_orders(
return orders_not_accepted, orders_not_issued return orders_not_accepted, orders_not_issued
# Helper to load prompt text from file relative to the expected 'prompts' dir def load_prompt(fname: str | Path, prompts_dir: str | Path | None = None) -> str:
def load_prompt(filename: str, prompts_dir: Optional[str] = None) -> str:
""" """
Return the contents of *filename* while never joining paths twice. Resolve *fname* to an absolute path and return its contents.
Resolution rules (first match wins):
Logic 1. If *fname* is absolute -> use as-is.
----- 2. If *prompts_dir* is given -> prompts_dir / fname
1. If *filename* is absolute use it directly. 3. Otherwise -> <package_root>/prompts / fname
2. Elif *filename* already contains a path component (e.g. 'x/y.txt')
treat it as a relative path and use it directly.
3. Elif *prompts_dir* is provided join prompts_dir + filename.
4. Otherwise join the packages default prompts dir.
""" """
if os.path.isabs(filename): # rule 1
prompt_path = filename fname = Path(fname)
elif os.path.dirname(filename): # rule 2 (has slash)
# If it's a relative path with directory, join with prompts_dir if provided if fname.is_absolute():
if prompts_dir: prompt_path = fname
prompt_path = os.path.join(prompts_dir, filename)
else:
if prompts_dir is not None:
prompt_path = Path(prompts_dir) / fname
else: else:
default_dir = os.path.join(os.path.dirname(__file__), "prompts") package_root = Path(__file__).resolve().parent
prompt_path = os.path.join(default_dir, filename) prompt_path = package_root / "prompts" / fname
elif prompts_dir: # rule 3
prompt_path = os.path.join(prompts_dir, filename)
else: # rule 4
default_dir = os.path.join(os.path.dirname(__file__), "prompts")
prompt_path = os.path.join(default_dir, filename)
try: try:
with open(prompt_path, "r", encoding="utf-8") as fh: return prompt_path.read_text(encoding="utf-8").strip()
return fh.read().strip()
except FileNotFoundError: except FileNotFoundError:
logger.error(f"Prompt file not found: {prompt_path}") logger.error("Prompt file not found: %s", prompt_path)
return "" raise Exception("Prompt file not found: " + prompt_path)
# == New LLM Response Logging Function == # == New LLM Response Logging Function ==
@ -552,3 +547,35 @@ def normalize_recipient_name(recipient: str) -> str:
normalized = name_mapping.get(recipient, recipient) normalized = name_mapping.get(recipient, recipient)
return normalized return normalized
def parse_prompts_dir_arg(raw: str | None) -> Dict[str, Path]:
"""
Resolve --prompts_dir into a mapping {power: Path}.
Accepts either a single path or 7 comma-separated paths.
Every path is normalised to an **absolute** Path object
(using Path(...).expanduser().resolve()) and checked for existence.
"""
if not raw:
return {}
parts = [s.strip() for s in raw.split(",") if s.strip()]
if len(parts) not in {1, 7}:
raise ValueError(
f"--prompts_dir expects 1 or 7 paths, got {len(parts)} "
f"({raw})"
)
# Expand/resolve & verify
def _norm(p: str) -> Path:
path = Path(p).expanduser().resolve()
if not path.is_dir():
raise FileNotFoundError(f"Prompt directory not found: {path}")
return path
if len(parts) == 1:
path = _norm(parts[0])
return {pwr: path for pwr in POWERS_ORDER}
paths = [_norm(p) for p in parts]
return dict(zip(POWERS_ORDER, paths))

View file

@ -363,7 +363,10 @@ class StatisticalGameAnalyzer:
# Categorize by relationship # Categorize by relationship
recipient = msg.get('recipient_power') recipient = msg.get('recipient_power')
normalized_recipient = self._normalize_recipient_name(recipient) try:
normalized_recipient = PowerEnum(recipient).value
except ValueError:
normalized_recipient = None
# Skip self-messages and invalid recipients # Skip self-messages and invalid recipients
if normalized_recipient and normalized_recipient != power and normalized_recipient in relationships: if normalized_recipient and normalized_recipient != power and normalized_recipient in relationships:

View file

@ -10,7 +10,7 @@ class Configuration(BaseSettings):
DEBUG: bool = False DEBUG: bool = False
log_file_path: Path | None = None log_file_path: Path | None = None
USE_UNFORMATTED_PROMPTS: bool = False USE_UNFORMATTED_PROMPTS: bool = False
SIMPLE_PROMPTS: bool = False SIMPLE_PROMPTS: bool = True
# Default models for tasks # Default models for tasks
AI_DIPLOMACY_NARRATIVE_MODEL: str = "openrouter-google/gemini-2.5-flash-preview-05-20" AI_DIPLOMACY_NARRATIVE_MODEL: str = "openrouter-google/gemini-2.5-flash-preview-05-20"

View file

@ -20,7 +20,7 @@ os.environ["GRPC_POLL_STRATEGY"] = "poll" # Use 'poll' for macOS compatibility
from diplomacy import Game from diplomacy import Game
from ai_diplomacy.utils import get_valid_orders, gather_possible_orders from ai_diplomacy.utils import get_valid_orders, gather_possible_orders, parse_prompts_dir_arg
from ai_diplomacy.negotiations import conduct_negotiations from ai_diplomacy.negotiations import conduct_negotiations
from ai_diplomacy.planning import planning_phase from ai_diplomacy.planning import planning_phase
from ai_diplomacy.game_history import GameHistory from ai_diplomacy.game_history import GameHistory
@ -31,6 +31,7 @@ from ai_diplomacy.game_logic import (
initialize_new_game, initialize_new_game,
) )
from ai_diplomacy.diary_logic import run_diary_consolidation from ai_diplomacy.diary_logic import run_diary_consolidation
from config import config
dotenv.load_dotenv() dotenv.load_dotenv()
@ -177,15 +178,17 @@ async def main():
args = parse_arguments() args = parse_arguments()
start_whole = time.time() start_whole = time.time()
# honour --simple_prompts before anything else needs it
if args.simple_prompts: if args.simple_prompts:
os.environ["SIMPLE_PROMPTS"] = "1" # read by prompt_constructor config.SIMPLE_PROMPTS = True
if args.prompts_dir is None: if args.prompts_dir is None:
pkg_root = os.path.join(os.path.dirname(__file__), "ai_diplomacy") pkg_root = os.path.join(os.path.dirname(__file__), "ai_diplomacy")
args.prompts_dir = os.path.join(pkg_root, "prompts_simple") args.prompts_dir = os.path.join(pkg_root, "prompts_simple")
if args.prompts_dir and not os.path.isdir(args.prompts_dir): # Prompt-dir validation & mapping
print(f"ERROR: Prompts directory not found: {args.prompts_dir}", file=sys.stderr) try:
args.prompts_dir_map = parse_prompts_dir_arg(args.prompts_dir)
except Exception as exc:
print(f"ERROR: {exc}", file=sys.stderr)
sys.exit(1) sys.exit(1)
# Handle phase summaries flag - import narrative module only if enabled # Handle phase summaries flag - import narrative module only if enabled
@ -197,10 +200,10 @@ async def main():
# Handle unformatted prompts flag # Handle unformatted prompts flag
if args.use_unformatted_prompts: if args.use_unformatted_prompts:
os.environ["USE_UNFORMATTED_PROMPTS"] = "1" config.USE_UNFORMATTED_PROMPTS = True
logger.info("Using two-step approach: unformatted prompts + Gemini Flash formatting") logger.info("Using two-step approach: unformatted prompts + Gemini Flash formatting")
else: else:
os.environ["USE_UNFORMATTED_PROMPTS"] = "0" config.USE_UNFORMATTED_PROMPTS = False
logger.info("Using original single-step formatted prompts") logger.info("Using original single-step formatted prompts")
# --- 1. Determine Run Directory and Mode (New vs. Resume) --- # --- 1. Determine Run Directory and Mode (New vs. Resume) ---
@ -404,8 +407,10 @@ async def main():
# Diary Consolidation # Diary Consolidation
if current_short_phase.startswith("S") and current_short_phase.endswith("M"): if current_short_phase.startswith("S") and current_short_phase.endswith("M"):
consolidation_tasks = [ consolidation_tasks = [
run_diary_consolidation(agent, game, llm_log_file_path, prompts_dir=run_config.prompts_dir) run_diary_consolidation(agent, game, llm_log_file_path,
for agent in agents.values() if not game.powers[agent.power_name].is_eliminated() prompts_dir=agent.prompts_dir)
for agent in agents.values()
if not game.powers[agent.power_name].is_eliminated()
] ]
if consolidation_tasks: if consolidation_tasks:
await asyncio.gather(*consolidation_tasks, return_exceptions=True) await asyncio.gather(*consolidation_tasks, return_exceptions=True)
@ -430,9 +435,14 @@ async def main():
# Save final overview stats # Save final overview stats
overview_file_path = os.path.join(run_dir, "overview.jsonl") overview_file_path = os.path.join(run_dir, "overview.jsonl")
with open(overview_file_path, "w") as overview_file: with open(overview_file_path, "w") as overview_file:
# ---- make Namespace JSON-safe ----------------------------------
cfg = vars(run_config).copy()
if "prompts_dir_map" in cfg and isinstance(cfg["prompts_dir_map"], dict):
cfg["prompts_dir_map"] = {p: str(path) for p, path in cfg["prompts_dir_map"].items()}
# ----------------------------------------------------------------
overview_file.write(json.dumps(model_error_stats) + "\n") overview_file.write(json.dumps(model_error_stats) + "\n")
overview_file.write(json.dumps(getattr(game, 'power_model_map', {})) + "\n") overview_file.write(json.dumps(getattr(game, 'power_model_map', {})) + "\n")
overview_file.write(json.dumps(vars(run_config)) + "\n") overview_file.write(json.dumps(cfg) + "\n")
logger.info("Done.") logger.info("Done.")

View file

@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from typing import Any, Optional from typing import Any, Optional, List
# your “typo → canonical” map # your “typo → canonical” map
_POWER_ALIASES = { _POWER_ALIASES = {
@ -7,8 +7,19 @@ _POWER_ALIASES = {
"GERMAN": "GERMANY", "GERMAN": "GERMANY",
"UK": "ENGLAND", "UK": "ENGLAND",
"BRIT": "ENGLAND", "BRIT": "ENGLAND",
"Germany": "GERMANY",
"England": "ENGLAND",
"France": "FRANCE",
"Italy": "ITALY",
"Russia": "RUSSIA",
"Austria": "AUSTRIA",
"Turkey": "TURKEY",
} }
POWERS_ORDER: List[str] = [
"AUSTRIA", "ENGLAND", "FRANCE", "GERMANY",
"ITALY", "RUSSIA", "TURKEY",
]
class PowerEnum(str, Enum): class PowerEnum(str, Enum):
AUSTRIA = "AUSTRIA" AUSTRIA = "AUSTRIA"