From 7edc7c465f42773c8499ac96cc12a7698d74c3bb Mon Sep 17 00:00:00 2001 From: sam-paech <152407511+sam-paech@users.noreply.github.com> Date: Fri, 4 Jul 2025 11:31:57 +1000 Subject: [PATCH] fix prompt errors & add per-power prompt dir functionality --- ai_diplomacy/agent.py | 31 +------ ai_diplomacy/diary_logic.py | 23 +---- ai_diplomacy/game_logic.py | 76 +++++++++++++---- ai_diplomacy/initialization.py | 9 +- ai_diplomacy/prompts/initial_state_prompt.txt | 0 .../prompts_simple/initial_state_prompt.txt | 0 ai_diplomacy/utils.py | 83 ++++++++++++------- analysis/statistical_game_analysis.py | 5 +- config.py | 2 +- lm_game.py | 30 ++++--- models.py | 13 ++- 11 files changed, 160 insertions(+), 112 deletions(-) create mode 100644 ai_diplomacy/prompts/initial_state_prompt.txt create mode 100644 ai_diplomacy/prompts_simple/initial_state_prompt.txt diff --git a/ai_diplomacy/agent.py b/ai_diplomacy/agent.py index 8c0d38f..5ce5d0d 100644 --- a/ai_diplomacy/agent.py +++ b/ai_diplomacy/agent.py @@ -25,29 +25,6 @@ logger = logging.getLogger(__name__) ALL_POWERS = frozenset({"AUSTRIA", "ENGLAND", "FRANCE", "GERMANY", "ITALY", "RUSSIA", "TURKEY"}) ALLOWED_RELATIONSHIPS = ["Enemy", "Unfriendly", "Neutral", "Friendly", "Ally"] - -# == New: Helper function to load prompt files reliably == -def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> Optional[str]: - """Loads a prompt template from the prompts directory.""" - try: - if prompts_dir: - filepath = os.path.join(prompts_dir, filename) - else: - # Construct path relative to this file's location - current_dir = os.path.dirname(os.path.abspath(__file__)) - default_prompts_dir = os.path.join(current_dir, "prompts") - filepath = os.path.join(default_prompts_dir, filename) - - with open(filepath, "r", encoding="utf-8") as f: - return f.read() - except FileNotFoundError: - logger.error(f"Prompt file not found: {filepath}") - return None - except Exception as e: - logger.error(f"Error loading prompt file {filepath}: {e}") - return None - - class DiplomacyAgent: """ Represents a stateful AI agent playing as a specific power in Diplomacy. @@ -452,7 +429,7 @@ class DiplomacyAgent: try: # Load the prompt template file - prompt_template_content = _load_prompt_file(get_prompt_path("negotiation_diary_prompt.txt"), prompts_dir=self.prompts_dir) + prompt_template_content = load_prompt(get_prompt_path("negotiation_diary_prompt.txt"), prompts_dir=self.prompts_dir) if not prompt_template_content: logger.error(f"[{self.power_name}] Could not load {get_prompt_path('negotiation_diary_prompt.txt')}. Skipping diary entry.") success_status = "Failure: Prompt file not loaded" @@ -668,7 +645,7 @@ class DiplomacyAgent: logger.info(f"[{self.power_name}] Generating order diary entry for {game.current_short_phase}...") # Load the prompt template - prompt_template = _load_prompt_file(get_prompt_path("order_diary_prompt.txt"), prompts_dir=self.prompts_dir) + prompt_template = load_prompt(get_prompt_path("order_diary_prompt.txt"), prompts_dir=self.prompts_dir) if not prompt_template: logger.error(f"[{self.power_name}] Could not load {get_prompt_path('order_diary_prompt.txt')}. Skipping diary entry.") return @@ -815,7 +792,7 @@ class DiplomacyAgent: logger.info(f"[{self.power_name}] Generating phase result diary entry for {game.current_short_phase}...") # Load the template - prompt_template = _load_prompt_file("phase_result_diary_prompt.txt", prompts_dir=self.prompts_dir) + prompt_template = load_prompt("phase_result_diary_prompt.txt", prompts_dir=self.prompts_dir) if not prompt_template: logger.error(f"[{self.power_name}] Could not load phase_result_diary_prompt.txt. Skipping diary entry.") return @@ -921,7 +898,7 @@ class DiplomacyAgent: try: # 1. Construct the prompt using the unformatted state update prompt file - prompt_template = _load_prompt_file(get_prompt_path("state_update_prompt.txt"), prompts_dir=self.prompts_dir) + prompt_template = load_prompt(get_prompt_path("state_update_prompt.txt"), prompts_dir=self.prompts_dir) if not prompt_template: logger.error(f"[{power_name}] Could not load {get_prompt_path('state_update_prompt.txt')}. Skipping state update.") return diff --git a/ai_diplomacy/diary_logic.py b/ai_diplomacy/diary_logic.py index bf6f757..407739d 100644 --- a/ai_diplomacy/diary_logic.py +++ b/ai_diplomacy/diary_logic.py @@ -3,7 +3,7 @@ import logging import re from typing import TYPE_CHECKING, Optional -from .utils import run_llm_and_log, log_llm_response +from .utils import run_llm_and_log, log_llm_response, load_prompt if TYPE_CHECKING: from diplomacy import Game @@ -12,25 +12,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> str | None: - """A local copy of the helper from agent.py to avoid circular imports.""" - import os - - try: - if prompts_dir: - filepath = os.path.join(prompts_dir, filename) - else: - current_dir = os.path.dirname(os.path.abspath(__file__)) - default_prompts_dir = os.path.join(current_dir, "prompts") - filepath = os.path.join(default_prompts_dir, filename) - - with open(filepath, "r", encoding="utf-8") as f: - return f.read() - except Exception as e: - logger.error(f"Error loading prompt file {filepath}: {e}") - return None - - async def run_diary_consolidation( agent: "DiplomacyAgent", game: "Game", @@ -75,7 +56,7 @@ async def run_diary_consolidation( logger.warning(f"[{agent.power_name}] No eligible entries to summarise; context diary left unchanged") return - prompt_template = _load_prompt_file("diary_consolidation_prompt.txt", prompts_dir=prompts_dir) + prompt_template = load_prompt("diary_consolidation_prompt.txt", prompts_dir=prompts_dir) if not prompt_template: logger.error(f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting") return diff --git a/ai_diplomacy/game_logic.py b/ai_diplomacy/game_logic.py index 8a44871..5ec7277 100644 --- a/ai_diplomacy/game_logic.py +++ b/ai_diplomacy/game_logic.py @@ -5,6 +5,7 @@ import json import asyncio from typing import Dict, Tuple, Optional, Any from argparse import Namespace +from pathlib import Path from diplomacy import Game from diplomacy.utils.export import to_saved_game_format, from_saved_game_format @@ -150,13 +151,16 @@ def save_game_state( # 3b. For *this* phase we also inject the fresh agent snapshot # and the plans written during the turn. if phase_name == completed_phase_name: - phase_block["config"] = vars(run_config) + # ---- make run_config serialisable --------------------------------- + cfg = vars(run_config).copy() + if "prompts_dir_map" in cfg and isinstance(cfg["prompts_dir_map"], dict): + cfg["prompts_dir_map"] = {p: str(path) for p, path in cfg["prompts_dir_map"].items()} + if isinstance(cfg.get("prompts_dir"), Path): + cfg["prompts_dir"] = str(cfg["prompts_dir"]) + # ------------------------------------------------------------------- + phase_block["config"] = cfg phase_block["state_agents"] = current_state_agents - # Plans for this phase – may be empty in non-movement phases. - phase_obj = game_history._get_phase(phase_name) - phase_block["state_history_plans"] = phase_obj.plans if phase_obj else {} - # -------------------------------------------------------------- # # 4. Attach top-level metadata and write atomically. # # -------------------------------------------------------------- # @@ -254,9 +258,14 @@ def load_game_state( if "state_agents" in last_phase_data: logger.info("Rebuilding agents from saved state...") - prompts_dir_from_config = run_config.prompts_dir if run_config and hasattr(run_config, "prompts_dir") else None + for power_name, agent_data in last_phase_data["state_agents"].items(): override_id = power_model_map.get(power_name) + prompts_dir_from_config = ( + run_config.prompts_dir_map.get(power_name) + if getattr(run_config, "prompts_dir_map", None) + else run_config.prompts_dir # fallback to old single path + ) agents[power_name] = deserialize_agent( agent_data, prompts_dir=prompts_dir_from_config, @@ -284,8 +293,15 @@ def load_game_state( return game, agents, game_history, run_config -async def initialize_new_game(args: Namespace, game: Game, game_history: GameHistory, llm_log_file_path: str) -> Dict[str, DiplomacyAgent]: - """Initializes agents for a new game.""" +# ai_diplomacy/game_logic.py +async def initialize_new_game( + args: Namespace, + game: Game, + game_history: GameHistory, + llm_log_file_path: str, +) -> Dict[str, DiplomacyAgent]: + """Initializes agents for a new game (supports per-power prompt directories).""" + powers_order = sorted(list(ALL_POWERS)) # Parse token limits @@ -300,8 +316,7 @@ async def initialize_new_game(args: Namespace, game: Game, game_history: GameHis else: logger.warning("Expected 7 values for --max_tokens_per_model, using default.") - # Handle power model mapping - + # Handle power-model mapping if args.models: provided_models = [name.strip() for name in args.models.split(",")] if len(provided_models) == len(powers_order): @@ -309,25 +324,51 @@ async def initialize_new_game(args: Namespace, game: Game, game_history: GameHis elif len(provided_models) == 1: game.power_model_map = dict(zip(powers_order, provided_models * 7)) else: - logger.error(f"Expected {len(powers_order)} models for --models but got {len(provided_models)}.") - raise Exception("Invalid number of models. Models list must be either exactly 1 or 7 models, comma delimited.") + logger.error( + f"Expected {len(powers_order)} models for --models but got {len(provided_models)}." + ) + raise Exception( + "Invalid number of models. Models list must be either exactly 1 or 7 models, comma delimited." + ) else: game.power_model_map = assign_models_to_powers() - agents = {} + agents: Dict[str, DiplomacyAgent] = {} initialization_tasks = [] logger.info("Initializing Diplomacy Agents for each power...") + for power_name, model_id in game.power_model_map.items(): if not game.powers[power_name].is_eliminated(): + # Determine the prompts directory for this power + if hasattr(args, "prompts_dir_map") and args.prompts_dir_map: + prompts_dir_for_power = args.prompts_dir_map.get(power_name, args.prompts_dir) + else: + prompts_dir_for_power = args.prompts_dir + try: - client = load_model_client(model_id, prompts_dir=args.prompts_dir) + client = load_model_client(model_id, prompts_dir=prompts_dir_for_power) client.max_tokens = model_max_tokens[power_name] - agent = DiplomacyAgent(power_name=power_name, client=client, prompts_dir=args.prompts_dir) + agent = DiplomacyAgent( + power_name=power_name, + client=client, + prompts_dir=prompts_dir_for_power, + ) agents[power_name] = agent logger.info(f"Preparing initialization task for {power_name} with model {model_id}") - initialization_tasks.append(initialize_agent_state_ext(agent, game, game_history, llm_log_file_path, prompts_dir=args.prompts_dir)) + initialization_tasks.append( + initialize_agent_state_ext( + agent, + game, + game_history, + llm_log_file_path, + prompts_dir=prompts_dir_for_power, + ) + ) except Exception as e: - logger.error(f"Failed to create agent or client for {power_name} with model {model_id}: {e}", exc_info=True) + logger.error( + f"Failed to create agent or client for {power_name} with model {model_id}: {e}", + exc_info=True, + ) logger.info(f"Running {len(initialization_tasks)} agent initializations concurrently...") initialization_results = await asyncio.gather(*initialization_tasks, return_exceptions=True) @@ -342,3 +383,4 @@ async def initialize_new_game(args: Namespace, game: Game, game_history: GameHis logger.info(f"Successfully initialized agent state for {power_name}.") return agents + diff --git a/ai_diplomacy/initialization.py b/ai_diplomacy/initialization.py index 87b1628..0bdce77 100644 --- a/ai_diplomacy/initialization.py +++ b/ai_diplomacy/initialization.py @@ -10,8 +10,8 @@ if False: # TYPE_CHECKING from diplomacy.models.game import GameHistory from .agent import DiplomacyAgent -from .agent import ALL_POWERS, ALLOWED_RELATIONSHIPS, _load_prompt_file -from .utils import run_llm_and_log, log_llm_response, get_prompt_path +from .agent import ALL_POWERS, ALLOWED_RELATIONSHIPS +from .utils import run_llm_and_log, log_llm_response, get_prompt_path, load_prompt from .prompt_constructor import build_context_prompt from .formatter import format_with_gemini_flash, FORMAT_INITIAL_STATE @@ -37,10 +37,7 @@ async def initialize_agent_state_ext( try: # Load the prompt template allowed_labels_str = ", ".join(ALLOWED_RELATIONSHIPS) - initial_prompt_template = _load_prompt_file(get_prompt_path("initial_state_prompt.txt"), prompts_dir=prompts_dir) - if not initial_prompt_template: - logger.error(f"[{power_name}] Could not load {get_prompt_path('initial_state_prompt.txt')}. Cannot initialize.") - return + initial_prompt_template = load_prompt(get_prompt_path("initial_state_prompt.txt"), prompts_dir=prompts_dir) # Format the prompt with variables initial_prompt = initial_prompt_template.format(power_name=power_name, allowed_labels_str=allowed_labels_str) diff --git a/ai_diplomacy/prompts/initial_state_prompt.txt b/ai_diplomacy/prompts/initial_state_prompt.txt new file mode 100644 index 0000000..e69de29 diff --git a/ai_diplomacy/prompts_simple/initial_state_prompt.txt b/ai_diplomacy/prompts_simple/initial_state_prompt.txt new file mode 100644 index 0000000..e69de29 diff --git a/ai_diplomacy/utils.py b/ai_diplomacy/utils.py index 7731f62..1337d72 100644 --- a/ai_diplomacy/utils.py +++ b/ai_diplomacy/utils.py @@ -12,8 +12,9 @@ import asyncio from openai import RateLimitError, APIConnectionError, APITimeoutError import aiohttp import requests - +from pathlib import Path from config import config +from models import POWERS_ORDER # Avoid circular import for type hinting if TYPE_CHECKING: @@ -313,40 +314,34 @@ def normalize_and_compare_orders( return orders_not_accepted, orders_not_issued -# Helper to load prompt text from file relative to the expected 'prompts' dir -def load_prompt(filename: str, prompts_dir: Optional[str] = None) -> str: +def load_prompt(fname: str | Path, prompts_dir: str | Path | None = None) -> str: """ - Return the contents of *filename* while never joining paths twice. + Resolve *fname* to an absolute path and return its contents. + Resolution rules (first match wins): - Logic - ----- - 1. If *filename* is absolute → use it directly. - 2. Elif *filename* already contains a path component (e.g. 'x/y.txt') - → treat it as a relative path and use it directly. - 3. Elif *prompts_dir* is provided → join prompts_dir + filename. - 4. Otherwise → join the package’s default prompts dir. + 1. If *fname* is absolute -> use as-is. + 2. If *prompts_dir* is given -> prompts_dir / fname + 3. Otherwise -> /prompts / fname """ - if os.path.isabs(filename): # rule 1 - prompt_path = filename - elif os.path.dirname(filename): # rule 2 (has slash) - # If it's a relative path with directory, join with prompts_dir if provided - if prompts_dir: - prompt_path = os.path.join(prompts_dir, filename) + + fname = Path(fname) + + if fname.is_absolute(): + prompt_path = fname + + else: + if prompts_dir is not None: + prompt_path = Path(prompts_dir) / fname else: - default_dir = os.path.join(os.path.dirname(__file__), "prompts") - prompt_path = os.path.join(default_dir, filename) - elif prompts_dir: # rule 3 - prompt_path = os.path.join(prompts_dir, filename) - else: # rule 4 - default_dir = os.path.join(os.path.dirname(__file__), "prompts") - prompt_path = os.path.join(default_dir, filename) + package_root = Path(__file__).resolve().parent + prompt_path = package_root / "prompts" / fname try: - with open(prompt_path, "r", encoding="utf-8") as fh: - return fh.read().strip() + return prompt_path.read_text(encoding="utf-8").strip() except FileNotFoundError: - logger.error(f"Prompt file not found: {prompt_path}") - return "" + logger.error("Prompt file not found: %s", prompt_path) + raise Exception("Prompt file not found: " + prompt_path) + # == New LLM Response Logging Function == @@ -552,3 +547,35 @@ def normalize_recipient_name(recipient: str) -> str: normalized = name_mapping.get(recipient, recipient) return normalized + +def parse_prompts_dir_arg(raw: str | None) -> Dict[str, Path]: + """ + Resolve --prompts_dir into a mapping {power: Path}. + Accepts either a single path or 7 comma-separated paths. + + Every path is normalised to an **absolute** Path object + (using Path(...).expanduser().resolve()) and checked for existence. + """ + if not raw: + return {} + + parts = [s.strip() for s in raw.split(",") if s.strip()] + if len(parts) not in {1, 7}: + raise ValueError( + f"--prompts_dir expects 1 or 7 paths, got {len(parts)} " + f"({raw})" + ) + + # Expand/resolve & verify + def _norm(p: str) -> Path: + path = Path(p).expanduser().resolve() + if not path.is_dir(): + raise FileNotFoundError(f"Prompt directory not found: {path}") + return path + + if len(parts) == 1: + path = _norm(parts[0]) + return {pwr: path for pwr in POWERS_ORDER} + + paths = [_norm(p) for p in parts] + return dict(zip(POWERS_ORDER, paths)) \ No newline at end of file diff --git a/analysis/statistical_game_analysis.py b/analysis/statistical_game_analysis.py index 32f292c..48de6b3 100644 --- a/analysis/statistical_game_analysis.py +++ b/analysis/statistical_game_analysis.py @@ -363,7 +363,10 @@ class StatisticalGameAnalyzer: # Categorize by relationship recipient = msg.get('recipient_power') - normalized_recipient = self._normalize_recipient_name(recipient) + try: + normalized_recipient = PowerEnum(recipient).value + except ValueError: + normalized_recipient = None # Skip self-messages and invalid recipients if normalized_recipient and normalized_recipient != power and normalized_recipient in relationships: diff --git a/config.py b/config.py index 82331c9..694823f 100644 --- a/config.py +++ b/config.py @@ -10,7 +10,7 @@ class Configuration(BaseSettings): DEBUG: bool = False log_file_path: Path | None = None USE_UNFORMATTED_PROMPTS: bool = False - SIMPLE_PROMPTS: bool = False + SIMPLE_PROMPTS: bool = True # Default models for tasks AI_DIPLOMACY_NARRATIVE_MODEL: str = "openrouter-google/gemini-2.5-flash-preview-05-20" diff --git a/lm_game.py b/lm_game.py index 1733941..12a3012 100644 --- a/lm_game.py +++ b/lm_game.py @@ -20,7 +20,7 @@ os.environ["GRPC_POLL_STRATEGY"] = "poll" # Use 'poll' for macOS compatibility from diplomacy import Game -from ai_diplomacy.utils import get_valid_orders, gather_possible_orders +from ai_diplomacy.utils import get_valid_orders, gather_possible_orders, parse_prompts_dir_arg from ai_diplomacy.negotiations import conduct_negotiations from ai_diplomacy.planning import planning_phase from ai_diplomacy.game_history import GameHistory @@ -31,6 +31,7 @@ from ai_diplomacy.game_logic import ( initialize_new_game, ) from ai_diplomacy.diary_logic import run_diary_consolidation +from config import config dotenv.load_dotenv() @@ -177,15 +178,17 @@ async def main(): args = parse_arguments() start_whole = time.time() - # honour --simple_prompts before anything else needs it if args.simple_prompts: - os.environ["SIMPLE_PROMPTS"] = "1" # read by prompt_constructor + config.SIMPLE_PROMPTS = True if args.prompts_dir is None: pkg_root = os.path.join(os.path.dirname(__file__), "ai_diplomacy") args.prompts_dir = os.path.join(pkg_root, "prompts_simple") - if args.prompts_dir and not os.path.isdir(args.prompts_dir): - print(f"ERROR: Prompts directory not found: {args.prompts_dir}", file=sys.stderr) + # Prompt-dir validation & mapping + try: + args.prompts_dir_map = parse_prompts_dir_arg(args.prompts_dir) + except Exception as exc: + print(f"ERROR: {exc}", file=sys.stderr) sys.exit(1) # Handle phase summaries flag - import narrative module only if enabled @@ -197,10 +200,10 @@ async def main(): # Handle unformatted prompts flag if args.use_unformatted_prompts: - os.environ["USE_UNFORMATTED_PROMPTS"] = "1" + config.USE_UNFORMATTED_PROMPTS = True logger.info("Using two-step approach: unformatted prompts + Gemini Flash formatting") else: - os.environ["USE_UNFORMATTED_PROMPTS"] = "0" + config.USE_UNFORMATTED_PROMPTS = False logger.info("Using original single-step formatted prompts") # --- 1. Determine Run Directory and Mode (New vs. Resume) --- @@ -404,8 +407,10 @@ async def main(): # Diary Consolidation if current_short_phase.startswith("S") and current_short_phase.endswith("M"): consolidation_tasks = [ - run_diary_consolidation(agent, game, llm_log_file_path, prompts_dir=run_config.prompts_dir) - for agent in agents.values() if not game.powers[agent.power_name].is_eliminated() + run_diary_consolidation(agent, game, llm_log_file_path, + prompts_dir=agent.prompts_dir) + for agent in agents.values() + if not game.powers[agent.power_name].is_eliminated() ] if consolidation_tasks: await asyncio.gather(*consolidation_tasks, return_exceptions=True) @@ -430,9 +435,14 @@ async def main(): # Save final overview stats overview_file_path = os.path.join(run_dir, "overview.jsonl") with open(overview_file_path, "w") as overview_file: + # ---- make Namespace JSON-safe ---------------------------------- + cfg = vars(run_config).copy() + if "prompts_dir_map" in cfg and isinstance(cfg["prompts_dir_map"], dict): + cfg["prompts_dir_map"] = {p: str(path) for p, path in cfg["prompts_dir_map"].items()} + # ---------------------------------------------------------------- overview_file.write(json.dumps(model_error_stats) + "\n") overview_file.write(json.dumps(getattr(game, 'power_model_map', {})) + "\n") - overview_file.write(json.dumps(vars(run_config)) + "\n") + overview_file.write(json.dumps(cfg) + "\n") logger.info("Done.") diff --git a/models.py b/models.py index 88bb881..632c040 100644 --- a/models.py +++ b/models.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Optional +from typing import Any, Optional, List # your “typo → canonical” map _POWER_ALIASES = { @@ -7,8 +7,19 @@ _POWER_ALIASES = { "GERMAN": "GERMANY", "UK": "ENGLAND", "BRIT": "ENGLAND", + "Germany": "GERMANY", + "England": "ENGLAND", + "France": "FRANCE", + "Italy": "ITALY", + "Russia": "RUSSIA", + "Austria": "AUSTRIA", + "Turkey": "TURKEY", } +POWERS_ORDER: List[str] = [ + "AUSTRIA", "ENGLAND", "FRANCE", "GERMANY", + "ITALY", "RUSSIA", "TURKEY", +] class PowerEnum(str, Enum): AUSTRIA = "AUSTRIA"