mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-19 12:58:09 +00:00
fix prompt errors & add per-power prompt dir functionality
This commit is contained in:
parent
22cf3202c7
commit
7edc7c465f
11 changed files with 160 additions and 112 deletions
|
|
@ -25,29 +25,6 @@ logger = logging.getLogger(__name__)
|
||||||
ALL_POWERS = frozenset({"AUSTRIA", "ENGLAND", "FRANCE", "GERMANY", "ITALY", "RUSSIA", "TURKEY"})
|
ALL_POWERS = frozenset({"AUSTRIA", "ENGLAND", "FRANCE", "GERMANY", "ITALY", "RUSSIA", "TURKEY"})
|
||||||
ALLOWED_RELATIONSHIPS = ["Enemy", "Unfriendly", "Neutral", "Friendly", "Ally"]
|
ALLOWED_RELATIONSHIPS = ["Enemy", "Unfriendly", "Neutral", "Friendly", "Ally"]
|
||||||
|
|
||||||
|
|
||||||
# == New: Helper function to load prompt files reliably ==
|
|
||||||
def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> Optional[str]:
|
|
||||||
"""Loads a prompt template from the prompts directory."""
|
|
||||||
try:
|
|
||||||
if prompts_dir:
|
|
||||||
filepath = os.path.join(prompts_dir, filename)
|
|
||||||
else:
|
|
||||||
# Construct path relative to this file's location
|
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
default_prompts_dir = os.path.join(current_dir, "prompts")
|
|
||||||
filepath = os.path.join(default_prompts_dir, filename)
|
|
||||||
|
|
||||||
with open(filepath, "r", encoding="utf-8") as f:
|
|
||||||
return f.read()
|
|
||||||
except FileNotFoundError:
|
|
||||||
logger.error(f"Prompt file not found: {filepath}")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error loading prompt file {filepath}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class DiplomacyAgent:
|
class DiplomacyAgent:
|
||||||
"""
|
"""
|
||||||
Represents a stateful AI agent playing as a specific power in Diplomacy.
|
Represents a stateful AI agent playing as a specific power in Diplomacy.
|
||||||
|
|
@ -452,7 +429,7 @@ class DiplomacyAgent:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Load the prompt template file
|
# Load the prompt template file
|
||||||
prompt_template_content = _load_prompt_file(get_prompt_path("negotiation_diary_prompt.txt"), prompts_dir=self.prompts_dir)
|
prompt_template_content = load_prompt(get_prompt_path("negotiation_diary_prompt.txt"), prompts_dir=self.prompts_dir)
|
||||||
if not prompt_template_content:
|
if not prompt_template_content:
|
||||||
logger.error(f"[{self.power_name}] Could not load {get_prompt_path('negotiation_diary_prompt.txt')}. Skipping diary entry.")
|
logger.error(f"[{self.power_name}] Could not load {get_prompt_path('negotiation_diary_prompt.txt')}. Skipping diary entry.")
|
||||||
success_status = "Failure: Prompt file not loaded"
|
success_status = "Failure: Prompt file not loaded"
|
||||||
|
|
@ -668,7 +645,7 @@ class DiplomacyAgent:
|
||||||
logger.info(f"[{self.power_name}] Generating order diary entry for {game.current_short_phase}...")
|
logger.info(f"[{self.power_name}] Generating order diary entry for {game.current_short_phase}...")
|
||||||
|
|
||||||
# Load the prompt template
|
# Load the prompt template
|
||||||
prompt_template = _load_prompt_file(get_prompt_path("order_diary_prompt.txt"), prompts_dir=self.prompts_dir)
|
prompt_template = load_prompt(get_prompt_path("order_diary_prompt.txt"), prompts_dir=self.prompts_dir)
|
||||||
if not prompt_template:
|
if not prompt_template:
|
||||||
logger.error(f"[{self.power_name}] Could not load {get_prompt_path('order_diary_prompt.txt')}. Skipping diary entry.")
|
logger.error(f"[{self.power_name}] Could not load {get_prompt_path('order_diary_prompt.txt')}. Skipping diary entry.")
|
||||||
return
|
return
|
||||||
|
|
@ -815,7 +792,7 @@ class DiplomacyAgent:
|
||||||
logger.info(f"[{self.power_name}] Generating phase result diary entry for {game.current_short_phase}...")
|
logger.info(f"[{self.power_name}] Generating phase result diary entry for {game.current_short_phase}...")
|
||||||
|
|
||||||
# Load the template
|
# Load the template
|
||||||
prompt_template = _load_prompt_file("phase_result_diary_prompt.txt", prompts_dir=self.prompts_dir)
|
prompt_template = load_prompt("phase_result_diary_prompt.txt", prompts_dir=self.prompts_dir)
|
||||||
if not prompt_template:
|
if not prompt_template:
|
||||||
logger.error(f"[{self.power_name}] Could not load phase_result_diary_prompt.txt. Skipping diary entry.")
|
logger.error(f"[{self.power_name}] Could not load phase_result_diary_prompt.txt. Skipping diary entry.")
|
||||||
return
|
return
|
||||||
|
|
@ -921,7 +898,7 @@ class DiplomacyAgent:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 1. Construct the prompt using the unformatted state update prompt file
|
# 1. Construct the prompt using the unformatted state update prompt file
|
||||||
prompt_template = _load_prompt_file(get_prompt_path("state_update_prompt.txt"), prompts_dir=self.prompts_dir)
|
prompt_template = load_prompt(get_prompt_path("state_update_prompt.txt"), prompts_dir=self.prompts_dir)
|
||||||
if not prompt_template:
|
if not prompt_template:
|
||||||
logger.error(f"[{power_name}] Could not load {get_prompt_path('state_update_prompt.txt')}. Skipping state update.")
|
logger.error(f"[{power_name}] Could not load {get_prompt_path('state_update_prompt.txt')}. Skipping state update.")
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import logging
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from .utils import run_llm_and_log, log_llm_response
|
from .utils import run_llm_and_log, log_llm_response, load_prompt
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from diplomacy import Game
|
from diplomacy import Game
|
||||||
|
|
@ -12,25 +12,6 @@ if TYPE_CHECKING:
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> str | None:
|
|
||||||
"""A local copy of the helper from agent.py to avoid circular imports."""
|
|
||||||
import os
|
|
||||||
|
|
||||||
try:
|
|
||||||
if prompts_dir:
|
|
||||||
filepath = os.path.join(prompts_dir, filename)
|
|
||||||
else:
|
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
default_prompts_dir = os.path.join(current_dir, "prompts")
|
|
||||||
filepath = os.path.join(default_prompts_dir, filename)
|
|
||||||
|
|
||||||
with open(filepath, "r", encoding="utf-8") as f:
|
|
||||||
return f.read()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error loading prompt file {filepath}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def run_diary_consolidation(
|
async def run_diary_consolidation(
|
||||||
agent: "DiplomacyAgent",
|
agent: "DiplomacyAgent",
|
||||||
game: "Game",
|
game: "Game",
|
||||||
|
|
@ -75,7 +56,7 @@ async def run_diary_consolidation(
|
||||||
logger.warning(f"[{agent.power_name}] No eligible entries to summarise; context diary left unchanged")
|
logger.warning(f"[{agent.power_name}] No eligible entries to summarise; context diary left unchanged")
|
||||||
return
|
return
|
||||||
|
|
||||||
prompt_template = _load_prompt_file("diary_consolidation_prompt.txt", prompts_dir=prompts_dir)
|
prompt_template = load_prompt("diary_consolidation_prompt.txt", prompts_dir=prompts_dir)
|
||||||
if not prompt_template:
|
if not prompt_template:
|
||||||
logger.error(f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting")
|
logger.error(f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting")
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import json
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Dict, Tuple, Optional, Any
|
from typing import Dict, Tuple, Optional, Any
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from diplomacy import Game
|
from diplomacy import Game
|
||||||
from diplomacy.utils.export import to_saved_game_format, from_saved_game_format
|
from diplomacy.utils.export import to_saved_game_format, from_saved_game_format
|
||||||
|
|
@ -150,13 +151,16 @@ def save_game_state(
|
||||||
# 3b. For *this* phase we also inject the fresh agent snapshot
|
# 3b. For *this* phase we also inject the fresh agent snapshot
|
||||||
# and the plans written during the turn.
|
# and the plans written during the turn.
|
||||||
if phase_name == completed_phase_name:
|
if phase_name == completed_phase_name:
|
||||||
phase_block["config"] = vars(run_config)
|
# ---- make run_config serialisable ---------------------------------
|
||||||
|
cfg = vars(run_config).copy()
|
||||||
|
if "prompts_dir_map" in cfg and isinstance(cfg["prompts_dir_map"], dict):
|
||||||
|
cfg["prompts_dir_map"] = {p: str(path) for p, path in cfg["prompts_dir_map"].items()}
|
||||||
|
if isinstance(cfg.get("prompts_dir"), Path):
|
||||||
|
cfg["prompts_dir"] = str(cfg["prompts_dir"])
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
phase_block["config"] = cfg
|
||||||
phase_block["state_agents"] = current_state_agents
|
phase_block["state_agents"] = current_state_agents
|
||||||
|
|
||||||
# Plans for this phase – may be empty in non-movement phases.
|
|
||||||
phase_obj = game_history._get_phase(phase_name)
|
|
||||||
phase_block["state_history_plans"] = phase_obj.plans if phase_obj else {}
|
|
||||||
|
|
||||||
# -------------------------------------------------------------- #
|
# -------------------------------------------------------------- #
|
||||||
# 4. Attach top-level metadata and write atomically. #
|
# 4. Attach top-level metadata and write atomically. #
|
||||||
# -------------------------------------------------------------- #
|
# -------------------------------------------------------------- #
|
||||||
|
|
@ -254,9 +258,14 @@ def load_game_state(
|
||||||
|
|
||||||
if "state_agents" in last_phase_data:
|
if "state_agents" in last_phase_data:
|
||||||
logger.info("Rebuilding agents from saved state...")
|
logger.info("Rebuilding agents from saved state...")
|
||||||
prompts_dir_from_config = run_config.prompts_dir if run_config and hasattr(run_config, "prompts_dir") else None
|
|
||||||
for power_name, agent_data in last_phase_data["state_agents"].items():
|
for power_name, agent_data in last_phase_data["state_agents"].items():
|
||||||
override_id = power_model_map.get(power_name)
|
override_id = power_model_map.get(power_name)
|
||||||
|
prompts_dir_from_config = (
|
||||||
|
run_config.prompts_dir_map.get(power_name)
|
||||||
|
if getattr(run_config, "prompts_dir_map", None)
|
||||||
|
else run_config.prompts_dir # fallback to old single path
|
||||||
|
)
|
||||||
agents[power_name] = deserialize_agent(
|
agents[power_name] = deserialize_agent(
|
||||||
agent_data,
|
agent_data,
|
||||||
prompts_dir=prompts_dir_from_config,
|
prompts_dir=prompts_dir_from_config,
|
||||||
|
|
@ -284,8 +293,15 @@ def load_game_state(
|
||||||
return game, agents, game_history, run_config
|
return game, agents, game_history, run_config
|
||||||
|
|
||||||
|
|
||||||
async def initialize_new_game(args: Namespace, game: Game, game_history: GameHistory, llm_log_file_path: str) -> Dict[str, DiplomacyAgent]:
|
# ai_diplomacy/game_logic.py
|
||||||
"""Initializes agents for a new game."""
|
async def initialize_new_game(
|
||||||
|
args: Namespace,
|
||||||
|
game: Game,
|
||||||
|
game_history: GameHistory,
|
||||||
|
llm_log_file_path: str,
|
||||||
|
) -> Dict[str, DiplomacyAgent]:
|
||||||
|
"""Initializes agents for a new game (supports per-power prompt directories)."""
|
||||||
|
|
||||||
powers_order = sorted(list(ALL_POWERS))
|
powers_order = sorted(list(ALL_POWERS))
|
||||||
|
|
||||||
# Parse token limits
|
# Parse token limits
|
||||||
|
|
@ -300,8 +316,7 @@ async def initialize_new_game(args: Namespace, game: Game, game_history: GameHis
|
||||||
else:
|
else:
|
||||||
logger.warning("Expected 7 values for --max_tokens_per_model, using default.")
|
logger.warning("Expected 7 values for --max_tokens_per_model, using default.")
|
||||||
|
|
||||||
# Handle power model mapping
|
# Handle power-model mapping
|
||||||
|
|
||||||
if args.models:
|
if args.models:
|
||||||
provided_models = [name.strip() for name in args.models.split(",")]
|
provided_models = [name.strip() for name in args.models.split(",")]
|
||||||
if len(provided_models) == len(powers_order):
|
if len(provided_models) == len(powers_order):
|
||||||
|
|
@ -309,25 +324,51 @@ async def initialize_new_game(args: Namespace, game: Game, game_history: GameHis
|
||||||
elif len(provided_models) == 1:
|
elif len(provided_models) == 1:
|
||||||
game.power_model_map = dict(zip(powers_order, provided_models * 7))
|
game.power_model_map = dict(zip(powers_order, provided_models * 7))
|
||||||
else:
|
else:
|
||||||
logger.error(f"Expected {len(powers_order)} models for --models but got {len(provided_models)}.")
|
logger.error(
|
||||||
raise Exception("Invalid number of models. Models list must be either exactly 1 or 7 models, comma delimited.")
|
f"Expected {len(powers_order)} models for --models but got {len(provided_models)}."
|
||||||
|
)
|
||||||
|
raise Exception(
|
||||||
|
"Invalid number of models. Models list must be either exactly 1 or 7 models, comma delimited."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
game.power_model_map = assign_models_to_powers()
|
game.power_model_map = assign_models_to_powers()
|
||||||
|
|
||||||
agents = {}
|
agents: Dict[str, DiplomacyAgent] = {}
|
||||||
initialization_tasks = []
|
initialization_tasks = []
|
||||||
logger.info("Initializing Diplomacy Agents for each power...")
|
logger.info("Initializing Diplomacy Agents for each power...")
|
||||||
|
|
||||||
for power_name, model_id in game.power_model_map.items():
|
for power_name, model_id in game.power_model_map.items():
|
||||||
if not game.powers[power_name].is_eliminated():
|
if not game.powers[power_name].is_eliminated():
|
||||||
|
# Determine the prompts directory for this power
|
||||||
|
if hasattr(args, "prompts_dir_map") and args.prompts_dir_map:
|
||||||
|
prompts_dir_for_power = args.prompts_dir_map.get(power_name, args.prompts_dir)
|
||||||
|
else:
|
||||||
|
prompts_dir_for_power = args.prompts_dir
|
||||||
|
|
||||||
try:
|
try:
|
||||||
client = load_model_client(model_id, prompts_dir=args.prompts_dir)
|
client = load_model_client(model_id, prompts_dir=prompts_dir_for_power)
|
||||||
client.max_tokens = model_max_tokens[power_name]
|
client.max_tokens = model_max_tokens[power_name]
|
||||||
agent = DiplomacyAgent(power_name=power_name, client=client, prompts_dir=args.prompts_dir)
|
agent = DiplomacyAgent(
|
||||||
|
power_name=power_name,
|
||||||
|
client=client,
|
||||||
|
prompts_dir=prompts_dir_for_power,
|
||||||
|
)
|
||||||
agents[power_name] = agent
|
agents[power_name] = agent
|
||||||
logger.info(f"Preparing initialization task for {power_name} with model {model_id}")
|
logger.info(f"Preparing initialization task for {power_name} with model {model_id}")
|
||||||
initialization_tasks.append(initialize_agent_state_ext(agent, game, game_history, llm_log_file_path, prompts_dir=args.prompts_dir))
|
initialization_tasks.append(
|
||||||
|
initialize_agent_state_ext(
|
||||||
|
agent,
|
||||||
|
game,
|
||||||
|
game_history,
|
||||||
|
llm_log_file_path,
|
||||||
|
prompts_dir=prompts_dir_for_power,
|
||||||
|
)
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to create agent or client for {power_name} with model {model_id}: {e}", exc_info=True)
|
logger.error(
|
||||||
|
f"Failed to create agent or client for {power_name} with model {model_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"Running {len(initialization_tasks)} agent initializations concurrently...")
|
logger.info(f"Running {len(initialization_tasks)} agent initializations concurrently...")
|
||||||
initialization_results = await asyncio.gather(*initialization_tasks, return_exceptions=True)
|
initialization_results = await asyncio.gather(*initialization_tasks, return_exceptions=True)
|
||||||
|
|
@ -342,3 +383,4 @@ async def initialize_new_game(args: Namespace, game: Game, game_history: GameHis
|
||||||
logger.info(f"Successfully initialized agent state for {power_name}.")
|
logger.info(f"Successfully initialized agent state for {power_name}.")
|
||||||
|
|
||||||
return agents
|
return agents
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,8 +10,8 @@ if False: # TYPE_CHECKING
|
||||||
from diplomacy.models.game import GameHistory
|
from diplomacy.models.game import GameHistory
|
||||||
from .agent import DiplomacyAgent
|
from .agent import DiplomacyAgent
|
||||||
|
|
||||||
from .agent import ALL_POWERS, ALLOWED_RELATIONSHIPS, _load_prompt_file
|
from .agent import ALL_POWERS, ALLOWED_RELATIONSHIPS
|
||||||
from .utils import run_llm_and_log, log_llm_response, get_prompt_path
|
from .utils import run_llm_and_log, log_llm_response, get_prompt_path, load_prompt
|
||||||
from .prompt_constructor import build_context_prompt
|
from .prompt_constructor import build_context_prompt
|
||||||
from .formatter import format_with_gemini_flash, FORMAT_INITIAL_STATE
|
from .formatter import format_with_gemini_flash, FORMAT_INITIAL_STATE
|
||||||
|
|
||||||
|
|
@ -37,10 +37,7 @@ async def initialize_agent_state_ext(
|
||||||
try:
|
try:
|
||||||
# Load the prompt template
|
# Load the prompt template
|
||||||
allowed_labels_str = ", ".join(ALLOWED_RELATIONSHIPS)
|
allowed_labels_str = ", ".join(ALLOWED_RELATIONSHIPS)
|
||||||
initial_prompt_template = _load_prompt_file(get_prompt_path("initial_state_prompt.txt"), prompts_dir=prompts_dir)
|
initial_prompt_template = load_prompt(get_prompt_path("initial_state_prompt.txt"), prompts_dir=prompts_dir)
|
||||||
if not initial_prompt_template:
|
|
||||||
logger.error(f"[{power_name}] Could not load {get_prompt_path('initial_state_prompt.txt')}. Cannot initialize.")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Format the prompt with variables
|
# Format the prompt with variables
|
||||||
initial_prompt = initial_prompt_template.format(power_name=power_name, allowed_labels_str=allowed_labels_str)
|
initial_prompt = initial_prompt_template.format(power_name=power_name, allowed_labels_str=allowed_labels_str)
|
||||||
|
|
|
||||||
0
ai_diplomacy/prompts/initial_state_prompt.txt
Normal file
0
ai_diplomacy/prompts/initial_state_prompt.txt
Normal file
0
ai_diplomacy/prompts_simple/initial_state_prompt.txt
Normal file
0
ai_diplomacy/prompts_simple/initial_state_prompt.txt
Normal file
|
|
@ -12,8 +12,9 @@ import asyncio
|
||||||
from openai import RateLimitError, APIConnectionError, APITimeoutError
|
from openai import RateLimitError, APIConnectionError, APITimeoutError
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import requests
|
import requests
|
||||||
|
from pathlib import Path
|
||||||
from config import config
|
from config import config
|
||||||
|
from models import POWERS_ORDER
|
||||||
|
|
||||||
# Avoid circular import for type hinting
|
# Avoid circular import for type hinting
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
@ -313,40 +314,34 @@ def normalize_and_compare_orders(
|
||||||
return orders_not_accepted, orders_not_issued
|
return orders_not_accepted, orders_not_issued
|
||||||
|
|
||||||
|
|
||||||
# Helper to load prompt text from file relative to the expected 'prompts' dir
|
def load_prompt(fname: str | Path, prompts_dir: str | Path | None = None) -> str:
|
||||||
def load_prompt(filename: str, prompts_dir: Optional[str] = None) -> str:
|
|
||||||
"""
|
"""
|
||||||
Return the contents of *filename* while never joining paths twice.
|
Resolve *fname* to an absolute path and return its contents.
|
||||||
|
Resolution rules (first match wins):
|
||||||
|
|
||||||
Logic
|
1. If *fname* is absolute -> use as-is.
|
||||||
-----
|
2. If *prompts_dir* is given -> prompts_dir / fname
|
||||||
1. If *filename* is absolute → use it directly.
|
3. Otherwise -> <package_root>/prompts / fname
|
||||||
2. Elif *filename* already contains a path component (e.g. 'x/y.txt')
|
|
||||||
→ treat it as a relative path and use it directly.
|
|
||||||
3. Elif *prompts_dir* is provided → join prompts_dir + filename.
|
|
||||||
4. Otherwise → join the package’s default prompts dir.
|
|
||||||
"""
|
"""
|
||||||
if os.path.isabs(filename): # rule 1
|
|
||||||
prompt_path = filename
|
fname = Path(fname)
|
||||||
elif os.path.dirname(filename): # rule 2 (has slash)
|
|
||||||
# If it's a relative path with directory, join with prompts_dir if provided
|
if fname.is_absolute():
|
||||||
if prompts_dir:
|
prompt_path = fname
|
||||||
prompt_path = os.path.join(prompts_dir, filename)
|
|
||||||
|
else:
|
||||||
|
if prompts_dir is not None:
|
||||||
|
prompt_path = Path(prompts_dir) / fname
|
||||||
else:
|
else:
|
||||||
default_dir = os.path.join(os.path.dirname(__file__), "prompts")
|
package_root = Path(__file__).resolve().parent
|
||||||
prompt_path = os.path.join(default_dir, filename)
|
prompt_path = package_root / "prompts" / fname
|
||||||
elif prompts_dir: # rule 3
|
|
||||||
prompt_path = os.path.join(prompts_dir, filename)
|
|
||||||
else: # rule 4
|
|
||||||
default_dir = os.path.join(os.path.dirname(__file__), "prompts")
|
|
||||||
prompt_path = os.path.join(default_dir, filename)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(prompt_path, "r", encoding="utf-8") as fh:
|
return prompt_path.read_text(encoding="utf-8").strip()
|
||||||
return fh.read().strip()
|
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logger.error(f"Prompt file not found: {prompt_path}")
|
logger.error("Prompt file not found: %s", prompt_path)
|
||||||
return ""
|
raise Exception("Prompt file not found: " + prompt_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# == New LLM Response Logging Function ==
|
# == New LLM Response Logging Function ==
|
||||||
|
|
@ -552,3 +547,35 @@ def normalize_recipient_name(recipient: str) -> str:
|
||||||
normalized = name_mapping.get(recipient, recipient)
|
normalized = name_mapping.get(recipient, recipient)
|
||||||
|
|
||||||
return normalized
|
return normalized
|
||||||
|
|
||||||
|
def parse_prompts_dir_arg(raw: str | None) -> Dict[str, Path]:
|
||||||
|
"""
|
||||||
|
Resolve --prompts_dir into a mapping {power: Path}.
|
||||||
|
Accepts either a single path or 7 comma-separated paths.
|
||||||
|
|
||||||
|
Every path is normalised to an **absolute** Path object
|
||||||
|
(using Path(...).expanduser().resolve()) and checked for existence.
|
||||||
|
"""
|
||||||
|
if not raw:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
parts = [s.strip() for s in raw.split(",") if s.strip()]
|
||||||
|
if len(parts) not in {1, 7}:
|
||||||
|
raise ValueError(
|
||||||
|
f"--prompts_dir expects 1 or 7 paths, got {len(parts)} "
|
||||||
|
f"({raw})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Expand/resolve & verify
|
||||||
|
def _norm(p: str) -> Path:
|
||||||
|
path = Path(p).expanduser().resolve()
|
||||||
|
if not path.is_dir():
|
||||||
|
raise FileNotFoundError(f"Prompt directory not found: {path}")
|
||||||
|
return path
|
||||||
|
|
||||||
|
if len(parts) == 1:
|
||||||
|
path = _norm(parts[0])
|
||||||
|
return {pwr: path for pwr in POWERS_ORDER}
|
||||||
|
|
||||||
|
paths = [_norm(p) for p in parts]
|
||||||
|
return dict(zip(POWERS_ORDER, paths))
|
||||||
|
|
@ -363,7 +363,10 @@ class StatisticalGameAnalyzer:
|
||||||
|
|
||||||
# Categorize by relationship
|
# Categorize by relationship
|
||||||
recipient = msg.get('recipient_power')
|
recipient = msg.get('recipient_power')
|
||||||
normalized_recipient = self._normalize_recipient_name(recipient)
|
try:
|
||||||
|
normalized_recipient = PowerEnum(recipient).value
|
||||||
|
except ValueError:
|
||||||
|
normalized_recipient = None
|
||||||
|
|
||||||
# Skip self-messages and invalid recipients
|
# Skip self-messages and invalid recipients
|
||||||
if normalized_recipient and normalized_recipient != power and normalized_recipient in relationships:
|
if normalized_recipient and normalized_recipient != power and normalized_recipient in relationships:
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ class Configuration(BaseSettings):
|
||||||
DEBUG: bool = False
|
DEBUG: bool = False
|
||||||
log_file_path: Path | None = None
|
log_file_path: Path | None = None
|
||||||
USE_UNFORMATTED_PROMPTS: bool = False
|
USE_UNFORMATTED_PROMPTS: bool = False
|
||||||
SIMPLE_PROMPTS: bool = False
|
SIMPLE_PROMPTS: bool = True
|
||||||
|
|
||||||
# Default models for tasks
|
# Default models for tasks
|
||||||
AI_DIPLOMACY_NARRATIVE_MODEL: str = "openrouter-google/gemini-2.5-flash-preview-05-20"
|
AI_DIPLOMACY_NARRATIVE_MODEL: str = "openrouter-google/gemini-2.5-flash-preview-05-20"
|
||||||
|
|
|
||||||
30
lm_game.py
30
lm_game.py
|
|
@ -20,7 +20,7 @@ os.environ["GRPC_POLL_STRATEGY"] = "poll" # Use 'poll' for macOS compatibility
|
||||||
|
|
||||||
from diplomacy import Game
|
from diplomacy import Game
|
||||||
|
|
||||||
from ai_diplomacy.utils import get_valid_orders, gather_possible_orders
|
from ai_diplomacy.utils import get_valid_orders, gather_possible_orders, parse_prompts_dir_arg
|
||||||
from ai_diplomacy.negotiations import conduct_negotiations
|
from ai_diplomacy.negotiations import conduct_negotiations
|
||||||
from ai_diplomacy.planning import planning_phase
|
from ai_diplomacy.planning import planning_phase
|
||||||
from ai_diplomacy.game_history import GameHistory
|
from ai_diplomacy.game_history import GameHistory
|
||||||
|
|
@ -31,6 +31,7 @@ from ai_diplomacy.game_logic import (
|
||||||
initialize_new_game,
|
initialize_new_game,
|
||||||
)
|
)
|
||||||
from ai_diplomacy.diary_logic import run_diary_consolidation
|
from ai_diplomacy.diary_logic import run_diary_consolidation
|
||||||
|
from config import config
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
|
@ -177,15 +178,17 @@ async def main():
|
||||||
args = parse_arguments()
|
args = parse_arguments()
|
||||||
start_whole = time.time()
|
start_whole = time.time()
|
||||||
|
|
||||||
# honour --simple_prompts before anything else needs it
|
|
||||||
if args.simple_prompts:
|
if args.simple_prompts:
|
||||||
os.environ["SIMPLE_PROMPTS"] = "1" # read by prompt_constructor
|
config.SIMPLE_PROMPTS = True
|
||||||
if args.prompts_dir is None:
|
if args.prompts_dir is None:
|
||||||
pkg_root = os.path.join(os.path.dirname(__file__), "ai_diplomacy")
|
pkg_root = os.path.join(os.path.dirname(__file__), "ai_diplomacy")
|
||||||
args.prompts_dir = os.path.join(pkg_root, "prompts_simple")
|
args.prompts_dir = os.path.join(pkg_root, "prompts_simple")
|
||||||
|
|
||||||
if args.prompts_dir and not os.path.isdir(args.prompts_dir):
|
# Prompt-dir validation & mapping
|
||||||
print(f"ERROR: Prompts directory not found: {args.prompts_dir}", file=sys.stderr)
|
try:
|
||||||
|
args.prompts_dir_map = parse_prompts_dir_arg(args.prompts_dir)
|
||||||
|
except Exception as exc:
|
||||||
|
print(f"ERROR: {exc}", file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# Handle phase summaries flag - import narrative module only if enabled
|
# Handle phase summaries flag - import narrative module only if enabled
|
||||||
|
|
@ -197,10 +200,10 @@ async def main():
|
||||||
|
|
||||||
# Handle unformatted prompts flag
|
# Handle unformatted prompts flag
|
||||||
if args.use_unformatted_prompts:
|
if args.use_unformatted_prompts:
|
||||||
os.environ["USE_UNFORMATTED_PROMPTS"] = "1"
|
config.USE_UNFORMATTED_PROMPTS = True
|
||||||
logger.info("Using two-step approach: unformatted prompts + Gemini Flash formatting")
|
logger.info("Using two-step approach: unformatted prompts + Gemini Flash formatting")
|
||||||
else:
|
else:
|
||||||
os.environ["USE_UNFORMATTED_PROMPTS"] = "0"
|
config.USE_UNFORMATTED_PROMPTS = False
|
||||||
logger.info("Using original single-step formatted prompts")
|
logger.info("Using original single-step formatted prompts")
|
||||||
|
|
||||||
# --- 1. Determine Run Directory and Mode (New vs. Resume) ---
|
# --- 1. Determine Run Directory and Mode (New vs. Resume) ---
|
||||||
|
|
@ -404,8 +407,10 @@ async def main():
|
||||||
# Diary Consolidation
|
# Diary Consolidation
|
||||||
if current_short_phase.startswith("S") and current_short_phase.endswith("M"):
|
if current_short_phase.startswith("S") and current_short_phase.endswith("M"):
|
||||||
consolidation_tasks = [
|
consolidation_tasks = [
|
||||||
run_diary_consolidation(agent, game, llm_log_file_path, prompts_dir=run_config.prompts_dir)
|
run_diary_consolidation(agent, game, llm_log_file_path,
|
||||||
for agent in agents.values() if not game.powers[agent.power_name].is_eliminated()
|
prompts_dir=agent.prompts_dir)
|
||||||
|
for agent in agents.values()
|
||||||
|
if not game.powers[agent.power_name].is_eliminated()
|
||||||
]
|
]
|
||||||
if consolidation_tasks:
|
if consolidation_tasks:
|
||||||
await asyncio.gather(*consolidation_tasks, return_exceptions=True)
|
await asyncio.gather(*consolidation_tasks, return_exceptions=True)
|
||||||
|
|
@ -430,9 +435,14 @@ async def main():
|
||||||
# Save final overview stats
|
# Save final overview stats
|
||||||
overview_file_path = os.path.join(run_dir, "overview.jsonl")
|
overview_file_path = os.path.join(run_dir, "overview.jsonl")
|
||||||
with open(overview_file_path, "w") as overview_file:
|
with open(overview_file_path, "w") as overview_file:
|
||||||
|
# ---- make Namespace JSON-safe ----------------------------------
|
||||||
|
cfg = vars(run_config).copy()
|
||||||
|
if "prompts_dir_map" in cfg and isinstance(cfg["prompts_dir_map"], dict):
|
||||||
|
cfg["prompts_dir_map"] = {p: str(path) for p, path in cfg["prompts_dir_map"].items()}
|
||||||
|
# ----------------------------------------------------------------
|
||||||
overview_file.write(json.dumps(model_error_stats) + "\n")
|
overview_file.write(json.dumps(model_error_stats) + "\n")
|
||||||
overview_file.write(json.dumps(getattr(game, 'power_model_map', {})) + "\n")
|
overview_file.write(json.dumps(getattr(game, 'power_model_map', {})) + "\n")
|
||||||
overview_file.write(json.dumps(vars(run_config)) + "\n")
|
overview_file.write(json.dumps(cfg) + "\n")
|
||||||
|
|
||||||
logger.info("Done.")
|
logger.info("Done.")
|
||||||
|
|
||||||
|
|
|
||||||
13
models.py
13
models.py
|
|
@ -1,5 +1,5 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional, List
|
||||||
|
|
||||||
# your “typo → canonical” map
|
# your “typo → canonical” map
|
||||||
_POWER_ALIASES = {
|
_POWER_ALIASES = {
|
||||||
|
|
@ -7,8 +7,19 @@ _POWER_ALIASES = {
|
||||||
"GERMAN": "GERMANY",
|
"GERMAN": "GERMANY",
|
||||||
"UK": "ENGLAND",
|
"UK": "ENGLAND",
|
||||||
"BRIT": "ENGLAND",
|
"BRIT": "ENGLAND",
|
||||||
|
"Germany": "GERMANY",
|
||||||
|
"England": "ENGLAND",
|
||||||
|
"France": "FRANCE",
|
||||||
|
"Italy": "ITALY",
|
||||||
|
"Russia": "RUSSIA",
|
||||||
|
"Austria": "AUSTRIA",
|
||||||
|
"Turkey": "TURKEY",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
POWERS_ORDER: List[str] = [
|
||||||
|
"AUSTRIA", "ENGLAND", "FRANCE", "GERMANY",
|
||||||
|
"ITALY", "RUSSIA", "TURKEY",
|
||||||
|
]
|
||||||
|
|
||||||
class PowerEnum(str, Enum):
|
class PowerEnum(str, Enum):
|
||||||
AUSTRIA = "AUSTRIA"
|
AUSTRIA = "AUSTRIA"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue