mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-19 12:58:09 +00:00
fix prompt errors & add per-power prompt dir functionality
This commit is contained in:
parent
22cf3202c7
commit
7edc7c465f
11 changed files with 160 additions and 112 deletions
|
|
@ -25,29 +25,6 @@ logger = logging.getLogger(__name__)
|
|||
ALL_POWERS = frozenset({"AUSTRIA", "ENGLAND", "FRANCE", "GERMANY", "ITALY", "RUSSIA", "TURKEY"})
|
||||
ALLOWED_RELATIONSHIPS = ["Enemy", "Unfriendly", "Neutral", "Friendly", "Ally"]
|
||||
|
||||
|
||||
# == New: Helper function to load prompt files reliably ==
|
||||
def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> Optional[str]:
|
||||
"""Loads a prompt template from the prompts directory."""
|
||||
try:
|
||||
if prompts_dir:
|
||||
filepath = os.path.join(prompts_dir, filename)
|
||||
else:
|
||||
# Construct path relative to this file's location
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
default_prompts_dir = os.path.join(current_dir, "prompts")
|
||||
filepath = os.path.join(default_prompts_dir, filename)
|
||||
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Prompt file not found: {filepath}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading prompt file {filepath}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class DiplomacyAgent:
|
||||
"""
|
||||
Represents a stateful AI agent playing as a specific power in Diplomacy.
|
||||
|
|
@ -452,7 +429,7 @@ class DiplomacyAgent:
|
|||
|
||||
try:
|
||||
# Load the prompt template file
|
||||
prompt_template_content = _load_prompt_file(get_prompt_path("negotiation_diary_prompt.txt"), prompts_dir=self.prompts_dir)
|
||||
prompt_template_content = load_prompt(get_prompt_path("negotiation_diary_prompt.txt"), prompts_dir=self.prompts_dir)
|
||||
if not prompt_template_content:
|
||||
logger.error(f"[{self.power_name}] Could not load {get_prompt_path('negotiation_diary_prompt.txt')}. Skipping diary entry.")
|
||||
success_status = "Failure: Prompt file not loaded"
|
||||
|
|
@ -668,7 +645,7 @@ class DiplomacyAgent:
|
|||
logger.info(f"[{self.power_name}] Generating order diary entry for {game.current_short_phase}...")
|
||||
|
||||
# Load the prompt template
|
||||
prompt_template = _load_prompt_file(get_prompt_path("order_diary_prompt.txt"), prompts_dir=self.prompts_dir)
|
||||
prompt_template = load_prompt(get_prompt_path("order_diary_prompt.txt"), prompts_dir=self.prompts_dir)
|
||||
if not prompt_template:
|
||||
logger.error(f"[{self.power_name}] Could not load {get_prompt_path('order_diary_prompt.txt')}. Skipping diary entry.")
|
||||
return
|
||||
|
|
@ -815,7 +792,7 @@ class DiplomacyAgent:
|
|||
logger.info(f"[{self.power_name}] Generating phase result diary entry for {game.current_short_phase}...")
|
||||
|
||||
# Load the template
|
||||
prompt_template = _load_prompt_file("phase_result_diary_prompt.txt", prompts_dir=self.prompts_dir)
|
||||
prompt_template = load_prompt("phase_result_diary_prompt.txt", prompts_dir=self.prompts_dir)
|
||||
if not prompt_template:
|
||||
logger.error(f"[{self.power_name}] Could not load phase_result_diary_prompt.txt. Skipping diary entry.")
|
||||
return
|
||||
|
|
@ -921,7 +898,7 @@ class DiplomacyAgent:
|
|||
|
||||
try:
|
||||
# 1. Construct the prompt using the unformatted state update prompt file
|
||||
prompt_template = _load_prompt_file(get_prompt_path("state_update_prompt.txt"), prompts_dir=self.prompts_dir)
|
||||
prompt_template = load_prompt(get_prompt_path("state_update_prompt.txt"), prompts_dir=self.prompts_dir)
|
||||
if not prompt_template:
|
||||
logger.error(f"[{power_name}] Could not load {get_prompt_path('state_update_prompt.txt')}. Skipping state update.")
|
||||
return
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import logging
|
|||
import re
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from .utils import run_llm_and_log, log_llm_response
|
||||
from .utils import run_llm_and_log, log_llm_response, load_prompt
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from diplomacy import Game
|
||||
|
|
@ -12,25 +12,6 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> str | None:
|
||||
"""A local copy of the helper from agent.py to avoid circular imports."""
|
||||
import os
|
||||
|
||||
try:
|
||||
if prompts_dir:
|
||||
filepath = os.path.join(prompts_dir, filename)
|
||||
else:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
default_prompts_dir = os.path.join(current_dir, "prompts")
|
||||
filepath = os.path.join(default_prompts_dir, filename)
|
||||
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading prompt file {filepath}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def run_diary_consolidation(
|
||||
agent: "DiplomacyAgent",
|
||||
game: "Game",
|
||||
|
|
@ -75,7 +56,7 @@ async def run_diary_consolidation(
|
|||
logger.warning(f"[{agent.power_name}] No eligible entries to summarise; context diary left unchanged")
|
||||
return
|
||||
|
||||
prompt_template = _load_prompt_file("diary_consolidation_prompt.txt", prompts_dir=prompts_dir)
|
||||
prompt_template = load_prompt("diary_consolidation_prompt.txt", prompts_dir=prompts_dir)
|
||||
if not prompt_template:
|
||||
logger.error(f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting")
|
||||
return
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import json
|
|||
import asyncio
|
||||
from typing import Dict, Tuple, Optional, Any
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
|
||||
from diplomacy import Game
|
||||
from diplomacy.utils.export import to_saved_game_format, from_saved_game_format
|
||||
|
|
@ -150,13 +151,16 @@ def save_game_state(
|
|||
# 3b. For *this* phase we also inject the fresh agent snapshot
|
||||
# and the plans written during the turn.
|
||||
if phase_name == completed_phase_name:
|
||||
phase_block["config"] = vars(run_config)
|
||||
# ---- make run_config serialisable ---------------------------------
|
||||
cfg = vars(run_config).copy()
|
||||
if "prompts_dir_map" in cfg and isinstance(cfg["prompts_dir_map"], dict):
|
||||
cfg["prompts_dir_map"] = {p: str(path) for p, path in cfg["prompts_dir_map"].items()}
|
||||
if isinstance(cfg.get("prompts_dir"), Path):
|
||||
cfg["prompts_dir"] = str(cfg["prompts_dir"])
|
||||
# -------------------------------------------------------------------
|
||||
phase_block["config"] = cfg
|
||||
phase_block["state_agents"] = current_state_agents
|
||||
|
||||
# Plans for this phase – may be empty in non-movement phases.
|
||||
phase_obj = game_history._get_phase(phase_name)
|
||||
phase_block["state_history_plans"] = phase_obj.plans if phase_obj else {}
|
||||
|
||||
# -------------------------------------------------------------- #
|
||||
# 4. Attach top-level metadata and write atomically. #
|
||||
# -------------------------------------------------------------- #
|
||||
|
|
@ -254,9 +258,14 @@ def load_game_state(
|
|||
|
||||
if "state_agents" in last_phase_data:
|
||||
logger.info("Rebuilding agents from saved state...")
|
||||
prompts_dir_from_config = run_config.prompts_dir if run_config and hasattr(run_config, "prompts_dir") else None
|
||||
|
||||
for power_name, agent_data in last_phase_data["state_agents"].items():
|
||||
override_id = power_model_map.get(power_name)
|
||||
prompts_dir_from_config = (
|
||||
run_config.prompts_dir_map.get(power_name)
|
||||
if getattr(run_config, "prompts_dir_map", None)
|
||||
else run_config.prompts_dir # fallback to old single path
|
||||
)
|
||||
agents[power_name] = deserialize_agent(
|
||||
agent_data,
|
||||
prompts_dir=prompts_dir_from_config,
|
||||
|
|
@ -284,8 +293,15 @@ def load_game_state(
|
|||
return game, agents, game_history, run_config
|
||||
|
||||
|
||||
async def initialize_new_game(args: Namespace, game: Game, game_history: GameHistory, llm_log_file_path: str) -> Dict[str, DiplomacyAgent]:
|
||||
"""Initializes agents for a new game."""
|
||||
# ai_diplomacy/game_logic.py
|
||||
async def initialize_new_game(
|
||||
args: Namespace,
|
||||
game: Game,
|
||||
game_history: GameHistory,
|
||||
llm_log_file_path: str,
|
||||
) -> Dict[str, DiplomacyAgent]:
|
||||
"""Initializes agents for a new game (supports per-power prompt directories)."""
|
||||
|
||||
powers_order = sorted(list(ALL_POWERS))
|
||||
|
||||
# Parse token limits
|
||||
|
|
@ -300,8 +316,7 @@ async def initialize_new_game(args: Namespace, game: Game, game_history: GameHis
|
|||
else:
|
||||
logger.warning("Expected 7 values for --max_tokens_per_model, using default.")
|
||||
|
||||
# Handle power model mapping
|
||||
|
||||
# Handle power-model mapping
|
||||
if args.models:
|
||||
provided_models = [name.strip() for name in args.models.split(",")]
|
||||
if len(provided_models) == len(powers_order):
|
||||
|
|
@ -309,25 +324,51 @@ async def initialize_new_game(args: Namespace, game: Game, game_history: GameHis
|
|||
elif len(provided_models) == 1:
|
||||
game.power_model_map = dict(zip(powers_order, provided_models * 7))
|
||||
else:
|
||||
logger.error(f"Expected {len(powers_order)} models for --models but got {len(provided_models)}.")
|
||||
raise Exception("Invalid number of models. Models list must be either exactly 1 or 7 models, comma delimited.")
|
||||
logger.error(
|
||||
f"Expected {len(powers_order)} models for --models but got {len(provided_models)}."
|
||||
)
|
||||
raise Exception(
|
||||
"Invalid number of models. Models list must be either exactly 1 or 7 models, comma delimited."
|
||||
)
|
||||
else:
|
||||
game.power_model_map = assign_models_to_powers()
|
||||
|
||||
agents = {}
|
||||
agents: Dict[str, DiplomacyAgent] = {}
|
||||
initialization_tasks = []
|
||||
logger.info("Initializing Diplomacy Agents for each power...")
|
||||
|
||||
for power_name, model_id in game.power_model_map.items():
|
||||
if not game.powers[power_name].is_eliminated():
|
||||
# Determine the prompts directory for this power
|
||||
if hasattr(args, "prompts_dir_map") and args.prompts_dir_map:
|
||||
prompts_dir_for_power = args.prompts_dir_map.get(power_name, args.prompts_dir)
|
||||
else:
|
||||
prompts_dir_for_power = args.prompts_dir
|
||||
|
||||
try:
|
||||
client = load_model_client(model_id, prompts_dir=args.prompts_dir)
|
||||
client = load_model_client(model_id, prompts_dir=prompts_dir_for_power)
|
||||
client.max_tokens = model_max_tokens[power_name]
|
||||
agent = DiplomacyAgent(power_name=power_name, client=client, prompts_dir=args.prompts_dir)
|
||||
agent = DiplomacyAgent(
|
||||
power_name=power_name,
|
||||
client=client,
|
||||
prompts_dir=prompts_dir_for_power,
|
||||
)
|
||||
agents[power_name] = agent
|
||||
logger.info(f"Preparing initialization task for {power_name} with model {model_id}")
|
||||
initialization_tasks.append(initialize_agent_state_ext(agent, game, game_history, llm_log_file_path, prompts_dir=args.prompts_dir))
|
||||
initialization_tasks.append(
|
||||
initialize_agent_state_ext(
|
||||
agent,
|
||||
game,
|
||||
game_history,
|
||||
llm_log_file_path,
|
||||
prompts_dir=prompts_dir_for_power,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create agent or client for {power_name} with model {model_id}: {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"Failed to create agent or client for {power_name} with model {model_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
logger.info(f"Running {len(initialization_tasks)} agent initializations concurrently...")
|
||||
initialization_results = await asyncio.gather(*initialization_tasks, return_exceptions=True)
|
||||
|
|
@ -342,3 +383,4 @@ async def initialize_new_game(args: Namespace, game: Game, game_history: GameHis
|
|||
logger.info(f"Successfully initialized agent state for {power_name}.")
|
||||
|
||||
return agents
|
||||
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@ if False: # TYPE_CHECKING
|
|||
from diplomacy.models.game import GameHistory
|
||||
from .agent import DiplomacyAgent
|
||||
|
||||
from .agent import ALL_POWERS, ALLOWED_RELATIONSHIPS, _load_prompt_file
|
||||
from .utils import run_llm_and_log, log_llm_response, get_prompt_path
|
||||
from .agent import ALL_POWERS, ALLOWED_RELATIONSHIPS
|
||||
from .utils import run_llm_and_log, log_llm_response, get_prompt_path, load_prompt
|
||||
from .prompt_constructor import build_context_prompt
|
||||
from .formatter import format_with_gemini_flash, FORMAT_INITIAL_STATE
|
||||
|
||||
|
|
@ -37,10 +37,7 @@ async def initialize_agent_state_ext(
|
|||
try:
|
||||
# Load the prompt template
|
||||
allowed_labels_str = ", ".join(ALLOWED_RELATIONSHIPS)
|
||||
initial_prompt_template = _load_prompt_file(get_prompt_path("initial_state_prompt.txt"), prompts_dir=prompts_dir)
|
||||
if not initial_prompt_template:
|
||||
logger.error(f"[{power_name}] Could not load {get_prompt_path('initial_state_prompt.txt')}. Cannot initialize.")
|
||||
return
|
||||
initial_prompt_template = load_prompt(get_prompt_path("initial_state_prompt.txt"), prompts_dir=prompts_dir)
|
||||
|
||||
# Format the prompt with variables
|
||||
initial_prompt = initial_prompt_template.format(power_name=power_name, allowed_labels_str=allowed_labels_str)
|
||||
|
|
|
|||
0
ai_diplomacy/prompts/initial_state_prompt.txt
Normal file
0
ai_diplomacy/prompts/initial_state_prompt.txt
Normal file
0
ai_diplomacy/prompts_simple/initial_state_prompt.txt
Normal file
0
ai_diplomacy/prompts_simple/initial_state_prompt.txt
Normal file
|
|
@ -12,8 +12,9 @@ import asyncio
|
|||
from openai import RateLimitError, APIConnectionError, APITimeoutError
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from pathlib import Path
|
||||
from config import config
|
||||
from models import POWERS_ORDER
|
||||
|
||||
# Avoid circular import for type hinting
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -313,40 +314,34 @@ def normalize_and_compare_orders(
|
|||
return orders_not_accepted, orders_not_issued
|
||||
|
||||
|
||||
# Helper to load prompt text from file relative to the expected 'prompts' dir
|
||||
def load_prompt(filename: str, prompts_dir: Optional[str] = None) -> str:
|
||||
def load_prompt(fname: str | Path, prompts_dir: str | Path | None = None) -> str:
|
||||
"""
|
||||
Return the contents of *filename* while never joining paths twice.
|
||||
Resolve *fname* to an absolute path and return its contents.
|
||||
Resolution rules (first match wins):
|
||||
|
||||
Logic
|
||||
-----
|
||||
1. If *filename* is absolute → use it directly.
|
||||
2. Elif *filename* already contains a path component (e.g. 'x/y.txt')
|
||||
→ treat it as a relative path and use it directly.
|
||||
3. Elif *prompts_dir* is provided → join prompts_dir + filename.
|
||||
4. Otherwise → join the package’s default prompts dir.
|
||||
1. If *fname* is absolute -> use as-is.
|
||||
2. If *prompts_dir* is given -> prompts_dir / fname
|
||||
3. Otherwise -> <package_root>/prompts / fname
|
||||
"""
|
||||
if os.path.isabs(filename): # rule 1
|
||||
prompt_path = filename
|
||||
elif os.path.dirname(filename): # rule 2 (has slash)
|
||||
# If it's a relative path with directory, join with prompts_dir if provided
|
||||
if prompts_dir:
|
||||
prompt_path = os.path.join(prompts_dir, filename)
|
||||
|
||||
fname = Path(fname)
|
||||
|
||||
if fname.is_absolute():
|
||||
prompt_path = fname
|
||||
|
||||
else:
|
||||
if prompts_dir is not None:
|
||||
prompt_path = Path(prompts_dir) / fname
|
||||
else:
|
||||
default_dir = os.path.join(os.path.dirname(__file__), "prompts")
|
||||
prompt_path = os.path.join(default_dir, filename)
|
||||
elif prompts_dir: # rule 3
|
||||
prompt_path = os.path.join(prompts_dir, filename)
|
||||
else: # rule 4
|
||||
default_dir = os.path.join(os.path.dirname(__file__), "prompts")
|
||||
prompt_path = os.path.join(default_dir, filename)
|
||||
package_root = Path(__file__).resolve().parent
|
||||
prompt_path = package_root / "prompts" / fname
|
||||
|
||||
try:
|
||||
with open(prompt_path, "r", encoding="utf-8") as fh:
|
||||
return fh.read().strip()
|
||||
return prompt_path.read_text(encoding="utf-8").strip()
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Prompt file not found: {prompt_path}")
|
||||
return ""
|
||||
logger.error("Prompt file not found: %s", prompt_path)
|
||||
raise Exception("Prompt file not found: " + prompt_path)
|
||||
|
||||
|
||||
|
||||
# == New LLM Response Logging Function ==
|
||||
|
|
@ -552,3 +547,35 @@ def normalize_recipient_name(recipient: str) -> str:
|
|||
normalized = name_mapping.get(recipient, recipient)
|
||||
|
||||
return normalized
|
||||
|
||||
def parse_prompts_dir_arg(raw: str | None) -> Dict[str, Path]:
|
||||
"""
|
||||
Resolve --prompts_dir into a mapping {power: Path}.
|
||||
Accepts either a single path or 7 comma-separated paths.
|
||||
|
||||
Every path is normalised to an **absolute** Path object
|
||||
(using Path(...).expanduser().resolve()) and checked for existence.
|
||||
"""
|
||||
if not raw:
|
||||
return {}
|
||||
|
||||
parts = [s.strip() for s in raw.split(",") if s.strip()]
|
||||
if len(parts) not in {1, 7}:
|
||||
raise ValueError(
|
||||
f"--prompts_dir expects 1 or 7 paths, got {len(parts)} "
|
||||
f"({raw})"
|
||||
)
|
||||
|
||||
# Expand/resolve & verify
|
||||
def _norm(p: str) -> Path:
|
||||
path = Path(p).expanduser().resolve()
|
||||
if not path.is_dir():
|
||||
raise FileNotFoundError(f"Prompt directory not found: {path}")
|
||||
return path
|
||||
|
||||
if len(parts) == 1:
|
||||
path = _norm(parts[0])
|
||||
return {pwr: path for pwr in POWERS_ORDER}
|
||||
|
||||
paths = [_norm(p) for p in parts]
|
||||
return dict(zip(POWERS_ORDER, paths))
|
||||
|
|
@ -363,7 +363,10 @@ class StatisticalGameAnalyzer:
|
|||
|
||||
# Categorize by relationship
|
||||
recipient = msg.get('recipient_power')
|
||||
normalized_recipient = self._normalize_recipient_name(recipient)
|
||||
try:
|
||||
normalized_recipient = PowerEnum(recipient).value
|
||||
except ValueError:
|
||||
normalized_recipient = None
|
||||
|
||||
# Skip self-messages and invalid recipients
|
||||
if normalized_recipient and normalized_recipient != power and normalized_recipient in relationships:
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ class Configuration(BaseSettings):
|
|||
DEBUG: bool = False
|
||||
log_file_path: Path | None = None
|
||||
USE_UNFORMATTED_PROMPTS: bool = False
|
||||
SIMPLE_PROMPTS: bool = False
|
||||
SIMPLE_PROMPTS: bool = True
|
||||
|
||||
# Default models for tasks
|
||||
AI_DIPLOMACY_NARRATIVE_MODEL: str = "openrouter-google/gemini-2.5-flash-preview-05-20"
|
||||
|
|
|
|||
30
lm_game.py
30
lm_game.py
|
|
@ -20,7 +20,7 @@ os.environ["GRPC_POLL_STRATEGY"] = "poll" # Use 'poll' for macOS compatibility
|
|||
|
||||
from diplomacy import Game
|
||||
|
||||
from ai_diplomacy.utils import get_valid_orders, gather_possible_orders
|
||||
from ai_diplomacy.utils import get_valid_orders, gather_possible_orders, parse_prompts_dir_arg
|
||||
from ai_diplomacy.negotiations import conduct_negotiations
|
||||
from ai_diplomacy.planning import planning_phase
|
||||
from ai_diplomacy.game_history import GameHistory
|
||||
|
|
@ -31,6 +31,7 @@ from ai_diplomacy.game_logic import (
|
|||
initialize_new_game,
|
||||
)
|
||||
from ai_diplomacy.diary_logic import run_diary_consolidation
|
||||
from config import config
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
|
@ -177,15 +178,17 @@ async def main():
|
|||
args = parse_arguments()
|
||||
start_whole = time.time()
|
||||
|
||||
# honour --simple_prompts before anything else needs it
|
||||
if args.simple_prompts:
|
||||
os.environ["SIMPLE_PROMPTS"] = "1" # read by prompt_constructor
|
||||
config.SIMPLE_PROMPTS = True
|
||||
if args.prompts_dir is None:
|
||||
pkg_root = os.path.join(os.path.dirname(__file__), "ai_diplomacy")
|
||||
args.prompts_dir = os.path.join(pkg_root, "prompts_simple")
|
||||
|
||||
if args.prompts_dir and not os.path.isdir(args.prompts_dir):
|
||||
print(f"ERROR: Prompts directory not found: {args.prompts_dir}", file=sys.stderr)
|
||||
# Prompt-dir validation & mapping
|
||||
try:
|
||||
args.prompts_dir_map = parse_prompts_dir_arg(args.prompts_dir)
|
||||
except Exception as exc:
|
||||
print(f"ERROR: {exc}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Handle phase summaries flag - import narrative module only if enabled
|
||||
|
|
@ -197,10 +200,10 @@ async def main():
|
|||
|
||||
# Handle unformatted prompts flag
|
||||
if args.use_unformatted_prompts:
|
||||
os.environ["USE_UNFORMATTED_PROMPTS"] = "1"
|
||||
config.USE_UNFORMATTED_PROMPTS = True
|
||||
logger.info("Using two-step approach: unformatted prompts + Gemini Flash formatting")
|
||||
else:
|
||||
os.environ["USE_UNFORMATTED_PROMPTS"] = "0"
|
||||
config.USE_UNFORMATTED_PROMPTS = False
|
||||
logger.info("Using original single-step formatted prompts")
|
||||
|
||||
# --- 1. Determine Run Directory and Mode (New vs. Resume) ---
|
||||
|
|
@ -404,8 +407,10 @@ async def main():
|
|||
# Diary Consolidation
|
||||
if current_short_phase.startswith("S") and current_short_phase.endswith("M"):
|
||||
consolidation_tasks = [
|
||||
run_diary_consolidation(agent, game, llm_log_file_path, prompts_dir=run_config.prompts_dir)
|
||||
for agent in agents.values() if not game.powers[agent.power_name].is_eliminated()
|
||||
run_diary_consolidation(agent, game, llm_log_file_path,
|
||||
prompts_dir=agent.prompts_dir)
|
||||
for agent in agents.values()
|
||||
if not game.powers[agent.power_name].is_eliminated()
|
||||
]
|
||||
if consolidation_tasks:
|
||||
await asyncio.gather(*consolidation_tasks, return_exceptions=True)
|
||||
|
|
@ -430,9 +435,14 @@ async def main():
|
|||
# Save final overview stats
|
||||
overview_file_path = os.path.join(run_dir, "overview.jsonl")
|
||||
with open(overview_file_path, "w") as overview_file:
|
||||
# ---- make Namespace JSON-safe ----------------------------------
|
||||
cfg = vars(run_config).copy()
|
||||
if "prompts_dir_map" in cfg and isinstance(cfg["prompts_dir_map"], dict):
|
||||
cfg["prompts_dir_map"] = {p: str(path) for p, path in cfg["prompts_dir_map"].items()}
|
||||
# ----------------------------------------------------------------
|
||||
overview_file.write(json.dumps(model_error_stats) + "\n")
|
||||
overview_file.write(json.dumps(getattr(game, 'power_model_map', {})) + "\n")
|
||||
overview_file.write(json.dumps(vars(run_config)) + "\n")
|
||||
overview_file.write(json.dumps(cfg) + "\n")
|
||||
|
||||
logger.info("Done.")
|
||||
|
||||
|
|
|
|||
13
models.py
13
models.py
|
|
@ -1,5 +1,5 @@
|
|||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, List
|
||||
|
||||
# your “typo → canonical” map
|
||||
_POWER_ALIASES = {
|
||||
|
|
@ -7,8 +7,19 @@ _POWER_ALIASES = {
|
|||
"GERMAN": "GERMANY",
|
||||
"UK": "ENGLAND",
|
||||
"BRIT": "ENGLAND",
|
||||
"Germany": "GERMANY",
|
||||
"England": "ENGLAND",
|
||||
"France": "FRANCE",
|
||||
"Italy": "ITALY",
|
||||
"Russia": "RUSSIA",
|
||||
"Austria": "AUSTRIA",
|
||||
"Turkey": "TURKEY",
|
||||
}
|
||||
|
||||
POWERS_ORDER: List[str] = [
|
||||
"AUSTRIA", "ENGLAND", "FRANCE", "GERMANY",
|
||||
"ITALY", "RUSSIA", "TURKEY",
|
||||
]
|
||||
|
||||
class PowerEnum(str, Enum):
|
||||
AUSTRIA = "AUSTRIA"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue