fix prompt errors & add per-power prompt dir functionality

This commit is contained in:
sam-paech 2025-07-04 11:31:57 +10:00
parent 22cf3202c7
commit 7edc7c465f
11 changed files with 160 additions and 112 deletions

View file

@ -25,29 +25,6 @@ logger = logging.getLogger(__name__)
ALL_POWERS = frozenset({"AUSTRIA", "ENGLAND", "FRANCE", "GERMANY", "ITALY", "RUSSIA", "TURKEY"})
ALLOWED_RELATIONSHIPS = ["Enemy", "Unfriendly", "Neutral", "Friendly", "Ally"]
# == New: Helper function to load prompt files reliably ==
def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> Optional[str]:
"""Loads a prompt template from the prompts directory."""
try:
if prompts_dir:
filepath = os.path.join(prompts_dir, filename)
else:
# Construct path relative to this file's location
current_dir = os.path.dirname(os.path.abspath(__file__))
default_prompts_dir = os.path.join(current_dir, "prompts")
filepath = os.path.join(default_prompts_dir, filename)
with open(filepath, "r", encoding="utf-8") as f:
return f.read()
except FileNotFoundError:
logger.error(f"Prompt file not found: {filepath}")
return None
except Exception as e:
logger.error(f"Error loading prompt file {filepath}: {e}")
return None
class DiplomacyAgent:
"""
Represents a stateful AI agent playing as a specific power in Diplomacy.
@ -452,7 +429,7 @@ class DiplomacyAgent:
try:
# Load the prompt template file
prompt_template_content = _load_prompt_file(get_prompt_path("negotiation_diary_prompt.txt"), prompts_dir=self.prompts_dir)
prompt_template_content = load_prompt(get_prompt_path("negotiation_diary_prompt.txt"), prompts_dir=self.prompts_dir)
if not prompt_template_content:
logger.error(f"[{self.power_name}] Could not load {get_prompt_path('negotiation_diary_prompt.txt')}. Skipping diary entry.")
success_status = "Failure: Prompt file not loaded"
@ -668,7 +645,7 @@ class DiplomacyAgent:
logger.info(f"[{self.power_name}] Generating order diary entry for {game.current_short_phase}...")
# Load the prompt template
prompt_template = _load_prompt_file(get_prompt_path("order_diary_prompt.txt"), prompts_dir=self.prompts_dir)
prompt_template = load_prompt(get_prompt_path("order_diary_prompt.txt"), prompts_dir=self.prompts_dir)
if not prompt_template:
logger.error(f"[{self.power_name}] Could not load {get_prompt_path('order_diary_prompt.txt')}. Skipping diary entry.")
return
@ -815,7 +792,7 @@ class DiplomacyAgent:
logger.info(f"[{self.power_name}] Generating phase result diary entry for {game.current_short_phase}...")
# Load the template
prompt_template = _load_prompt_file("phase_result_diary_prompt.txt", prompts_dir=self.prompts_dir)
prompt_template = load_prompt("phase_result_diary_prompt.txt", prompts_dir=self.prompts_dir)
if not prompt_template:
logger.error(f"[{self.power_name}] Could not load phase_result_diary_prompt.txt. Skipping diary entry.")
return
@ -921,7 +898,7 @@ class DiplomacyAgent:
try:
# 1. Construct the prompt using the unformatted state update prompt file
prompt_template = _load_prompt_file(get_prompt_path("state_update_prompt.txt"), prompts_dir=self.prompts_dir)
prompt_template = load_prompt(get_prompt_path("state_update_prompt.txt"), prompts_dir=self.prompts_dir)
if not prompt_template:
logger.error(f"[{power_name}] Could not load {get_prompt_path('state_update_prompt.txt')}. Skipping state update.")
return

View file

@ -3,7 +3,7 @@ import logging
import re
from typing import TYPE_CHECKING, Optional
from .utils import run_llm_and_log, log_llm_response
from .utils import run_llm_and_log, log_llm_response, load_prompt
if TYPE_CHECKING:
from diplomacy import Game
@ -12,25 +12,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> str | None:
"""A local copy of the helper from agent.py to avoid circular imports."""
import os
try:
if prompts_dir:
filepath = os.path.join(prompts_dir, filename)
else:
current_dir = os.path.dirname(os.path.abspath(__file__))
default_prompts_dir = os.path.join(current_dir, "prompts")
filepath = os.path.join(default_prompts_dir, filename)
with open(filepath, "r", encoding="utf-8") as f:
return f.read()
except Exception as e:
logger.error(f"Error loading prompt file {filepath}: {e}")
return None
async def run_diary_consolidation(
agent: "DiplomacyAgent",
game: "Game",
@ -75,7 +56,7 @@ async def run_diary_consolidation(
logger.warning(f"[{agent.power_name}] No eligible entries to summarise; context diary left unchanged")
return
prompt_template = _load_prompt_file("diary_consolidation_prompt.txt", prompts_dir=prompts_dir)
prompt_template = load_prompt("diary_consolidation_prompt.txt", prompts_dir=prompts_dir)
if not prompt_template:
logger.error(f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting")
return

View file

@ -5,6 +5,7 @@ import json
import asyncio
from typing import Dict, Tuple, Optional, Any
from argparse import Namespace
from pathlib import Path
from diplomacy import Game
from diplomacy.utils.export import to_saved_game_format, from_saved_game_format
@ -150,13 +151,16 @@ def save_game_state(
# 3b. For *this* phase we also inject the fresh agent snapshot
# and the plans written during the turn.
if phase_name == completed_phase_name:
phase_block["config"] = vars(run_config)
# ---- make run_config serialisable ---------------------------------
cfg = vars(run_config).copy()
if "prompts_dir_map" in cfg and isinstance(cfg["prompts_dir_map"], dict):
cfg["prompts_dir_map"] = {p: str(path) for p, path in cfg["prompts_dir_map"].items()}
if isinstance(cfg.get("prompts_dir"), Path):
cfg["prompts_dir"] = str(cfg["prompts_dir"])
# -------------------------------------------------------------------
phase_block["config"] = cfg
phase_block["state_agents"] = current_state_agents
# Plans for this phase may be empty in non-movement phases.
phase_obj = game_history._get_phase(phase_name)
phase_block["state_history_plans"] = phase_obj.plans if phase_obj else {}
# -------------------------------------------------------------- #
# 4. Attach top-level metadata and write atomically. #
# -------------------------------------------------------------- #
@ -254,9 +258,14 @@ def load_game_state(
if "state_agents" in last_phase_data:
logger.info("Rebuilding agents from saved state...")
prompts_dir_from_config = run_config.prompts_dir if run_config and hasattr(run_config, "prompts_dir") else None
for power_name, agent_data in last_phase_data["state_agents"].items():
override_id = power_model_map.get(power_name)
prompts_dir_from_config = (
run_config.prompts_dir_map.get(power_name)
if getattr(run_config, "prompts_dir_map", None)
else run_config.prompts_dir # fallback to old single path
)
agents[power_name] = deserialize_agent(
agent_data,
prompts_dir=prompts_dir_from_config,
@ -284,8 +293,15 @@ def load_game_state(
return game, agents, game_history, run_config
async def initialize_new_game(args: Namespace, game: Game, game_history: GameHistory, llm_log_file_path: str) -> Dict[str, DiplomacyAgent]:
"""Initializes agents for a new game."""
# ai_diplomacy/game_logic.py
async def initialize_new_game(
args: Namespace,
game: Game,
game_history: GameHistory,
llm_log_file_path: str,
) -> Dict[str, DiplomacyAgent]:
"""Initializes agents for a new game (supports per-power prompt directories)."""
powers_order = sorted(list(ALL_POWERS))
# Parse token limits
@ -300,8 +316,7 @@ async def initialize_new_game(args: Namespace, game: Game, game_history: GameHis
else:
logger.warning("Expected 7 values for --max_tokens_per_model, using default.")
# Handle power model mapping
# Handle power-model mapping
if args.models:
provided_models = [name.strip() for name in args.models.split(",")]
if len(provided_models) == len(powers_order):
@ -309,25 +324,51 @@ async def initialize_new_game(args: Namespace, game: Game, game_history: GameHis
elif len(provided_models) == 1:
game.power_model_map = dict(zip(powers_order, provided_models * 7))
else:
logger.error(f"Expected {len(powers_order)} models for --models but got {len(provided_models)}.")
raise Exception("Invalid number of models. Models list must be either exactly 1 or 7 models, comma delimited.")
logger.error(
f"Expected {len(powers_order)} models for --models but got {len(provided_models)}."
)
raise Exception(
"Invalid number of models. Models list must be either exactly 1 or 7 models, comma delimited."
)
else:
game.power_model_map = assign_models_to_powers()
agents = {}
agents: Dict[str, DiplomacyAgent] = {}
initialization_tasks = []
logger.info("Initializing Diplomacy Agents for each power...")
for power_name, model_id in game.power_model_map.items():
if not game.powers[power_name].is_eliminated():
# Determine the prompts directory for this power
if hasattr(args, "prompts_dir_map") and args.prompts_dir_map:
prompts_dir_for_power = args.prompts_dir_map.get(power_name, args.prompts_dir)
else:
prompts_dir_for_power = args.prompts_dir
try:
client = load_model_client(model_id, prompts_dir=args.prompts_dir)
client = load_model_client(model_id, prompts_dir=prompts_dir_for_power)
client.max_tokens = model_max_tokens[power_name]
agent = DiplomacyAgent(power_name=power_name, client=client, prompts_dir=args.prompts_dir)
agent = DiplomacyAgent(
power_name=power_name,
client=client,
prompts_dir=prompts_dir_for_power,
)
agents[power_name] = agent
logger.info(f"Preparing initialization task for {power_name} with model {model_id}")
initialization_tasks.append(initialize_agent_state_ext(agent, game, game_history, llm_log_file_path, prompts_dir=args.prompts_dir))
initialization_tasks.append(
initialize_agent_state_ext(
agent,
game,
game_history,
llm_log_file_path,
prompts_dir=prompts_dir_for_power,
)
)
except Exception as e:
logger.error(f"Failed to create agent or client for {power_name} with model {model_id}: {e}", exc_info=True)
logger.error(
f"Failed to create agent or client for {power_name} with model {model_id}: {e}",
exc_info=True,
)
logger.info(f"Running {len(initialization_tasks)} agent initializations concurrently...")
initialization_results = await asyncio.gather(*initialization_tasks, return_exceptions=True)
@ -342,3 +383,4 @@ async def initialize_new_game(args: Namespace, game: Game, game_history: GameHis
logger.info(f"Successfully initialized agent state for {power_name}.")
return agents

View file

@ -10,8 +10,8 @@ if False: # TYPE_CHECKING
from diplomacy.models.game import GameHistory
from .agent import DiplomacyAgent
from .agent import ALL_POWERS, ALLOWED_RELATIONSHIPS, _load_prompt_file
from .utils import run_llm_and_log, log_llm_response, get_prompt_path
from .agent import ALL_POWERS, ALLOWED_RELATIONSHIPS
from .utils import run_llm_and_log, log_llm_response, get_prompt_path, load_prompt
from .prompt_constructor import build_context_prompt
from .formatter import format_with_gemini_flash, FORMAT_INITIAL_STATE
@ -37,10 +37,7 @@ async def initialize_agent_state_ext(
try:
# Load the prompt template
allowed_labels_str = ", ".join(ALLOWED_RELATIONSHIPS)
initial_prompt_template = _load_prompt_file(get_prompt_path("initial_state_prompt.txt"), prompts_dir=prompts_dir)
if not initial_prompt_template:
logger.error(f"[{power_name}] Could not load {get_prompt_path('initial_state_prompt.txt')}. Cannot initialize.")
return
initial_prompt_template = load_prompt(get_prompt_path("initial_state_prompt.txt"), prompts_dir=prompts_dir)
# Format the prompt with variables
initial_prompt = initial_prompt_template.format(power_name=power_name, allowed_labels_str=allowed_labels_str)

View file

@ -12,8 +12,9 @@ import asyncio
from openai import RateLimitError, APIConnectionError, APITimeoutError
import aiohttp
import requests
from pathlib import Path
from config import config
from models import POWERS_ORDER
# Avoid circular import for type hinting
if TYPE_CHECKING:
@ -313,40 +314,34 @@ def normalize_and_compare_orders(
return orders_not_accepted, orders_not_issued
# Helper to load prompt text from file relative to the expected 'prompts' dir
def load_prompt(filename: str, prompts_dir: Optional[str] = None) -> str:
def load_prompt(fname: str | Path, prompts_dir: str | Path | None = None) -> str:
"""
Return the contents of *filename* while never joining paths twice.
Resolve *fname* to an absolute path and return its contents.
Resolution rules (first match wins):
Logic
-----
1. If *filename* is absolute use it directly.
2. Elif *filename* already contains a path component (e.g. 'x/y.txt')
treat it as a relative path and use it directly.
3. Elif *prompts_dir* is provided join prompts_dir + filename.
4. Otherwise join the packages default prompts dir.
1. If *fname* is absolute -> use as-is.
2. If *prompts_dir* is given -> prompts_dir / fname
3. Otherwise -> <package_root>/prompts / fname
"""
if os.path.isabs(filename): # rule 1
prompt_path = filename
elif os.path.dirname(filename): # rule 2 (has slash)
# If it's a relative path with directory, join with prompts_dir if provided
if prompts_dir:
prompt_path = os.path.join(prompts_dir, filename)
fname = Path(fname)
if fname.is_absolute():
prompt_path = fname
else:
if prompts_dir is not None:
prompt_path = Path(prompts_dir) / fname
else:
default_dir = os.path.join(os.path.dirname(__file__), "prompts")
prompt_path = os.path.join(default_dir, filename)
elif prompts_dir: # rule 3
prompt_path = os.path.join(prompts_dir, filename)
else: # rule 4
default_dir = os.path.join(os.path.dirname(__file__), "prompts")
prompt_path = os.path.join(default_dir, filename)
package_root = Path(__file__).resolve().parent
prompt_path = package_root / "prompts" / fname
try:
with open(prompt_path, "r", encoding="utf-8") as fh:
return fh.read().strip()
return prompt_path.read_text(encoding="utf-8").strip()
except FileNotFoundError:
logger.error(f"Prompt file not found: {prompt_path}")
return ""
logger.error("Prompt file not found: %s", prompt_path)
raise Exception("Prompt file not found: " + prompt_path)
# == New LLM Response Logging Function ==
@ -552,3 +547,35 @@ def normalize_recipient_name(recipient: str) -> str:
normalized = name_mapping.get(recipient, recipient)
return normalized
def parse_prompts_dir_arg(raw: str | None) -> Dict[str, Path]:
"""
Resolve --prompts_dir into a mapping {power: Path}.
Accepts either a single path or 7 comma-separated paths.
Every path is normalised to an **absolute** Path object
(using Path(...).expanduser().resolve()) and checked for existence.
"""
if not raw:
return {}
parts = [s.strip() for s in raw.split(",") if s.strip()]
if len(parts) not in {1, 7}:
raise ValueError(
f"--prompts_dir expects 1 or 7 paths, got {len(parts)} "
f"({raw})"
)
# Expand/resolve & verify
def _norm(p: str) -> Path:
path = Path(p).expanduser().resolve()
if not path.is_dir():
raise FileNotFoundError(f"Prompt directory not found: {path}")
return path
if len(parts) == 1:
path = _norm(parts[0])
return {pwr: path for pwr in POWERS_ORDER}
paths = [_norm(p) for p in parts]
return dict(zip(POWERS_ORDER, paths))

View file

@ -363,7 +363,10 @@ class StatisticalGameAnalyzer:
# Categorize by relationship
recipient = msg.get('recipient_power')
normalized_recipient = self._normalize_recipient_name(recipient)
try:
normalized_recipient = PowerEnum(recipient).value
except ValueError:
normalized_recipient = None
# Skip self-messages and invalid recipients
if normalized_recipient and normalized_recipient != power and normalized_recipient in relationships:

View file

@ -10,7 +10,7 @@ class Configuration(BaseSettings):
DEBUG: bool = False
log_file_path: Path | None = None
USE_UNFORMATTED_PROMPTS: bool = False
SIMPLE_PROMPTS: bool = False
SIMPLE_PROMPTS: bool = True
# Default models for tasks
AI_DIPLOMACY_NARRATIVE_MODEL: str = "openrouter-google/gemini-2.5-flash-preview-05-20"

View file

@ -20,7 +20,7 @@ os.environ["GRPC_POLL_STRATEGY"] = "poll" # Use 'poll' for macOS compatibility
from diplomacy import Game
from ai_diplomacy.utils import get_valid_orders, gather_possible_orders
from ai_diplomacy.utils import get_valid_orders, gather_possible_orders, parse_prompts_dir_arg
from ai_diplomacy.negotiations import conduct_negotiations
from ai_diplomacy.planning import planning_phase
from ai_diplomacy.game_history import GameHistory
@ -31,6 +31,7 @@ from ai_diplomacy.game_logic import (
initialize_new_game,
)
from ai_diplomacy.diary_logic import run_diary_consolidation
from config import config
dotenv.load_dotenv()
@ -177,15 +178,17 @@ async def main():
args = parse_arguments()
start_whole = time.time()
# honour --simple_prompts before anything else needs it
if args.simple_prompts:
os.environ["SIMPLE_PROMPTS"] = "1" # read by prompt_constructor
config.SIMPLE_PROMPTS = True
if args.prompts_dir is None:
pkg_root = os.path.join(os.path.dirname(__file__), "ai_diplomacy")
args.prompts_dir = os.path.join(pkg_root, "prompts_simple")
if args.prompts_dir and not os.path.isdir(args.prompts_dir):
print(f"ERROR: Prompts directory not found: {args.prompts_dir}", file=sys.stderr)
# Prompt-dir validation & mapping
try:
args.prompts_dir_map = parse_prompts_dir_arg(args.prompts_dir)
except Exception as exc:
print(f"ERROR: {exc}", file=sys.stderr)
sys.exit(1)
# Handle phase summaries flag - import narrative module only if enabled
@ -197,10 +200,10 @@ async def main():
# Handle unformatted prompts flag
if args.use_unformatted_prompts:
os.environ["USE_UNFORMATTED_PROMPTS"] = "1"
config.USE_UNFORMATTED_PROMPTS = True
logger.info("Using two-step approach: unformatted prompts + Gemini Flash formatting")
else:
os.environ["USE_UNFORMATTED_PROMPTS"] = "0"
config.USE_UNFORMATTED_PROMPTS = False
logger.info("Using original single-step formatted prompts")
# --- 1. Determine Run Directory and Mode (New vs. Resume) ---
@ -404,8 +407,10 @@ async def main():
# Diary Consolidation
if current_short_phase.startswith("S") and current_short_phase.endswith("M"):
consolidation_tasks = [
run_diary_consolidation(agent, game, llm_log_file_path, prompts_dir=run_config.prompts_dir)
for agent in agents.values() if not game.powers[agent.power_name].is_eliminated()
run_diary_consolidation(agent, game, llm_log_file_path,
prompts_dir=agent.prompts_dir)
for agent in agents.values()
if not game.powers[agent.power_name].is_eliminated()
]
if consolidation_tasks:
await asyncio.gather(*consolidation_tasks, return_exceptions=True)
@ -430,9 +435,14 @@ async def main():
# Save final overview stats
overview_file_path = os.path.join(run_dir, "overview.jsonl")
with open(overview_file_path, "w") as overview_file:
# ---- make Namespace JSON-safe ----------------------------------
cfg = vars(run_config).copy()
if "prompts_dir_map" in cfg and isinstance(cfg["prompts_dir_map"], dict):
cfg["prompts_dir_map"] = {p: str(path) for p, path in cfg["prompts_dir_map"].items()}
# ----------------------------------------------------------------
overview_file.write(json.dumps(model_error_stats) + "\n")
overview_file.write(json.dumps(getattr(game, 'power_model_map', {})) + "\n")
overview_file.write(json.dumps(vars(run_config)) + "\n")
overview_file.write(json.dumps(cfg) + "\n")
logger.info("Done.")

View file

@ -1,5 +1,5 @@
from enum import Enum
from typing import Any, Optional
from typing import Any, Optional, List
# your “typo → canonical” map
_POWER_ALIASES = {
@ -7,8 +7,19 @@ _POWER_ALIASES = {
"GERMAN": "GERMANY",
"UK": "ENGLAND",
"BRIT": "ENGLAND",
"Germany": "GERMANY",
"England": "ENGLAND",
"France": "FRANCE",
"Italy": "ITALY",
"Russia": "RUSSIA",
"Austria": "AUSTRIA",
"Turkey": "TURKEY",
}
POWERS_ORDER: List[str] = [
"AUSTRIA", "ENGLAND", "FRANCE", "GERMANY",
"ITALY", "RUSSIA", "TURKEY",
]
class PowerEnum(str, Enum):
AUSTRIA = "AUSTRIA"