Ruff format and check

This commit is contained in:
Tyler Marques 2025-07-03 14:27:50 -07:00
parent aadedd6512
commit ad817695ef
No known key found for this signature in database
GPG key ID: CB99EDCF41D3016F
13 changed files with 738 additions and 973 deletions

View file

@ -10,7 +10,7 @@ import ast
from ..config import config
# Assuming BaseModelClient is importable from clients.py in the same directory
from .clients import BaseModelClient, load_model_client
from .clients import BaseModelClient
# Import load_prompt and the new logging wrapper from utils
from .utils import load_prompt, run_llm_and_log, log_llm_response, get_prompt_path
@ -336,7 +336,7 @@ class DiplomacyAgent:
)
# If even the last resort doesn't give a dict, return empty dict
return {}
except Exception as e:
except Exception:
logger.error(f"[{self.power_name}] All JSON extraction attempts failed. Original text: {original_text[:500]}...")
return {}
@ -965,7 +965,7 @@ class DiplomacyAgent:
other_powers = [p for p in game.powers if p != power_name]
# Create a readable board state string from the board_state dict
board_state_str = f"Board State:\n"
board_state_str = "Board State:\n"
for p_name, power_data in board_state.get("powers", {}).items():
# Get units and centers from the board state
units = power_data.get("units", [])

File diff suppressed because it is too large Load diff

View file

@ -11,25 +11,28 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> str | None:
"""A local copy of the helper from agent.py to avoid circular imports."""
import os
try:
if prompts_dir:
filepath = os.path.join(prompts_dir, filename)
else:
current_dir = os.path.dirname(os.path.abspath(__file__))
default_prompts_dir = os.path.join(current_dir, 'prompts')
default_prompts_dir = os.path.join(current_dir, "prompts")
filepath = os.path.join(default_prompts_dir, filename)
with open(filepath, 'r', encoding='utf-8') as f:
with open(filepath, "r", encoding="utf-8") as f:
return f.read()
except Exception as e:
logger.error(f"Error loading prompt file {filepath}: {e}")
return None
async def run_diary_consolidation(
agent: 'DiplomacyAgent',
agent: "DiplomacyAgent",
game: "Game",
log_file_path: str,
entries_to_keep_unsummarized: int = 6,
@ -39,70 +42,42 @@ async def run_diary_consolidation(
Consolidate older diary entries while keeping recent ones.
This is the logic moved from the DiplomacyAgent class.
"""
logger.info(
f"[{agent.power_name}] CONSOLIDATION START — "
f"{len(agent.full_private_diary)} total full entries"
)
logger.info(f"[{agent.power_name}] CONSOLIDATION START — {len(agent.full_private_diary)} total full entries")
full_entries = [
e for e in agent.full_private_diary
if not e.startswith("[CONSOLIDATED HISTORY]")
]
full_entries = [e for e in agent.full_private_diary if not e.startswith("[CONSOLIDATED HISTORY]")]
if len(full_entries) <= entries_to_keep_unsummarized:
agent.private_diary = list(agent.full_private_diary)
logger.info(
f"[{agent.power_name}] ≤ {entries_to_keep_unsummarized} full entries — "
"skipping consolidation"
)
logger.info(f"[{agent.power_name}] ≤ {entries_to_keep_unsummarized} full entries — skipping consolidation")
return
boundary_entry = full_entries[-entries_to_keep_unsummarized]
match = re.search(r"\[[SFWRAB]\s*(\d{4})", boundary_entry)
if not match:
logger.error(
f"[{agent.power_name}] Could not parse year from boundary entry; "
"aborting consolidation"
)
logger.error(f"[{agent.power_name}] Could not parse year from boundary entry; aborting consolidation")
agent.private_diary = list(agent.full_private_diary)
return
cutoff_year = int(match.group(1))
logger.info(
f"[{agent.power_name}] Cut-off year for consolidation: {cutoff_year}"
)
logger.info(f"[{agent.power_name}] Cut-off year for consolidation: {cutoff_year}")
def _entry_year(entry: str) -> int | None:
m = re.search(r"\[[SFWRAB]\s*(\d{4})", entry)
return int(m.group(1)) if m else None
entries_to_summarize = [
e for e in full_entries
if (_entry_year(e) is not None and _entry_year(e) < cutoff_year)
]
entries_to_keep = [
e for e in full_entries
if (_entry_year(e) is None or _entry_year(e) >= cutoff_year)
]
entries_to_summarize = [e for e in full_entries if (_entry_year(e) is not None and _entry_year(e) < cutoff_year)]
entries_to_keep = [e for e in full_entries if (_entry_year(e) is None or _entry_year(e) >= cutoff_year)]
logger.info(
f"[{agent.power_name}] Summarising {len(entries_to_summarize)} entries; "
f"keeping {len(entries_to_keep)} recent entries verbatim"
)
logger.info(f"[{agent.power_name}] Summarising {len(entries_to_summarize)} entries; keeping {len(entries_to_keep)} recent entries verbatim")
if not entries_to_summarize:
agent.private_diary = list(agent.full_private_diary)
logger.warning(
f"[{agent.power_name}] No eligible entries to summarise; "
"context diary left unchanged"
)
logger.warning(f"[{agent.power_name}] No eligible entries to summarise; context diary left unchanged")
return
prompt_template = _load_prompt_file("diary_consolidation_prompt.txt", prompts_dir=prompts_dir)
if not prompt_template:
logger.error(
f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting"
)
logger.error(f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting")
return
prompt = prompt_template.format(
@ -132,27 +107,18 @@ async def run_diary_consolidation(
new_summary_entry = f"[CONSOLIDATED HISTORY] {consolidated_text}"
agent.private_diary = [new_summary_entry] + entries_to_keep
success_flag = "TRUE"
logger.info(
f"[{agent.power_name}] Consolidation complete — "
f"{len(agent.private_diary)} context entries now"
)
logger.info(f"[{agent.power_name}] Consolidation complete — {len(agent.private_diary)} context entries now")
except Exception as exc:
logger.error(
f"[{agent.power_name}] Diary consolidation failed: {exc}", exc_info=True
)
logger.error(f"[{agent.power_name}] Diary consolidation failed: {exc}", exc_info=True)
finally:
log_llm_response(
log_file_path=log_file_path,
model_name=(
consolidation_client.model_name
if consolidation_client is not None
else agent.client.model_name
),
model_name=(consolidation_client.model_name if consolidation_client is not None else agent.client.model_name),
power_name=agent.power_name,
phase=game.current_short_phase,
response_type="diary_consolidation",
raw_input_prompt=prompt,
raw_response=raw_response,
success=success_flag,
)
)

View file

@ -3,14 +3,13 @@ Formatter module for converting natural language LLM responses to structured JSO
Uses Gemini 2.5 Flash via OpenRouter to extract and format information from reasoning-focused responses.
"""
import json
import logging
import os
from typing import Dict, Any, Optional
from typing import Optional
from pathlib import Path
# Import logging function and model configuration
from .utils import log_llm_response, get_special_models
# Import client loading function
from .clients import load_model_client
@ -26,77 +25,73 @@ FORMAT_ORDER_DIARY = "order_diary"
async def format_with_gemini_flash(
raw_response: str,
format_type: str,
power_name: Optional[str] = None,
phase: Optional[str] = None,
log_file_path: Optional[str] = None
raw_response: str, format_type: str, power_name: Optional[str] = None, phase: Optional[str] = None, log_file_path: Optional[str] = None
) -> str:
"""
Format a natural language response into required JSON structure using Gemini Flash.
Args:
raw_response: Natural language response from primary LLM
format_type: Type of formatting required (e.g., FORMAT_ORDERS, FORMAT_STATE_UPDATE)
power_name: Optional power name for logging
phase: Optional phase for logging
log_file_path: Optional path for CSV logging
Returns:
JSON string in the expected format
"""
# Map format types to prompt files
format_prompts = {
FORMAT_STATE_UPDATE: "formatting/format_state_update.txt",
FORMAT_CONVERSATION: "formatting/format_conversation.txt",
FORMAT_CONVERSATION: "formatting/format_conversation.txt",
FORMAT_NEGOTIATION_DIARY: "formatting/format_negotiation_diary.txt",
FORMAT_ORDERS: "formatting/format_orders.txt",
FORMAT_INITIAL_STATE: "formatting/format_initial_state.txt",
FORMAT_ORDER_DIARY: "formatting/format_order_diary.txt"
FORMAT_ORDER_DIARY: "formatting/format_order_diary.txt",
}
if format_type not in format_prompts:
raise ValueError(f"Unknown format type: {format_type}")
# Load the formatting prompt
prompt_file = Path(__file__).parent / "prompts" / format_prompts[format_type]
if not prompt_file.exists():
raise FileNotFoundError(f"Formatting prompt not found: {prompt_file}")
with open(prompt_file, 'r') as f:
with open(prompt_file, "r") as f:
format_prompt = f.read()
# Replace placeholder with actual response
format_prompt = format_prompt.replace("[RAW_RESPONSE]", raw_response)
# Get model name from configuration
special_models = get_special_models()
model_name = special_models["formatter"]
# Load the formatter client using the same logic as other models
formatter_client = load_model_client(model_name)
try:
logger.info(f"[FORMATTER] Calling {model_name} for {format_type} formatting")
# Create the full prompt with system message
system_content = "You are a precise formatting assistant. Extract and format information exactly as requested."
formatter_client.set_system_prompt(system_content)
# Use the client's generate_response method
formatted_response = await formatter_client.generate_response(
prompt=format_prompt,
temperature=0, # Deterministic formatting
inject_random_seed=False # No need for random seed in formatting
inject_random_seed=False, # No need for random seed in formatting
)
if not formatted_response:
logger.warning(f"[FORMATTER] {model_name} returned empty response")
return ""
# Log successful formatting
logger.info(f"[FORMATTER] Successfully formatted {format_type} response")
# Strip any markdown formatting that Gemini might add
if formatted_response.startswith("```json"):
formatted_response = formatted_response[7:]
@ -105,7 +100,7 @@ async def format_with_gemini_flash(
if formatted_response.endswith("```"):
formatted_response = formatted_response[:-3]
formatted_response = formatted_response.strip()
# Log if requested
if log_file_path:
log_llm_response(
@ -116,14 +111,14 @@ async def format_with_gemini_flash(
response_type=f"format_{format_type}",
raw_input_prompt=format_prompt,
raw_response=formatted_response,
success="Success"
success="Success",
)
return formatted_response
except Exception as e:
logger.error(f"[FORMATTER] Error calling Gemini Flash: {e}")
# Log error if requested
if log_file_path:
log_llm_response(
@ -134,11 +129,11 @@ async def format_with_gemini_flash(
response_type=f"format_{format_type}",
raw_input_prompt=format_prompt,
raw_response=f"ERROR: {str(e)}",
success=f"Failure: {type(e).__name__}"
success=f"Failure: {type(e).__name__}",
)
# Return empty structure based on format type
if format_type == FORMAT_CONVERSATION:
return "[]"
else:
return "{}"
return "{}"

View file

@ -3,7 +3,6 @@ import logging
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, List, Optional
from pprint import pformat
logger = logging.getLogger("utils")
logger.setLevel(logging.INFO)
@ -24,12 +23,8 @@ class Phase:
plans: Dict[str, str] = field(default_factory=dict)
messages: List[Message] = field(default_factory=list)
submitted_orders_by_power: Dict[str, List[str]] = field(default_factory=dict)
orders_by_power: Dict[str, List[str]] = field(
default_factory=lambda: defaultdict(list)
)
results_by_power: Dict[str, List[List[str]]] = field(
default_factory=lambda: defaultdict(list)
)
orders_by_power: Dict[str, List[str]] = field(default_factory=lambda: defaultdict(list))
results_by_power: Dict[str, List[List[str]]] = field(default_factory=lambda: defaultdict(list))
# NEW: Store phase-end summaries provided by each power
phase_summaries: Dict[str, str] = field(default_factory=dict)
# NEW: Store experience/journal updates from each power for this phase
@ -37,11 +32,9 @@ class Phase:
def add_plan(self, power_name: str, plan: str):
self.plans[power_name] = plan
def add_message(self, sender: str, recipient: str, content: str):
self.messages.append(
Message(sender=sender, recipient=recipient, content=content)
)
self.messages.append(Message(sender=sender, recipient=recipient, content=content))
def add_orders(self, power: str, orders: List[str], results: List[List[str]]):
self.orders_by_power[power].extend(orders)
@ -110,14 +103,10 @@ class GameHistory:
phase.plans[power_name] = plan
logger.debug(f"Added plan for {power_name} in {phase_name}")
def add_message(
self, phase_name: str, sender: str, recipient: str, message_content: str
):
def add_message(self, phase_name: str, sender: str, recipient: str, message_content: str):
phase = self._get_phase(phase_name)
if phase:
message = Message(
sender=sender, recipient=recipient, content=message_content
)
message = Message(sender=sender, recipient=recipient, content=message_content)
phase.messages.append(message)
logger.debug(f"Added message from {sender} to {recipient} in {phase_name}")
@ -147,16 +136,12 @@ class GameHistory:
phase.experience_updates[power_name] = update
logger.debug(f"Added experience update for {power_name} in {phase_name}")
def get_strategic_directives(self):
def get_strategic_directives(self):
# returns for last phase only if exists
if not self.phases:
if not self.phases:
return {}
return self.phases[-1].plans
def get_order_history_for_prompt(
self,
game: "Game",
@ -194,10 +179,7 @@ class GameHistory:
return str(tag).strip().lower()
engine_phases = {ph.name: ph for ph in getattr(game, "get_phase_history", lambda: [])()}
eng2code = {
"AUSTRIA": "AUT", "ENGLAND": "ENG", "FRANCE": "FRA",
"GERMANY": "GER", "ITALY": "ITA", "RUSSIA": "RUS", "TURKEY": "TUR"
}
eng2code = {"AUSTRIA": "AUT", "ENGLAND": "ENG", "FRANCE": "FRA", "GERMANY": "GER", "ITALY": "ITA", "RUSSIA": "RUS", "TURKEY": "TUR"}
norm = game.map.norm
out_lines = ["**ORDER HISTORY (Recent Rounds)**"]
@ -209,14 +191,14 @@ class GameHistory:
for pwr in sorted(set(ph.orders_by_power) | set(ph.submitted_orders_by_power)):
submitted = ph.submitted_orders_by_power.get(pwr, [])
accepted = ph.orders_by_power.get(pwr, [])
accepted = ph.orders_by_power.get(pwr, [])
if isinstance(submitted, str):
submitted = [submitted]
if isinstance(accepted, str):
accepted = [accepted]
def _norm_keep(o): # keep WAIVE readable
def _norm_keep(o): # keep WAIVE readable
return o if o.upper() == "WAIVE" else norm(o)
sub_norm = {_norm_keep(o): o for o in submitted}
@ -264,13 +246,6 @@ class GameHistory:
return "\n(No orders were issued in recent history)\n"
return "\n".join(out_lines)
def get_messages_this_round(self, power_name: str, current_phase_name: str) -> str:
current_phase: Optional[Phase] = None
for phase_obj in self.phases:
@ -281,7 +256,7 @@ class GameHistory:
if not current_phase:
return f"\n(No messages found for current phase: {current_phase_name})\n"
messages_str = ""
messages_str = ""
global_msgs_content = current_phase.get_global_messages()
if global_msgs_content:
@ -299,7 +274,7 @@ class GameHistory:
messages_str += "\n"
else:
messages_str += "\n**PRIVATE MESSAGES TO/FROM YOU THIS ROUND:**\n (No private messages this round)\n"
if not global_msgs_content and not private_msgs_dict:
return f"\n(No messages recorded for current phase: {current_phase_name})\n"
@ -313,10 +288,10 @@ class GameHistory:
"""
if not self.phases:
return []
# Get the most recent 2 phases including current phase
recent_phases = self.phases[-2:] if len(self.phases) >= 2 else self.phases[-1:]
# Collect all messages sent TO this power
messages_to_power = []
for phase in recent_phases:
@ -325,48 +300,44 @@ class GameHistory:
if msg.recipient == power_name or (msg.recipient == "GLOBAL" and msg.sender != power_name):
# Skip if sender is this power (don't need to respond to own messages)
if msg.sender != power_name:
messages_to_power.append({
'sender': msg.sender,
'content': msg.content,
'phase': phase.name
})
messages_to_power.append({"sender": msg.sender, "content": msg.content, "phase": phase.name})
# Add debug logging
logger.info(f"Found {len(messages_to_power)} messages to {power_name} across {len(recent_phases)} phases")
if not messages_to_power:
logger.info(f"No messages found for {power_name} to respond to")
# Take the most recent 'limit' messages
return messages_to_power[-limit:] if messages_to_power else []
def get_ignored_messages_by_power(self, sender_name: str, num_phases: int = 3) -> Dict[str, List[Dict[str, str]]]:
"""
Identifies which powers are not responding to messages from sender_name.
Returns a dict mapping power names to their ignored messages.
A message is considered ignored if:
1. It was sent from sender_name to another power (private)
2. No response from that power was received in the same or next phase
"""
ignored_by_power = {}
# Get recent phases
recent_phases = self.phases[-num_phases:] if self.phases else []
if not recent_phases:
return ignored_by_power
for i, phase in enumerate(recent_phases):
# Get messages sent by sender to specific powers (not global)
sender_messages = []
for msg in phase.messages:
# Handle both Message objects and dict objects
if isinstance(msg, Message):
if msg.sender == sender_name and msg.recipient not in ['GLOBAL', 'ALL']:
if msg.sender == sender_name and msg.recipient not in ["GLOBAL", "ALL"]:
sender_messages.append(msg)
else: # Assume dict
if msg['sender'] == sender_name and msg['recipient'] not in ['GLOBAL', 'ALL']:
if msg["sender"] == sender_name and msg["recipient"] not in ["GLOBAL", "ALL"]:
sender_messages.append(msg)
# Check for responses in this and next phases
for msg in sender_messages:
# Handle both Message objects and dict objects
@ -374,37 +345,36 @@ class GameHistory:
recipient = msg.recipient
msg_content = msg.content
else:
recipient = msg['recipient']
msg_content = msg['content']
recipient = msg["recipient"]
msg_content = msg["content"]
# Look for responses in current phase and next phases
found_response = False
# Check remaining phases starting from current
for check_phase in recent_phases[i:min(i+2, len(recent_phases))]:
for check_phase in recent_phases[i : min(i + 2, len(recent_phases))]:
# Look for messages FROM the recipient TO the sender (direct response)
# or FROM the recipient to GLOBAL/ALL that might acknowledge sender
response_msgs = []
for m in check_phase.messages:
if isinstance(m, Message):
if m.sender == recipient and (m.recipient == sender_name or
(m.recipient in ['GLOBAL', 'ALL'] and sender_name in m.content)):
if m.sender == recipient and (
m.recipient == sender_name or (m.recipient in ["GLOBAL", "ALL"] and sender_name in m.content)
):
response_msgs.append(m)
else: # Assume dict
if m['sender'] == recipient and (m['recipient'] == sender_name or
(m['recipient'] in ['GLOBAL', 'ALL'] and sender_name in m.get('content', ''))):
if m["sender"] == recipient and (
m["recipient"] == sender_name or (m["recipient"] in ["GLOBAL", "ALL"] and sender_name in m.get("content", ""))
):
response_msgs.append(m)
if response_msgs:
found_response = True
break
if not found_response:
if recipient not in ignored_by_power:
ignored_by_power[recipient] = []
ignored_by_power[recipient].append({
'phase': phase.name,
'content': msg_content
})
return ignored_by_power
ignored_by_power[recipient].append({"phase": phase.name, "content": msg_content})
return ignored_by_power

View file

@ -3,7 +3,7 @@ import logging
import os
import json
import asyncio
from typing import Dict, List, Tuple, Optional, Any
from typing import Dict, Tuple, Optional, Any
from argparse import Namespace
from diplomacy import Game
@ -19,6 +19,7 @@ logger = logging.getLogger(__name__)
# --- Serialization / Deserialization ---
def serialize_agent(agent: DiplomacyAgent) -> dict:
"""Converts an agent object to a JSON-serializable dictionary."""
return {
@ -31,12 +32,8 @@ def serialize_agent(agent: DiplomacyAgent) -> dict:
"private_diary": agent.private_diary,
}
def deserialize_agent(
agent_data: dict,
prompts_dir: Optional[str] = None,
*,
override_model_id: Optional[str] = None
) -> DiplomacyAgent:
def deserialize_agent(agent_data: dict, prompts_dir: Optional[str] = None, *, override_model_id: Optional[str] = None) -> DiplomacyAgent:
"""
Recreates an agent object from a dictionary.
@ -68,34 +65,31 @@ def deserialize_agent(
# --- State Management ---
# game_logic.py
_PHASE_ORDER = ["M", "R", "A"] # Movement → Retreats → Adjustments
_PHASE_ORDER = ["M", "R", "A"] # Movement → Retreats → Adjustments
def _next_phase_name(short: str) -> str:
"""
Return the Diplomacy phase string that chronologically follows *short*.
(E.g. S1901M S1901R, S1901R W1901A, W1901A S1902M)
"""
season = short[0] # 'S' | 'W'
year = int(short[1:5])
typ = short[-1] # 'M' | 'R' | 'A'
season = short[0] # 'S' | 'W'
year = int(short[1:5])
typ = short[-1] # 'M' | 'R' | 'A'
idx = _PHASE_ORDER.index(typ)
if idx < 2: # still in the same season
return f"{season}{year}{_PHASE_ORDER[idx+1]}"
if idx < 2: # still in the same season
return f"{season}{year}{_PHASE_ORDER[idx + 1]}"
# typ was 'A' → roll season
if season == "S": # summer → winter, same year
if season == "S": # summer → winter, same year
return f"W{year}M"
else: # winter→ spring, next year
return f"S{year+1}M"
else: # winter→ spring, next year
return f"S{year + 1}M"
def save_game_state(
game: Game,
agents: Dict[str, DiplomacyAgent],
game_history: GameHistory,
output_path: str,
run_config: Namespace,
completed_phase_name: str
game: Game, agents: Dict[str, DiplomacyAgent], game_history: GameHistory, output_path: str, run_config: Namespace, completed_phase_name: str
):
"""
Serialise the entire game to JSON, preserving per-phase custom metadata
@ -130,9 +124,7 @@ def save_game_state(
if extras:
previous_phase_extras[phase["name"]] = extras
except Exception as exc:
logger.warning(
"Could not load previous save to retain metadata: %s", exc, exc_info=True
)
logger.warning("Could not load previous save to retain metadata: %s", exc, exc_info=True)
# -------------------------------------------------------------- #
# 2. Build the fresh base structure from the diplomacy library. #
@ -143,16 +135,12 @@ def save_game_state(
# 3. Walk every phase and merge the metadata back in. #
# -------------------------------------------------------------- #
# Capture the *current* snapshot of every live agent exactly once.
current_state_agents = {
p_name: serialize_agent(p_agent)
for p_name, p_agent in agents.items()
if not game.powers[p_name].is_eliminated()
}
current_state_agents = {p_name: serialize_agent(p_agent) for p_name, p_agent in agents.items() if not game.powers[p_name].is_eliminated()}
for phase_block in saved_game.get("phases", []):
if int(phase_block["name"][1:5]) > run_config.max_year:
break
phase_name = phase_block["name"]
# 3a. Re-attach anything we cached from a previous save.
@ -167,31 +155,24 @@ def save_game_state(
# Plans for this phase may be empty in non-movement phases.
phase_obj = game_history._get_phase(phase_name)
phase_block["state_history_plans"] = (
phase_obj.plans if phase_obj else {}
)
phase_block["state_history_plans"] = phase_obj.plans if phase_obj else {}
# -------------------------------------------------------------- #
# 4. Attach top-level metadata and write atomically. #
# -------------------------------------------------------------- #
saved_game["phase_summaries"] = getattr(game, "phase_summaries", {})
saved_game["final_agent_states"] = {
p_name: {"relationships": a.relationships, "goals": a.goals}
for p_name, a in agents.items()
}
saved_game["final_agent_states"] = {p_name: {"relationships": a.relationships, "goals": a.goals} for p_name, a in agents.items()}
# Filter out phases > max_year
#saved_game["phases"] = [
# saved_game["phases"] = [
# ph for ph in saved_game["phases"]
# if int(ph["name"][1:5]) <= run_config.max_year # <= 1902, for example
#]
# ]
atomic_write_json(saved_game, output_path)
logger.info("Game state saved successfully.")
def load_game_state(
run_dir: str,
game_file_name: str,
@ -204,7 +185,7 @@ def load_game_state(
raise FileNotFoundError(f"Cannot resume. Save file not found at: {game_file_path}")
logger.info(f"Loading game state from: {game_file_path}")
with open(game_file_path, 'r') as f:
with open(game_file_path, "r") as f:
saved_game_data = json.load(f)
# If resuming, find the specified phase and truncate the data after it
@ -213,49 +194,49 @@ def load_game_state(
try:
# Find the index of the phase *before* the one we want to resume from.
# We will start the simulation *at* the resume_from_phase.
resume_idx = next(i for i, phase in enumerate(saved_game_data['phases']) if phase['name'] == resume_from_phase)
resume_idx = next(i for i, phase in enumerate(saved_game_data["phases"]) if phase["name"] == resume_from_phase)
# Truncate the list to exclude everything after the resume phase
# Note: the state saved for a given phase represents the state at the beginning of that phase.
saved_game_data['phases'] = saved_game_data['phases'][:resume_idx+1]
saved_game_data["phases"] = saved_game_data["phases"][: resume_idx + 1]
# Wipe any data that must be regenerated.
for key in ("orders", "results", "messages"):
saved_game_data['phases'][-1].pop(key, None)
saved_game_data["phases"][-1].pop(key, None)
logger.info(f"Game history truncated to {len(saved_game_data['phases'])} phases. The next phase to run will be {resume_from_phase}.")
except StopIteration:
# If the phase is not found, maybe it's the first phase (S1901M)
if resume_from_phase == "S1901M":
saved_game_data['phases'] = []
logger.info("Resuming from S1901M. Starting with a clean history.")
saved_game_data["phases"] = []
logger.info("Resuming from S1901M. Starting with a clean history.")
else:
raise ValueError(f"Resume phase '{resume_from_phase}' not found in the save file.")
# Reconstruct the Game object
last_phase = saved_game_data['phases'][-1]
last_phase = saved_game_data["phases"][-1]
# Wipe the data that must be regenerated **but preserve the keys**
last_phase['orders'] = {} # was dict
last_phase['results'] = {} # was dict
last_phase['messages'] = []
last_phase["orders"] = {} # was dict
last_phase["results"] = {} # was dict
last_phase["messages"] = []
game = from_saved_game_format(saved_game_data)
game.phase_summaries = saved_game_data.get('phase_summaries', {})
game.phase_summaries = saved_game_data.get("phase_summaries", {})
# Reconstruct agents and game history from the *last* valid phase in the data
if not saved_game_data['phases']:
if not saved_game_data["phases"]:
# This happens if we are resuming from the very beginning (S1901M)
logger.info("No previous phases found. Initializing fresh agents and history.")
agents = {} # Will be created by the main loop
agents = {} # Will be created by the main loop
game_history = GameHistory()
else:
# We save the game state up to & including the current (uncompleted) phase.
# So we need to grab the agent state from the previous (completed) phase.
if len(saved_game_data['phases']) <= 1:
if len(saved_game_data["phases"]) <= 1:
last_phase_data = {}
else:
last_phase_data = saved_game_data['phases'][-2]
last_phase_data = saved_game_data["phases"][-2]
# -------------------- Rebuild agents -------------------- #
agents = {}
@ -269,18 +250,11 @@ def load_game_state(
elif len(provided) == 1:
power_model_map = dict(zip(powers_order, provided * len(powers_order)))
else:
raise ValueError(
f"Invalid --models argument: expected 1 or {len(powers_order)} "
f"items, got {len(provided)}."
)
raise ValueError(f"Invalid --models argument: expected 1 or {len(powers_order)} items, got {len(provided)}.")
if "state_agents" in last_phase_data:
logger.info("Rebuilding agents from saved state...")
prompts_dir_from_config = (
run_config.prompts_dir
if run_config and hasattr(run_config, "prompts_dir")
else None
)
prompts_dir_from_config = run_config.prompts_dir if run_config and hasattr(run_config, "prompts_dir") else None
for power_name, agent_data in last_phase_data["state_agents"].items():
override_id = power_model_map.get(power_name)
agents[power_name] = deserialize_agent(
@ -290,40 +264,30 @@ def load_game_state(
)
logger.info(f"Rebuilt {len(agents)} agents.")
else:
raise ValueError(
"Cannot resume: 'state_agents' key not found in the last phase "
"of the save file."
)
raise ValueError("Cannot resume: 'state_agents' key not found in the last phase of the save file.")
# Rebuild GameHistory
game_history = GameHistory()
logger.info("Rebuilding game history...")
for phase_data in saved_game_data['phases'][:-1]:
phase_name = phase_data['name']
for phase_data in saved_game_data["phases"][:-1]:
phase_name = phase_data["name"]
game_history.add_phase(phase_name)
# Add messages
for msg in phase_data.get('messages', []):
game_history.add_message(phase_name, msg['sender'], msg['recipient'], msg['message'])
for msg in phase_data.get("messages", []):
game_history.add_message(phase_name, msg["sender"], msg["recipient"], msg["message"])
# Add plans
if 'state_history_plans' in phase_data:
for p_name, plan in phase_data['state_history_plans'].items():
if "state_history_plans" in phase_data:
for p_name, plan in phase_data["state_history_plans"].items():
game_history.add_plan(phase_name, p_name, plan)
logger.info("Game history rebuilt.")
return game, agents, game_history, run_config
async def initialize_new_game(
args: Namespace,
game: Game,
game_history: GameHistory,
llm_log_file_path: str
) -> Dict[str, DiplomacyAgent]:
async def initialize_new_game(args: Namespace, game: Game, game_history: GameHistory, llm_log_file_path: str) -> Dict[str, DiplomacyAgent]:
"""Initializes agents for a new game."""
powers_order = sorted(list(ALL_POWERS))
# Parse token limits
default_max_tokens = args.max_tokens
model_max_tokens = {p: default_max_tokens for p in powers_order}
@ -339,7 +303,7 @@ async def initialize_new_game(
# Handle power model mapping
if args.models:
provided_models = [name.strip() for name in args.models.split(",")]
provided_models = [name.strip() for name in args.models.split(",")]
if len(provided_models) == len(powers_order):
game.power_model_map = dict(zip(powers_order, provided_models))
elif len(provided_models) == 1:
@ -364,17 +328,17 @@ async def initialize_new_game(
initialization_tasks.append(initialize_agent_state_ext(agent, game, game_history, llm_log_file_path, prompts_dir=args.prompts_dir))
except Exception as e:
logger.error(f"Failed to create agent or client for {power_name} with model {model_id}: {e}", exc_info=True)
logger.info(f"Running {len(initialization_tasks)} agent initializations concurrently...")
initialization_results = await asyncio.gather(*initialization_tasks, return_exceptions=True)
initialized_powers = list(agents.keys())
for i, result in enumerate(initialization_results):
if i < len(initialized_powers):
power_name = initialized_powers[i]
if isinstance(result, Exception):
logger.error(f"Failed to initialize agent state for {power_name}: {result}", exc_info=result)
else:
logger.info(f"Successfully initialized agent state for {power_name}.")
return agents
if i < len(initialized_powers):
power_name = initialized_powers[i]
if isinstance(result, Exception):
logger.error(f"Failed to initialize agent state for {power_name}: {result}", exc_info=result)
else:
logger.info(f"Successfully initialized agent state for {power_name}.")
return agents

View file

@ -1,12 +1,11 @@
# ai_diplomacy/initialization.py
import logging
import json
import os
from typing import Optional
from ..config import config
# Forward declaration for type hinting, actual imports in function if complex
if False: # TYPE_CHECKING
if False: # TYPE_CHECKING
from diplomacy import Game
from diplomacy.models.game import GameHistory
from .agent import DiplomacyAgent
@ -18,40 +17,40 @@ from .formatter import format_with_gemini_flash, FORMAT_INITIAL_STATE
logger = logging.getLogger(__name__)
async def initialize_agent_state_ext(
agent: 'DiplomacyAgent',
game: 'Game',
game_history: 'GameHistory',
agent: "DiplomacyAgent",
game: "Game",
game_history: "GameHistory",
log_file_path: str,
prompts_dir: Optional[str] = None,
):
"""Uses the LLM to set initial goals and relationships for the agent."""
power_name = agent.power_name
logger.info(f"[{power_name}] Initializing agent state using LLM (external function)..." )
logger.info(f"[{power_name}] Initializing agent state using LLM (external function)...")
current_phase = game.get_current_phase() if game else "UnknownPhase"
full_prompt = "" # Ensure full_prompt is defined in the outer scope for finally block
response = "" # Ensure response is defined for finally block
success_status = "Failure: Initialized" # Default status
response = "" # Ensure response is defined for finally block
success_status = "Failure: Initialized" # Default status
try:
# Load the prompt template
allowed_labels_str = ", ".join(ALLOWED_RELATIONSHIPS)
initial_prompt_template = _load_prompt_file(get_prompt_path('initial_state_prompt.txt'), prompts_dir=prompts_dir)
initial_prompt_template = _load_prompt_file(get_prompt_path("initial_state_prompt.txt"), prompts_dir=prompts_dir)
if not initial_prompt_template:
logger.error(f"[{power_name}] Could not load {get_prompt_path('initial_state_prompt.txt')}. Cannot initialize.")
return
# Format the prompt with variables
initial_prompt = initial_prompt_template.format(
power_name=power_name,
allowed_labels_str=allowed_labels_str
)
initial_prompt = initial_prompt_template.format(power_name=power_name, allowed_labels_str=allowed_labels_str)
board_state = game.get_state() if game else {}
possible_orders = game.get_all_possible_orders() if game else {}
logger.debug(f"[{power_name}] Preparing context for initial state. Board state type: {type(board_state)}, possible_orders type: {type(possible_orders)}, game_history type: {type(game_history)}")
logger.debug(
f"[{power_name}] Preparing context for initial state. Board state type: {type(board_state)}, possible_orders type: {type(possible_orders)}, game_history type: {type(game_history)}"
)
# Ensure agent.client and its methods can handle None for game/board_state/etc. if that's a possibility
# For initialization, game should always be present.
@ -59,12 +58,12 @@ async def initialize_agent_state_ext(
context = build_context_prompt(
game=game,
board_state=board_state,
board_state=board_state,
power_name=power_name,
possible_orders=possible_orders,
game_history=game_history,
agent_goals=None,
agent_relationships=None,
possible_orders=possible_orders,
game_history=game_history,
agent_goals=None,
agent_relationships=None,
agent_private_diary=formatted_diary,
prompts_dir=prompts_dir,
)
@ -76,9 +75,9 @@ async def initialize_agent_state_ext(
log_file_path=log_file_path,
power_name=power_name,
phase=current_phase,
response_type='initialization', # Context for run_llm_and_log internal error logging
response_type="initialization", # Context for run_llm_and_log internal error logging
)
logger.debug(f"[{power_name}] LLM response for initial state: {response[:300]}...") # Log a snippet
logger.debug(f"[{power_name}] LLM response for initial state: {response[:300]}...") # Log a snippet
parsed_successfully = False
try:
@ -86,11 +85,7 @@ async def initialize_agent_state_ext(
if config.USE_UNFORMATTED_PROMPTS:
# Format the natural language response into JSON
formatted_response = await format_with_gemini_flash(
response,
FORMAT_INITIAL_STATE,
power_name=power_name,
phase=current_phase,
log_file_path=log_file_path
response, FORMAT_INITIAL_STATE, power_name=power_name, phase=current_phase, log_file_path=log_file_path
)
else:
# Use the raw response directly (already formatted)
@ -101,20 +96,24 @@ async def initialize_agent_state_ext(
except json.JSONDecodeError as e:
logger.error(f"[{power_name}] All JSON extraction attempts failed: {e}. Response snippet: {response[:300]}...")
success_status = "Failure: JSONDecodeError"
update_data = {} # Ensure update_data exists for fallback logic below
parsed_successfully = False # Explicitly set here too
update_data = {} # Ensure update_data exists for fallback logic below
parsed_successfully = False # Explicitly set here too
# Fallback logic for goals/relationships will be handled later if update_data is empty
# Defensive check for update_data type if parsing was initially considered successful
if parsed_successfully:
if parsed_successfully:
if isinstance(update_data, str):
logger.error(f"[{power_name}] _extract_json_from_text returned a string, not a dict/list, despite not raising an exception. This indicates an unexpected parsing issue. String returned: {update_data[:300]}...")
update_data = {} # Treat as parsing failure
logger.error(
f"[{power_name}] _extract_json_from_text returned a string, not a dict/list, despite not raising an exception. This indicates an unexpected parsing issue. String returned: {update_data[:300]}..."
)
update_data = {} # Treat as parsing failure
parsed_successfully = False
success_status = "Failure: ParsedAsStr"
elif not isinstance(update_data, dict): # Expecting a dict from JSON object
logger.error(f"[{power_name}] _extract_json_from_text returned a non-dict type ({type(update_data)}), expected dict. Data: {str(update_data)[:300]}")
update_data = {} # Treat as parsing failure
elif not isinstance(update_data, dict): # Expecting a dict from JSON object
logger.error(
f"[{power_name}] _extract_json_from_text returned a non-dict type ({type(update_data)}), expected dict. Data: {str(update_data)[:300]}"
)
update_data = {} # Treat as parsing failure
parsed_successfully = False
success_status = "Failure: NotADict"
@ -122,8 +121,8 @@ async def initialize_agent_state_ext(
initial_relationships_applied = False
if parsed_successfully:
initial_goals = update_data.get('initial_goals') or update_data.get('goals')
initial_relationships = update_data.get('initial_relationships') or update_data.get('relationships')
initial_goals = update_data.get("initial_goals") or update_data.get("goals")
initial_relationships = update_data.get("initial_relationships") or update_data.get("relationships")
if isinstance(initial_goals, list) and initial_goals:
agent.goals = initial_goals
@ -152,30 +151,30 @@ async def initialize_agent_state_ext(
else:
logger.warning(f"[{power_name}] No valid relationships found in LLM response.")
else:
logger.warning(f"[{power_name}] LLM did not provide valid 'initial_relationships' dict (got: {initial_relationships}).")
logger.warning(f"[{power_name}] LLM did not provide valid 'initial_relationships' dict (got: {initial_relationships}).")
if initial_goals_applied or initial_relationships_applied:
success_status = "Success: Applied LLM data"
elif parsed_successfully: # Parsed but nothing useful to apply
elif parsed_successfully: # Parsed but nothing useful to apply
success_status = "Success: Parsed but no data applied"
# If not parsed_successfully, success_status is already "Failure: JSONDecodeError"
# Fallback if LLM data was not applied or parsing failed
if not initial_goals_applied:
if not agent.goals: # Only set defaults if no goals were set during agent construction or by LLM
if not agent.goals: # Only set defaults if no goals were set during agent construction or by LLM
agent.goals = ["Survive and expand", "Form beneficial alliances", "Secure key territories"]
agent.add_journal_entry(f"[{current_phase}] Set default initial goals as LLM provided none or parse failed.")
logger.info(f"[{power_name}] Default goals set.")
if not initial_relationships_applied:
# Check if relationships are still default-like before overriding
# Check if relationships are still default-like before overriding
is_default_relationships = True
if agent.relationships: # Check if it's not empty
if agent.relationships: # Check if it's not empty
for p in ALL_POWERS:
if p != power_name and agent.relationships.get(p) != "Neutral":
is_default_relationships = False
break
if is_default_relationships:
if is_default_relationships:
agent.relationships = {p: "Neutral" for p in ALL_POWERS if p != power_name}
agent.add_journal_entry(f"[{current_phase}] Set default neutral relationships as LLM provided none valid or parse failed.")
logger.info(f"[{power_name}] Default neutral relationships set.")
@ -191,16 +190,16 @@ async def initialize_agent_state_ext(
agent.relationships = {p: "Neutral" for p in ALL_POWERS if p != power_name}
logger.info(f"[{power_name}] Set fallback neutral relationships after top-level error: {agent.relationships}")
finally:
if log_file_path: # Ensure log_file_path is provided
if log_file_path: # Ensure log_file_path is provided
log_llm_response(
log_file_path=log_file_path,
model_name=agent.client.model_name if agent and agent.client else "UnknownModel",
power_name=power_name,
phase=current_phase,
response_type="initial_state_setup", # Specific type for CSV logging
response_type="initial_state_setup", # Specific type for CSV logging
raw_input_prompt=full_prompt,
raw_response=response,
success=success_status
success=success_status,
)
# Final log of state after initialization attempt

View file

@ -209,4 +209,3 @@ def _patched_generate_phase_summary(self: Game, phase_key, summary_callback=None
Game._generate_phase_summary = _patched_generate_phase_summary # type: ignore[assignment]
LOGGER.info("Game._generate_phase_summary patched with narrative generation.")

View file

@ -6,8 +6,7 @@ from typing import Dict, TYPE_CHECKING
from diplomacy.engine.message import Message, GLOBAL
from .agent import DiplomacyAgent
from .clients import load_model_client
from .utils import gather_possible_orders, load_prompt, normalize_recipient_name
from .utils import gather_possible_orders, normalize_recipient_name
if TYPE_CHECKING:
from .game_history import GameHistory
@ -21,9 +20,9 @@ load_dotenv()
async def conduct_negotiations(
game: 'Game',
game: "Game",
agents: Dict[str, DiplomacyAgent],
game_history: 'GameHistory',
game_history: "GameHistory",
model_error_stats: Dict[str, Dict[str, int]],
log_file_path: str,
max_rounds: int = 3,
@ -35,13 +34,9 @@ async def conduct_negotiations(
"""
logger.info("Starting negotiation phase.")
active_powers = [
p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()
]
eliminated_powers = [
p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()
]
active_powers = [p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()]
eliminated_powers = [p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()]
logger.info(f"Active powers for negotiations: {active_powers}")
if eliminated_powers:
logger.info(f"Eliminated powers (skipped): {eliminated_powers}")
@ -51,7 +46,7 @@ async def conduct_negotiations(
# We do up to 'max_rounds' single-message turns for each power
for round_index in range(max_rounds):
logger.info(f"Negotiation Round {round_index + 1}/{max_rounds}")
# Prepare tasks for asyncio.gather
tasks = []
power_names_for_tasks = []
@ -99,28 +94,28 @@ async def conduct_negotiations(
# Process results
for i, result in enumerate(results):
power_name = power_names_for_tasks[i]
agent = agents[power_name] # Get agent again for journaling
model_name = agent.client.model_name # Get model name for stats
agent = agents[power_name] # Get agent again for journaling
model_name = agent.client.model_name # Get model name for stats
if isinstance(result, Exception):
logger.error(f"Error getting conversation reply for {power_name}: {result}", exc_info=result)
# Use model_name for stats key if possible
if model_name in model_error_stats:
model_error_stats[model_name]["conversation_errors"] += 1
else: # Fallback to power_name if model name not tracked (shouldn't happen)
model_error_stats.setdefault(power_name, {}).setdefault("conversation_errors", 0)
model_error_stats[power_name]["conversation_errors"] += 1
messages = [] # Treat as no messages on error
elif result is None: # Handle case where client might return None on internal error
logger.warning(f"Received None instead of messages for {power_name}.")
messages = []
if model_name in model_error_stats:
model_error_stats[model_name]["conversation_errors"] += 1
else:
model_error_stats.setdefault(power_name, {}).setdefault("conversation_errors", 0)
model_error_stats[power_name]["conversation_errors"] += 1
model_error_stats[model_name]["conversation_errors"] += 1
else: # Fallback to power_name if model name not tracked (shouldn't happen)
model_error_stats.setdefault(power_name, {}).setdefault("conversation_errors", 0)
model_error_stats[power_name]["conversation_errors"] += 1
messages = [] # Treat as no messages on error
elif result is None: # Handle case where client might return None on internal error
logger.warning(f"Received None instead of messages for {power_name}.")
messages = []
if model_name in model_error_stats:
model_error_stats[model_name]["conversation_errors"] += 1
else:
model_error_stats.setdefault(power_name, {}).setdefault("conversation_errors", 0)
model_error_stats[power_name]["conversation_errors"] += 1
else:
messages = result # result is the list of message dicts
messages = result # result is the list of message dicts
logger.debug(f"Received {len(messages)} message(s) from {power_name}.")
# Process the received messages (same logic as before)
@ -134,27 +129,27 @@ async def conduct_negotiations(
# Create an official message in the Diplomacy engine
# Determine recipient based on message type
if message.get("message_type") == "private":
recipient = normalize_recipient_name(message.get("recipient", GLOBAL)) # Default to GLOBAL if recipient missing somehow
recipient = normalize_recipient_name(message.get("recipient", GLOBAL)) # Default to GLOBAL if recipient missing somehow
if recipient not in game.powers and recipient != GLOBAL:
logger.warning(f"Invalid recipient '{recipient}' in message from {power_name}. Sending globally.")
recipient = GLOBAL # Fallback to GLOBAL if recipient power is invalid
else: # Assume global if not private or type is missing
recipient = GLOBAL # Fallback to GLOBAL if recipient power is invalid
else: # Assume global if not private or type is missing
recipient = GLOBAL
diplo_message = Message(
phase=game.current_short_phase,
sender=power_name,
recipient=recipient, # Use determined recipient
message=message.get("content", ""), # Use .get for safety
time_sent=None, # Let the engine assign time
recipient=recipient, # Use determined recipient
message=message.get("content", ""), # Use .get for safety
time_sent=None, # Let the engine assign time
)
game.add_message(diplo_message)
# Also add to our custom history
game_history.add_message(
game.current_short_phase,
power_name,
recipient, # Use determined recipient here too
message.get("content", ""), # Use .get for safety
recipient, # Use determined recipient here too
message.get("content", ""), # Use .get for safety
)
journal_recipient = f"to {recipient}" if recipient != GLOBAL else "globally"
agent.add_journal_entry(f"Sent message {journal_recipient} in {game.current_short_phase}: {message.get('content', '')[:100]}...")

View file

@ -1,18 +1,17 @@
from dotenv import load_dotenv
import logging
import concurrent.futures
from typing import Dict
from .clients import load_model_client
from .game_history import GameHistory
from .agent import DiplomacyAgent
logger = logging.getLogger(__name__)
async def planning_phase(
game,
agents: Dict[str, DiplomacyAgent],
game_history: GameHistory,
game,
agents: Dict[str, DiplomacyAgent],
game_history: GameHistory,
model_error_stats,
log_file_path: str,
):
@ -20,24 +19,18 @@ async def planning_phase(
Lets each power generate a strategic plan using their DiplomacyAgent.
"""
logger.info(f"Starting planning phase for {game.current_short_phase}...")
active_powers = [
p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()
]
eliminated_powers = [
p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()
]
active_powers = [p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()]
eliminated_powers = [p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()]
logger.info(f"Active powers for planning: {active_powers}")
if eliminated_powers:
logger.info(f"Eliminated powers (skipped): {eliminated_powers}")
else:
logger.info("No eliminated powers yet.")
board_state = game.get_state()
with concurrent.futures.ThreadPoolExecutor(
max_workers=len(active_powers)
) as executor:
with concurrent.futures.ThreadPoolExecutor(max_workers=len(active_powers)) as executor:
futures = {}
for power_name in active_powers:
if power_name not in agents:
@ -45,7 +38,7 @@ async def planning_phase(
continue
agent = agents[power_name]
client = agent.client
future = executor.submit(
client.get_plan,
game,
@ -67,20 +60,18 @@ async def planning_phase(
try:
plan_result = future.result()
logger.info(f"Received planning result from {power_name}.")
if plan_result.startswith("Error:"):
logger.warning(f"Agent {power_name} reported an error during planning: {plan_result}")
if power_name in model_error_stats:
model_error_stats[power_name].setdefault('planning_generation_errors', 0)
model_error_stats[power_name]['planning_generation_errors'] += 1
else:
model_error_stats.setdefault(f'{power_name}_planning_generation_errors', 0)
model_error_stats[f'{power_name}_planning_generation_errors'] += 1
logger.warning(f"Agent {power_name} reported an error during planning: {plan_result}")
if power_name in model_error_stats:
model_error_stats[power_name].setdefault("planning_generation_errors", 0)
model_error_stats[power_name]["planning_generation_errors"] += 1
else:
model_error_stats.setdefault(f"{power_name}_planning_generation_errors", 0)
model_error_stats[f"{power_name}_planning_generation_errors"] += 1
elif plan_result:
agent.add_journal_entry(f"Generated plan for {game.current_short_phase}: {plan_result[:100]}...")
game_history.add_plan(
game.current_short_phase, power_name, plan_result
)
game_history.add_plan(game.current_short_phase, power_name, plan_result)
logger.debug(f"Added plan for {power_name} to history.")
else:
logger.warning(f"Agent {power_name} returned an empty plan.")
@ -88,11 +79,11 @@ async def planning_phase(
except Exception as e:
logger.error(f"Exception during planning result processing for {power_name}: {e}")
if power_name in model_error_stats:
model_error_stats[power_name].setdefault('planning_execution_errors', 0)
model_error_stats[power_name]['planning_execution_errors'] += 1
model_error_stats[power_name].setdefault("planning_execution_errors", 0)
model_error_stats[power_name]["planning_execution_errors"] += 1
else:
model_error_stats.setdefault(f'{power_name}_planning_execution_errors', 0)
model_error_stats[f'{power_name}_planning_execution_errors'] += 1
model_error_stats.setdefault(f"{power_name}_planning_execution_errors", 0)
model_error_stats[f"{power_name}_planning_execution_errors"] += 1
logger.info("Planning phase processing complete.")
return game_history
return game_history

View file

@ -8,12 +8,13 @@ import logging
import re
# Placeholder for actual map type from diplomacy.engine.map.Map
# GameMap = Any
# GameMap = Any
# Type hint for board_state dictionary from game.get_state()
# BoardState = Dict[str, Any]
logger = logging.getLogger(__name__)
def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]:
"""
Return graph[PROV]['ARMY'|'FLEET'] = list of 3-letter neighbour provinces.
@ -22,20 +23,18 @@ def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]:
"""
# ── collect all 3-letter province codes ───────────────────────────────
provs: Set[str] = {
loc.split("/")[0][:3].upper() # 'BUL/EC' -> 'BUL'
loc.split("/")[0][:3].upper() # 'BUL/EC' -> 'BUL'
for loc in game_map.locs
if len(loc.split("/")[0]) == 3
}
graph: Dict[str, Dict[str, List[str]]] = {
p: {"ARMY": [], "FLEET": []} for p in provs
}
graph: Dict[str, Dict[str, List[str]]] = {p: {"ARMY": [], "FLEET": []} for p in provs}
# ── helper: list every concrete variant of a province ─────────────────
def variants(code: str) -> List[str]:
lst = list(game_map.loc_coasts.get(code, []))
if code not in lst:
lst.append(code) # ensure base node included
lst.append(code) # ensure base node included
return lst
# ── populate adjacency by brute-force queries to `abuts()` ────────────
@ -55,48 +54,43 @@ def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]:
graph[src]["ARMY"].append(dest)
# FLEETany src variant that can host a fleet is valid
if any(
game_map.abuts("F", sv, "-", dv)
for sv in src_vers
for dv in dest_vers
):
if any(game_map.abuts("F", sv, "-", dv) for sv in src_vers for dv in dest_vers):
graph[src]["FLEET"].append(dest)
# ── tidy up duplicates / order ---------------------------------------
for p in graph:
graph[p]["ARMY"] = sorted(set(graph[p]["ARMY"]))
graph[p]["ARMY"] = sorted(set(graph[p]["ARMY"]))
graph[p]["FLEET"] = sorted(set(graph[p]["FLEET"]))
return graph
def bfs_shortest_path(
graph: Dict[str, Dict[str, List[str]]],
board_state: BoardState,
game_map: GameMap, # Added game_map
start_loc_full: str, # This is a FULL location name like 'VIE' or 'STP/SC'
unit_type: str,
is_target_func: Callable[[str, BoardState], bool] # Expects SHORT name for loc
) -> Optional[List[str]]: # Returns path of SHORT names
graph: Dict[str, Dict[str, List[str]]],
board_state: BoardState,
game_map: GameMap, # Added game_map
start_loc_full: str, # This is a FULL location name like 'VIE' or 'STP/SC'
unit_type: str,
is_target_func: Callable[[str, BoardState], bool], # Expects SHORT name for loc
) -> Optional[List[str]]: # Returns path of SHORT names
"""Performs BFS to find the shortest path from start_loc to a target satisfying is_target_func."""
# Convert full start location to short province name
start_loc_short = game_map.loc_name.get(start_loc_full, start_loc_full)
if '/' in start_loc_short: # If it was STP/SC, loc_name gives STP. If it was VIE, loc_name gives VIE.
if "/" in start_loc_short: # If it was STP/SC, loc_name gives STP. If it was VIE, loc_name gives VIE.
start_loc_short = start_loc_short[:3]
# If start_loc_full was already short (e.g. 'VIE'), get might return it as is, or its value if it was a key.
# A simpler way for non-coastal full (like 'VIE') or already short:
if '/' not in start_loc_full:
start_loc_short = start_loc_full[:3] # Ensures 'VIE' -> 'VIE', 'PAR' -> 'PAR'
else: # Has '/', e.g. 'STP/SC'
start_loc_short = start_loc_full[:3] # 'STP/SC' -> 'STP'
if "/" not in start_loc_full:
start_loc_short = start_loc_full[:3] # Ensures 'VIE' -> 'VIE', 'PAR' -> 'PAR'
else: # Has '/', e.g. 'STP/SC'
start_loc_short = start_loc_full[:3] # 'STP/SC' -> 'STP'
if start_loc_short not in graph:
logger.warning(f"BFS: Start province {start_loc_short} (from {start_loc_full}) not in graph. Pathfinding may fail.")
return None
queue: deque[Tuple[str, List[str]]] = deque([(start_loc_short, [start_loc_short])])
queue: deque[Tuple[str, List[str]]] = deque([(start_loc_short, [start_loc_short])])
visited_nodes: Set[str] = {start_loc_short}
while queue:
@ -104,14 +98,14 @@ def bfs_shortest_path(
# is_target_func expects a short location name
if is_target_func(current_loc_short, board_state):
return path # Path of short names
return path # Path of short names
# possible_neighbors are SHORT names from the graph
possible_neighbors_short = graph.get(current_loc_short, {}).get(unit_type, [])
for next_loc_short in possible_neighbors_short:
if next_loc_short not in visited_nodes:
if next_loc_short not in graph: # Defensive check for neighbors not in graph keys
if next_loc_short not in graph: # Defensive check for neighbors not in graph keys
logger.warning(f"BFS: Neighbor {next_loc_short} of {current_loc_short} not in graph. Skipping.")
continue
visited_nodes.add(next_loc_short)
@ -119,11 +113,12 @@ def bfs_shortest_path(
queue.append((next_loc_short, new_path))
return None
# --- Helper functions for context generation ---
def get_unit_at_location(board_state: BoardState, location: str) -> Optional[str]:
"""Returns the full unit string (e.g., 'A PAR (FRA)') if a unit is at the location, else None."""
for power, unit_list in board_state.get('units', {}).items():
for unit_str in unit_list: # e.g., "A PAR", "F STP/SC"
for power, unit_list in board_state.get("units", {}).items():
for unit_str in unit_list: # e.g., "A PAR", "F STP/SC"
parts = unit_str.split(" ")
if len(parts) == 2:
unit_map_loc = parts[1]
@ -131,27 +126,29 @@ def get_unit_at_location(board_state: BoardState, location: str) -> Optional[str
return f"{parts[0]} {location} ({power})"
return None
def get_sc_controller(game_map: GameMap, board_state: BoardState, location: str) -> Optional[str]:
"""Returns the controlling power's name if the location is an SC, else None."""
# Normalize location to base province name, as SCs are tied to provinces, not specific coasts
loc_province_name = game_map.loc_name.get(location, location).upper()[:3]
if loc_province_name not in game_map.scs:
return None
for power, sc_list in board_state.get('centers', {}).items():
for power, sc_list in board_state.get("centers", {}).items():
if loc_province_name in sc_list:
return power
return None # Unowned SC
return None # Unowned SC
def get_shortest_path_to_friendly_unit(
board_state: BoardState,
board_state: BoardState,
graph: Dict[str, Dict[str, List[str]]],
game_map: GameMap, # Added game_map
power_name: str,
start_unit_loc_full: str,
start_unit_type: str
game_map: GameMap, # Added game_map
power_name: str,
start_unit_loc_full: str,
start_unit_type: str,
) -> Optional[Tuple[str, List[str]]]:
"""Finds the shortest path to any friendly unit of the same power."""
def is_target_friendly(loc_short: str, current_board_state: BoardState) -> bool:
# loc_short is a short province name. Need to check all its full locations.
full_locs_for_short = game_map.loc_coasts.get(loc_short, [loc_short])
@ -162,7 +159,7 @@ def get_shortest_path_to_friendly_unit(
return False
path_short_names = bfs_shortest_path(graph, board_state, game_map, start_unit_loc_full, start_unit_type, is_target_friendly)
if path_short_names and len(path_short_names) > 1: # Path includes start, so > 1 means a distinct friendly unit found
if path_short_names and len(path_short_names) > 1: # Path includes start, so > 1 means a distinct friendly unit found
target_loc_short = path_short_names[-1]
# Find the actual friendly unit string at one of the full locations of target_loc_short
friendly_unit_str = "UNKNOWN_FRIENDLY_UNIT"
@ -177,41 +174,41 @@ def get_shortest_path_to_friendly_unit(
def get_nearest_enemy_units(
board_state: BoardState,
board_state: BoardState,
graph: Dict[str, Dict[str, List[str]]],
game_map: GameMap, # Added game_map
power_name: str,
start_unit_loc_full: str,
start_unit_type: str,
n: int = 3
game_map: GameMap, # Added game_map
power_name: str,
start_unit_loc_full: str,
start_unit_type: str,
n: int = 3,
) -> List[Tuple[str, List[str]]]:
"""Finds up to N nearest enemy units, sorted by path length."""
enemy_paths: List[Tuple[str, List[str]]] = [] # (enemy_unit_str, path_short_names)
all_enemy_unit_locations_full: List[Tuple[str,str]] = [] # (loc_full, unit_str_full)
enemy_paths: List[Tuple[str, List[str]]] = [] # (enemy_unit_str, path_short_names)
all_enemy_unit_locations_full: List[Tuple[str, str]] = [] # (loc_full, unit_str_full)
# board_state.get("units", {}) has format: { "POWER_NAME": ["A PAR", "F BRE"], ... }
for p_name, unit_list_for_power in board_state.get("units", {}).items():
if p_name != power_name: # If it's an enemy power
for unit_repr_from_state in unit_list_for_power: # e.g., "A PAR" or "F STP/SC"
if p_name != power_name: # If it's an enemy power
for unit_repr_from_state in unit_list_for_power: # e.g., "A PAR" or "F STP/SC"
parts = unit_repr_from_state.split(" ")
if len(parts) == 2:
# unit_type_char = parts[0] # 'A' or 'F'
loc_full = parts[1] # 'PAR' or 'STP/SC'
loc_full = parts[1] # 'PAR' or 'STP/SC'
# Use get_unit_at_location to get the consistent full unit string like "A PAR (POWER_NAME)"
full_unit_str_with_power = get_unit_at_location(board_state, loc_full)
if full_unit_str_with_power: # Should find the unit if iteration is correct
all_enemy_unit_locations_full.append((loc_full, full_unit_str_with_power))
if full_unit_str_with_power: # Should find the unit if iteration is correct
all_enemy_unit_locations_full.append((loc_full, full_unit_str_with_power))
for target_enemy_loc_full, enemy_unit_str in all_enemy_unit_locations_full:
target_enemy_loc_short = game_map.loc_name.get(target_enemy_loc_full, target_enemy_loc_full)
if '/' in target_enemy_loc_short:
if "/" in target_enemy_loc_short:
target_enemy_loc_short = target_enemy_loc_short[:3]
if '/' not in target_enemy_loc_full:
if "/" not in target_enemy_loc_full:
target_enemy_loc_short = target_enemy_loc_full[:3]
else:
target_enemy_loc_short = target_enemy_loc_full[:3]
def is_specific_enemy_loc(loc_short: str, current_board_state: BoardState) -> bool:
# Check if loc_short corresponds to target_enemy_loc_full
return loc_short == target_enemy_loc_short
@ -219,8 +216,8 @@ def get_nearest_enemy_units(
path_short_names = bfs_shortest_path(graph, board_state, game_map, start_unit_loc_full, start_unit_type, is_specific_enemy_loc)
if path_short_names:
enemy_paths.append((enemy_unit_str, path_short_names))
enemy_paths.sort(key=lambda x: len(x[1])) # Sort by path length
enemy_paths.sort(key=lambda x: len(x[1])) # Sort by path length
return enemy_paths[:n]
@ -242,10 +239,10 @@ def get_nearest_uncontrolled_scs(
"""
results: List[Tuple[str, int, List[str]]] = []
for sc_short in game_map.scs: # all SC province codes
for sc_short in game_map.scs: # all SC province codes
controller = get_sc_controller(game_map, board_state, sc_short)
if controller == power_name:
continue # already ours
continue # already ours
# helper for BFS target test
def is_target(loc_short: str, _state: BoardState) -> bool:
@ -260,9 +257,9 @@ def get_nearest_uncontrolled_scs(
is_target,
)
if not path:
continue # unreachable
continue # unreachable
distance = len(path) - 1 # moves needed
distance = len(path) - 1 # moves needed
# skip distance 0 (same province) and 1 (adjacent)
if distance <= 1:
@ -277,11 +274,11 @@ def get_nearest_uncontrolled_scs(
def get_adjacent_territory_details(
game_map: GameMap,
board_state: BoardState,
unit_loc_full: str, # The location of the unit whose adjacencies we're checking
unit_type: str, # ARMY or FLEET of the unit at unit_loc_full
graph: Dict[str, Dict[str, List[str]]]
game_map: GameMap,
board_state: BoardState,
unit_loc_full: str, # The location of the unit whose adjacencies we're checking
unit_type: str, # ARMY or FLEET of the unit at unit_loc_full
graph: Dict[str, Dict[str, List[str]]],
) -> str:
"""Generates a string describing adjacent territories and units that can interact with them."""
output_lines: List[str] = []
@ -291,37 +288,37 @@ def get_adjacent_territory_details(
# For fleets, graph[unit_loc_full]['FLEET'] gives full loc names (incl coasts)
# THIS COMMENT IS NOW OUTDATED. Graph uses short names for keys and values.
unit_loc_short = game_map.loc_name.get(unit_loc_full, unit_loc_full)
if '/' in unit_loc_short:
if "/" in unit_loc_short:
unit_loc_short = unit_loc_short[:3]
if '/' not in unit_loc_full:
if "/" not in unit_loc_full:
unit_loc_short = unit_loc_full[:3]
else:
unit_loc_short = unit_loc_full[:3]
adjacent_locs_short_for_unit = graph.get(unit_loc_short, {}).get(unit_type, [])
adjacent_locs_short_for_unit = graph.get(unit_loc_short, {}).get(unit_type, [])
processed_adj_provinces = set() # To handle cases like STP/NC and STP/SC both being adjacent to BOT
processed_adj_provinces = set() # To handle cases like STP/NC and STP/SC both being adjacent to BOT
for adj_loc_short in adjacent_locs_short_for_unit: # adj_loc_short is already short
for adj_loc_short in adjacent_locs_short_for_unit: # adj_loc_short is already short
# adj_province_short = game_map.loc_name.get(adj_loc_full, adj_loc_full).upper()[:3] # No longer needed
if adj_loc_short in processed_adj_provinces: # adj_loc_short is already short and upper implicitly by map data
if adj_loc_short in processed_adj_provinces: # adj_loc_short is already short and upper implicitly by map data
continue
processed_adj_provinces.add(adj_loc_short)
adj_loc_type = game_map.loc_type.get(adj_loc_short, 'UNKNOWN').upper()
if adj_loc_type == 'COAST' or adj_loc_type == 'LAND':
adj_loc_type_display = 'LAND' if adj_loc_type == 'LAND' else 'COAST'
elif adj_loc_type == 'WATER':
adj_loc_type_display = 'WATER'
else: # SHUT etc.
adj_loc_type = game_map.loc_type.get(adj_loc_short, "UNKNOWN").upper()
if adj_loc_type == "COAST" or adj_loc_type == "LAND":
adj_loc_type_display = "LAND" if adj_loc_type == "LAND" else "COAST"
elif adj_loc_type == "WATER":
adj_loc_type_display = "WATER"
else: # SHUT etc.
adj_loc_type_display = adj_loc_type
line = f" {adj_loc_short} ({adj_loc_type_display})"
sc_controller = get_sc_controller(game_map, board_state, adj_loc_short)
if sc_controller:
line += f" SC Control: {sc_controller}"
unit_in_adj_loc = get_unit_at_location(board_state, adj_loc_short)
if unit_in_adj_loc:
line += f" Units: {unit_in_adj_loc}"
@ -330,15 +327,14 @@ def get_adjacent_territory_details(
# "Can support/move to" - Simplified: list units in *further* adjacent provinces
# A true "can support/move to" would require checking possible orders of those further units.
# further_adj_provinces are short names from the graph
further_adj_provinces_short = graph.get(adj_loc_short, {}).get('ARMY', []) + \
graph.get(adj_loc_short, {}).get('FLEET', [])
further_adj_provinces_short = graph.get(adj_loc_short, {}).get("ARMY", []) + graph.get(adj_loc_short, {}).get("FLEET", [])
supporting_units_info = []
processed_further_provinces = set()
for further_adj_loc_short in further_adj_provinces_short:
# further_adj_province_short = game_map.loc_name.get(further_adj_loc_full, further_adj_loc_full).upper()[:3]
# No conversion needed, it's already short
if further_adj_loc_short == adj_loc_short or further_adj_loc_short == unit_loc_short: # Don't list itself or origin
if further_adj_loc_short == adj_loc_short or further_adj_loc_short == unit_loc_short: # Don't list itself or origin
continue
if further_adj_loc_short in processed_further_provinces:
continue
@ -354,14 +350,14 @@ def get_adjacent_territory_details(
temp_unit = get_unit_at_location(board_state, fv_further)
if temp_unit:
unit_in_further_loc = temp_unit
break # Found a unit in one of the coasts/base
break # Found a unit in one of the coasts/base
# if not unit_in_further_loc and further_adj_loc_full != further_adj_province_short:
# unit_in_further_loc = get_unit_at_location(board_state, further_adj_province_short)
if unit_in_further_loc:
supporting_units_info.append(unit_in_further_loc)
if supporting_units_info:
output_lines.append(f" => Can support/move to: {', '.join(sorted(list(set(supporting_units_info))))}")
@ -378,30 +374,30 @@ def generate_rich_order_context_xml(game: Any, power_name: str, possible_orders_
board_state: BoardState = game.get_state()
game_map: GameMap = game.map
graph = build_diplomacy_graph(game_map)
final_context_lines: List[str] = ["<PossibleOrdersContext>"]
# Iterate through units that have orders (keys of possible_orders_for_power are unit locations)
for unit_loc_full, unit_specific_possible_orders in possible_orders_for_power.items():
unit_str_full = get_unit_at_location(board_state, unit_loc_full)
if not unit_str_full: # Should not happen if unit_loc_full is from possible_orders keys
continue
if not unit_str_full: # Should not happen if unit_loc_full is from possible_orders keys
continue
unit_type_char = unit_str_full.split(" ")[0] # 'A' or 'F'
unit_type_long = "ARMY" if unit_type_char == 'A' else "FLEET"
unit_type_char = unit_str_full.split(" ")[0] # 'A' or 'F'
unit_type_long = "ARMY" if unit_type_char == "A" else "FLEET"
loc_province_short = game_map.loc_name.get(unit_loc_full, unit_loc_full).upper()[:3]
loc_type_short = game_map.loc_type.get(loc_province_short, "UNKNOWN").upper()
if loc_type_short == 'COAST' or loc_type_short == 'LAND':
loc_type_display = 'LAND' if loc_type_short == 'LAND' else 'COAST'
if loc_type_short == "COAST" or loc_type_short == "LAND":
loc_type_display = "LAND" if loc_type_short == "LAND" else "COAST"
else:
loc_type_display = loc_type_short
current_unit_lines: List[str] = []
current_unit_lines.append(f' <UnitContext loc="{unit_loc_full}">')
# Unit Information section
current_unit_lines.append(' <UnitInformation>')
current_unit_lines.append(" <UnitInformation>")
sc_owner_at_loc = get_sc_controller(game_map, board_state, unit_loc_full)
header_content = f"Strategic territory held by {power_name}: {unit_loc_full} ({loc_type_display})"
if sc_owner_at_loc == power_name:
@ -410,55 +406,61 @@ def generate_rich_order_context_xml(game: Any, power_name: str, possible_orders_
header_content += f" (SC controlled by {sc_owner_at_loc})"
current_unit_lines.append(f" {header_content}")
current_unit_lines.append(f" Units present: {unit_str_full}")
current_unit_lines.append(' </UnitInformation>')
current_unit_lines.append(" </UnitInformation>")
# Possible moves section
current_unit_lines.append(' <PossibleMoves>')
current_unit_lines.append(" <PossibleMoves>")
current_unit_lines.append(" Possible moves:")
for order_str in unit_specific_possible_orders:
current_unit_lines.append(f" {order_str}")
current_unit_lines.append(' </PossibleMoves>')
current_unit_lines.append(" </PossibleMoves>")
# Nearest enemy units section
enemy_units_info = get_nearest_enemy_units(board_state, graph, game_map, power_name, unit_loc_full, unit_type_long, n=3)
current_unit_lines.append(' <NearestEnemyUnits>')
current_unit_lines.append(" <NearestEnemyUnits>")
if enemy_units_info:
current_unit_lines.append(" Nearest units (not ours):")
for enemy_unit_str, enemy_path_short in enemy_units_info:
current_unit_lines.append(f" {enemy_unit_str}, path=[{unit_loc_full}{(''.join(enemy_path_short[1:])) if len(enemy_path_short) > 1 else enemy_path_short[0]}]")
current_unit_lines.append(
f" {enemy_unit_str}, path=[{unit_loc_full}{(''.join(enemy_path_short[1:])) if len(enemy_path_short) > 1 else enemy_path_short[0]}]"
)
else:
current_unit_lines.append(" Nearest units (not ours): None found")
current_unit_lines.append(' </NearestEnemyUnits>')
current_unit_lines.append(" </NearestEnemyUnits>")
# Nearest supply centers (not controlled by us) section
uncontrolled_scs_info = get_nearest_uncontrolled_scs(game_map, board_state, graph, power_name, unit_loc_full, unit_type_long, n=3)
current_unit_lines.append(' <NearestUncontrolledSupplyCenters>')
current_unit_lines.append(" <NearestUncontrolledSupplyCenters>")
if uncontrolled_scs_info:
current_unit_lines.append(" Nearest supply centers (not controlled by us):")
for sc_str, dist, sc_path_short in uncontrolled_scs_info:
current_unit_lines.append(f" {sc_str}, dist={dist}, path=[{unit_loc_full}{(''.join(sc_path_short[1:])) if len(sc_path_short) > 1 else sc_path_short[0]}]")
current_unit_lines.append(
f" {sc_str}, dist={dist}, path=[{unit_loc_full}{(''.join(sc_path_short[1:])) if len(sc_path_short) > 1 else sc_path_short[0]}]"
)
else:
current_unit_lines.append(" Nearest supply centers (not controlled by us): None found")
current_unit_lines.append(' </NearestUncontrolledSupplyCenters>')
current_unit_lines.append(" </NearestUncontrolledSupplyCenters>")
# Adjacent territories details section
adj_details_str = get_adjacent_territory_details(game_map, board_state, unit_loc_full, unit_type_long, graph)
current_unit_lines.append(' <AdjacentTerritories>')
current_unit_lines.append(" <AdjacentTerritories>")
if adj_details_str:
current_unit_lines.append(" Adjacent territories (including units that can support/move to the adjacent territory):")
# Assuming adj_details_str is already formatted with newlines and indentation for its content
# We might need to indent adj_details_str if it's a single block of text
# For now, let's add a standard indent to each line of adj_details_str if it contains newlines
if '\n' in adj_details_str:
indented_adj_details = "\n".join([f" {line}" for line in adj_details_str.split('\n')])
if "\n" in adj_details_str:
indented_adj_details = "\n".join([f" {line}" for line in adj_details_str.split("\n")])
current_unit_lines.append(indented_adj_details)
else:
current_unit_lines.append(f" {adj_details_str}")
current_unit_lines.append(f" {adj_details_str}")
else:
current_unit_lines.append(" Adjacent territories: None relevant or all are empty/uncontested by direct threats.") # Added more descriptive else
current_unit_lines.append(' </AdjacentTerritories>')
current_unit_lines.append(' </UnitContext>')
current_unit_lines.append(
" Adjacent territories: None relevant or all are empty/uncontested by direct threats."
) # Added more descriptive else
current_unit_lines.append(" </AdjacentTerritories>")
current_unit_lines.append(" </UnitContext>")
final_context_lines.extend(current_unit_lines)
final_context_lines.append("</PossibleOrdersContext>")
@ -469,16 +471,16 @@ def generate_rich_order_context_xml(game: Any, power_name: str, possible_orders_
# Regex and tiny helpers
# ---------------------------------------------------------------------------
import re
from typing import Tuple, List, Dict, Optional, Any
# ── order-syntax matchers ─────────────────────────────────────────────────
_SIMPLE_MOVE_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? - [A-Z]{3}(?:/[A-Z]{2})?$")
_HOLD_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? H$") # NEW
_RETREAT_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? R [A-Z]{3}(?:/[A-Z]{2})?$")
_ADJUST_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? [BD]$") # build / disband
_HOLD_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? H$") # NEW
_RETREAT_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? R [A-Z]{3}(?:/[A-Z]{2})?$")
_ADJUST_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? [BD]$") # build / disband
def _is_hold_order(order: str) -> bool: # NEW
def _is_hold_order(order: str) -> bool: # NEW
return bool(_HOLD_RE.match(order.strip()))
@ -486,15 +488,19 @@ def _norm_power(name: str) -> str:
"""Trim & uppercase for reliable comparisons."""
return name.strip().upper()
def _is_simple_move(order: str) -> bool:
return bool(_SIMPLE_MOVE_RE.match(order.strip()))
def _is_retreat_order(order: str) -> bool:
return bool(_RETREAT_RE.match(order.strip()))
def _is_adjust_order(order: str) -> bool:
return bool(_ADJUST_RE.match(order.strip()))
def _split_move(order: str) -> Tuple[str, str]:
"""Return ('A BUD', 'TRI') from 'A BUD - TRI' (validated move only)."""
unit_part, dest = order.split(" - ")
@ -505,6 +511,7 @@ def _split_move(order: str) -> Tuple[str, str]:
# Gather *all* friendly support orders for a given move
# ---------------------------------------------------------------------------
def _all_support_examples(
mover: str,
dest: str,
@ -526,6 +533,7 @@ def _all_support_examples(
return supports
def _all_support_hold_examples(
holder: str,
all_orders: Dict[str, List[str]],
@ -538,7 +546,7 @@ def _all_support_hold_examples(
supports: List[str] = []
for loc, orders in all_orders.items():
if holder.endswith(loc): # skip the holding unit
if holder.endswith(loc): # skip the holding unit
continue
for o in orders:
if o.strip().endswith(target):
@ -550,6 +558,7 @@ def _all_support_hold_examples(
# Province-type resolver (handles short codes, coasts, seas)
# ---------------------------------------------------------------------------
def _province_type_display(game_map, prov_short: str) -> str:
"""
Return 'LAND', 'COAST', or 'WATER' for the 3-letter province code.
@ -575,7 +584,7 @@ def _dest_occupancy_desc(
board_state,
our_power: str,
) -> str:
""" '(occupied by X)', '(occupied by X — you!)', or '(unoccupied)' """
"""'(occupied by X)', '(occupied by X — you!)', or '(unoccupied)'"""
occupant: Optional[str] = None
for full in game_map.loc_coasts.get(dest_short, [dest_short]):
u = get_unit_at_location(board_state, full)
@ -593,6 +602,7 @@ def _dest_occupancy_desc(
# Adjacent-territory lines (used by movement-phase builder)
# ---------------------------------------------------------------------------
def _adjacent_territory_lines(
graph,
game_map,
@ -633,13 +643,9 @@ def _adjacent_territory_lines(
pwr = unit_here.split(" ")[-1].strip("()")
if pwr == our_power:
friend_descr = unit_here.split(" (")[0]
lines.append(
f"{indent2}Support hold: {mover_descr} S {friend_descr}"
)
lines.append(f"{indent2}Support hold: {mover_descr} S {friend_descr}")
else:
lines.append(
f"{indent2}-> {unit_here} can support or contest {mover_descr}s moves and vice-versa"
)
lines.append(f"{indent2}-> {unit_here} can support or contest {mover_descr}s moves and vice-versa")
return lines
@ -648,6 +654,7 @@ def _adjacent_territory_lines(
# Movement-phase generator (UNCHANGED LOGIC)
# ---------------------------------------------------------------------------
def _generate_rich_order_context_movement(
game: Any,
power_name: str,
@ -672,18 +679,14 @@ def _generate_rich_order_context_movement(
if _norm_power(unit_power) != me:
continue # Skip units that arent ours
mover_descr, _ = _split_move(
f"{unit_full_str.split(' ')[0]} {unit_loc_full} - {unit_loc_full}"
)
mover_descr, _ = _split_move(f"{unit_full_str.split(' ')[0]} {unit_loc_full} - {unit_loc_full}")
prov_short = game_map.loc_name.get(unit_loc_full, unit_loc_full)[:3]
prov_type_disp = _province_type_display(game_map, prov_short)
sc_tag = " (SC)" if prov_short in game_map.scs else ""
owner = get_sc_controller(game_map, board_state, unit_loc_full) or "None"
owner_line = (
f"Held by {owner} (You)" if owner == power_name else f"Held by {owner}"
)
owner_line = f"Held by {owner} (You)" if owner == power_name else f"Held by {owner}"
ind = " "
block: List[str] = [f"<Territory {prov_short}>"]
@ -693,18 +696,16 @@ def _generate_rich_order_context_movement(
# ----- Adjacent territories -----
block.append("# Adjacent territories:")
block.extend(
_adjacent_territory_lines(
graph, game_map, board_state,
unit_loc_full, mover_descr, power_name
)
)
block.extend(_adjacent_territory_lines(graph, game_map, board_state, unit_loc_full, mover_descr, power_name))
# ----- Nearest enemy units -----
block.append("# Nearest units (not ours):")
enemies = get_nearest_enemy_units(
board_state, graph, game_map,
power_name, unit_loc_full,
board_state,
graph,
game_map,
power_name,
unit_loc_full,
"ARMY" if mover_descr.startswith("A") else "FLEET",
n=3,
)
@ -715,8 +716,11 @@ def _generate_rich_order_context_movement(
# ----- Nearest uncontrolled SCs -----
block.append("# Nearest supply centers (not controlled by us):")
scs = get_nearest_uncontrolled_scs(
game_map, board_state, graph,
power_name, unit_loc_full,
game_map,
board_state,
graph,
power_name,
unit_loc_full,
"ARMY" if mover_descr.startswith("A") else "FLEET",
n=3,
)
@ -729,7 +733,7 @@ def _generate_rich_order_context_movement(
block.append(f"# Possible {mover_descr} unit movements & supports:")
simple_moves = [o for o in orders if _is_simple_move(o)]
hold_orders = [o for o in orders if _is_hold_order(o)] # NEW
hold_orders = [o for o in orders if _is_hold_order(o)] # NEW
if not simple_moves and not hold_orders:
block.append(f"{ind}None")
@ -737,22 +741,19 @@ def _generate_rich_order_context_movement(
# ---- Moves (same behaviour as before) ----
for mv in simple_moves:
mover, dest = _split_move(mv)
occ = _dest_occupancy_desc(
dest.split("/")[0][:3], game_map, board_state, power_name
)
occ = _dest_occupancy_desc(dest.split("/")[0][:3], game_map, board_state, power_name)
block.append(f"{ind}{mv} {occ}")
for s in _all_support_examples(mover, dest, possible_orders_for_power):
block.append(f"{ind*2}Available Support: {s}")
block.append(f"{ind * 2}Available Support: {s}")
# ---- Holds (new) ----
for hd in hold_orders:
holder = hd.split(" H")[0] # e.g., 'F DEN'
holder = hd.split(" H")[0] # e.g., 'F DEN'
block.append(f"{ind}{hd}")
for s in _all_support_hold_examples(holder, possible_orders_for_power):
block.append(f"{ind*2}Available Support: {s}")
block.append(f"{ind * 2}Available Support: {s}")
block.append(f"</Territory {prov_short}>")
blocks.append("\n".join(block))
@ -764,6 +765,7 @@ def _generate_rich_order_context_movement(
# Retreat-phase builder echo orders verbatim, no tags
# ---------------------------------------------------------------------------
def _generate_rich_order_context_retreat(
game: Any,
power_name: str,
@ -784,11 +786,11 @@ def _generate_rich_order_context_retreat(
return "\n".join(lines) if lines else "(No dislodged units)"
# ---------------------------------------------------------------------------
# Adjustment-phase builder summary line + orders, no WAIVEs, no tags
# ---------------------------------------------------------------------------
def _generate_rich_order_context_adjustment(
game: Any,
power_name: str,
@ -800,12 +802,12 @@ def _generate_rich_order_context_adjustment(
* No wrapper tags.
"""
board_state = game.get_state()
sc_owned = len(board_state.get("centers", {}).get(power_name, []))
units_num = len(board_state.get("units", {}).get(power_name, []))
delta = sc_owned - units_num # +ve ⇒ builds, -ve ⇒ disbands
sc_owned = len(board_state.get("centers", {}).get(power_name, []))
units_num = len(board_state.get("units", {}).get(power_name, []))
delta = sc_owned - units_num # +ve ⇒ builds, -ve ⇒ disbands
# ----- summary line ----------------------------------------------------
if delta > 0:
if delta > 0:
summary = f"Builds available: {delta}"
elif delta < 0:
summary = f"Disbands required: {-delta}"
@ -824,13 +826,11 @@ def _generate_rich_order_context_adjustment(
return "\n".join(lines) if len(lines) > 1 else summary
# ---------------------------------------------------------------------------
# Phase-dispatch wrapper (public entry point)
# ---------------------------------------------------------------------------
def generate_rich_order_context(
game: Any,
power_name: str,
@ -846,22 +846,14 @@ def generate_rich_order_context(
phase_type = game.current_short_phase[-1]
if phase_type == "M": # Movement
return _generate_rich_order_context_movement(
game, power_name, possible_orders_for_power
)
if phase_type == "M": # Movement
return _generate_rich_order_context_movement(game, power_name, possible_orders_for_power)
if phase_type == "R": # Retreat
return _generate_rich_order_context_retreat(
game, power_name, possible_orders_for_power
)
if phase_type == "R": # Retreat
return _generate_rich_order_context_retreat(game, power_name, possible_orders_for_power)
if phase_type == "A": # Adjustment (build / disband)
return _generate_rich_order_context_adjustment(
game, power_name, possible_orders_for_power
)
if phase_type == "A": # Adjustment (build / disband)
return _generate_rich_order_context_adjustment(game, power_name, possible_orders_for_power)
# Fallback treat unknown formats as movement
return _generate_rich_order_context_movement(
game, power_name, possible_orders_for_power
)
return _generate_rich_order_context_movement(game, power_name, possible_orders_for_power)

View file

@ -11,7 +11,6 @@ from .possible_order_context import (
generate_rich_order_context,
generate_rich_order_context_xml,
)
import os
from .game_history import GameHistory # Assuming GameHistory is correctly importable
# placeholder for diplomacy.Game to avoid circular or direct dependency if not needed for typehinting only
@ -230,4 +229,3 @@ def construct_order_generation_prompt(
logger.debug(f"Final order generation prompt preview for {power_name}: {final_prompt[:500]}...")
return final_prompt

View file

@ -16,7 +16,7 @@ from ..config import config
if TYPE_CHECKING:
from .clients import BaseModelClient
# If DiplomacyAgent is used for type hinting for an 'agent' parameter:
# from .agent import DiplomacyAgent
# from .agent import DiplomacyAgent
logger = logging.getLogger("utils")
logger.setLevel(logging.INFO)
@ -32,12 +32,12 @@ def atomic_write_json(data: dict, filepath: str):
dir_name = os.path.dirname(filepath)
if dir_name:
os.makedirs(dir_name, exist_ok=True)
# Write to a temporary file in the same directory
temp_filepath = f"{filepath}.tmp.{os.getpid()}"
with open(temp_filepath, 'w', encoding='utf-8') as f:
with open(temp_filepath, "w", encoding="utf-8") as f:
json.dump(data, f, indent=4)
# Atomically rename the temporary file to the final destination
os.rename(temp_filepath, filepath)
except Exception as e:
@ -55,15 +55,15 @@ def assign_models_to_powers() -> Dict[str, str]:
Example usage: define which model each power uses.
Return a dict: { power_name: model_id, ... }
POWERS = ['AUSTRIA', 'ENGLAND', 'FRANCE', 'GERMANY', 'ITALY', 'RUSSIA', 'TURKEY']
Models supported: o3-mini, o4-mini, o3, gpt-4o, gpt-4o-mini,
claude-opus-4-20250514, claude-sonnet-4-20250514, claude-3-5-haiku-20241022, claude-3-5-sonnet-20241022, claude-3-7-sonnet-20250219
gemini-2.0-flash, gemini-2.5-flash-preview-04-17, gemini-2.5-pro-preview-03-25,
Models supported: o3-mini, o4-mini, o3, gpt-4o, gpt-4o-mini,
claude-opus-4-20250514, claude-sonnet-4-20250514, claude-3-5-haiku-20241022, claude-3-5-sonnet-20241022, claude-3-7-sonnet-20250219
gemini-2.0-flash, gemini-2.5-flash-preview-04-17, gemini-2.5-pro-preview-03-25,
deepseek-chat, deepseek-reasoner
openrouter-meta-llama/llama-3.3-70b-instruct, openrouter-qwen/qwen3-235b-a22b, openrouter-microsoft/phi-4-reasoning-plus:free,
openrouter-deepseek/deepseek-prover-v2:free, openrouter-meta-llama/llama-4-maverick:free, openrouter-nvidia/llama-3.3-nemotron-super-49b-v1:free,
openrouter-deepseek/deepseek-prover-v2:free, openrouter-meta-llama/llama-4-maverick:free, openrouter-nvidia/llama-3.3-nemotron-super-49b-v1:free,
openrouter-google/gemma-3-12b-it:free, openrouter-google/gemini-2.5-flash-preview-05-20
"""
# POWER MODELS
"""
return {
@ -77,13 +77,13 @@ def assign_models_to_powers() -> Dict[str, str]:
}
"""
# TEST MODELS
return {
"AUSTRIA": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
"ENGLAND": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
"FRANCE": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
"GERMANY": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
"ITALY": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
"ITALY": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
"RUSSIA": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
"TURKEY": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
}
@ -92,30 +92,26 @@ def assign_models_to_powers() -> Dict[str, str]:
def get_special_models() -> Dict[str, str]:
"""
Define models for special purposes like phase summaries and formatting.
These can be overridden via environment variables:
- AI_DIPLOMACY_NARRATIVE_MODEL: Model for phase summaries (default: "o3")
- AI_DIPLOMACY_FORMATTER_MODEL: Model for JSON formatting (default: "google/gemini-2.5-flash-lite-preview-06-17")
Returns:
dict: {
"phase_summary": model for generating narrative phase summaries,
"formatter": model for formatting natural language to JSON
}
Examples:
# Use Claude for phase summaries
export AI_DIPLOMACY_NARRATIVE_MODEL="claude-3-5-sonnet-20241022"
# Use a different Gemini model for formatting
export AI_DIPLOMACY_FORMATTER_MODEL="gemini-2.0-flash"
"""
return {
"phase_summary": config.AI_DIPLOMACY_NARRATIVE_MODEL,
"formatter": config.AI_DIPLOMACY_FORMATTER_MODEL
}
return {"phase_summary": config.AI_DIPLOMACY_NARRATIVE_MODEL, "formatter": config.AI_DIPLOMACY_FORMATTER_MODEL}
def gather_possible_orders(game: Game, power_name: str) -> Dict[str, List[str]]:
"""
@ -132,7 +128,7 @@ def gather_possible_orders(game: Game, power_name: str) -> Dict[str, List[str]]:
async def get_valid_orders(
game: Game,
client, # BaseModelClient instance
client, # BaseModelClient instance
board_state,
power_name: str,
possible_orders: Dict[str, List[str]],
@ -170,20 +166,24 @@ async def get_valid_orders(
)
invalid_info: list[str] = []
valid: list[str] = []
valid: list[str] = []
invalid: list[str] = []
# ── 2. Type check ──────────────────────────────────────────
if not isinstance(raw_orders, list):
logger.warning("[%s] Orders received from LLM are not a list: %s. Using fallback.",
power_name, raw_orders)
logger.warning("[%s] Orders received from LLM are not a list: %s. Using fallback.", power_name, raw_orders)
model_error_stats[client.model_name]["order_decoding_errors"] += 1
return {"valid": client.fallback_orders(possible_orders), "invalid": []}
# ── 3. Round-trip validation with engine ───────────────────
CODE_TO_ENGINE = {
"AUT": "AUSTRIA", "ENG": "ENGLAND", "FRA": "FRANCE",
"GER": "GERMANY", "ITA": "ITALY", "RUS": "RUSSIA", "TUR": "TURKEY",
"AUT": "AUSTRIA",
"ENG": "ENGLAND",
"FRA": "FRANCE",
"GER": "GERMANY",
"ITA": "ITALY",
"RUS": "RUSSIA",
"TUR": "TURKEY",
}
engine_power = power_name if power_name in game.powers else CODE_TO_ENGINE[power_name]
@ -202,16 +202,16 @@ async def get_valid_orders(
game.set_orders(engine_power, [upper])
normed = game.get_orders(engine_power)
if normed: # accepted
if normed: # accepted
valid.append(normed[0])
else: # rejected
else: # rejected
invalid.append(upper)
invalid_info.append(f"Order '{move}' is invalid for {power_name}")
game.clear_orders(engine_power) # clean slate for main engine flow
# ── 4. Legacy logging & stats updates ──────────────────────
if invalid_info: # at least one bad move
if invalid_info: # at least one bad move
logger.debug("[%s] Invalid orders: %s", power_name, ", ".join(invalid_info))
model_error_stats[client.model_name]["order_decoding_errors"] += 1
logger.debug("[%s] Some orders invalid, using fallback.", power_name)
@ -226,8 +226,6 @@ async def get_valid_orders(
return {"valid": valid, "invalid": invalid}
def normalize_and_compare_orders(
issued_orders: Dict[str, List[str]],
accepted_orders_dict: Dict[str, List[str]],
@ -325,18 +323,18 @@ def load_prompt(filename: str, prompts_dir: Optional[str] = None) -> str:
3. Elif *prompts_dir* is provided join prompts_dir + filename.
4. Otherwise join the packages default prompts dir.
"""
if os.path.isabs(filename): # rule 1
if os.path.isabs(filename): # rule 1
prompt_path = filename
elif os.path.dirname(filename): # rule 2 (has slash)
elif os.path.dirname(filename): # rule 2 (has slash)
# If it's a relative path with directory, join with prompts_dir if provided
if prompts_dir:
prompt_path = os.path.join(prompts_dir, filename)
else:
default_dir = os.path.join(os.path.dirname(__file__), "prompts")
prompt_path = os.path.join(default_dir, filename)
elif prompts_dir: # rule 3
elif prompts_dir: # rule 3
prompt_path = os.path.join(prompts_dir, filename)
else: # rule 4
else: # rule 4
default_dir = os.path.join(os.path.dirname(__file__), "prompts")
prompt_path = os.path.join(default_dir, filename)
@ -348,16 +346,14 @@ def load_prompt(filename: str, prompts_dir: Optional[str] = None) -> str:
return ""
# == New LLM Response Logging Function ==
def log_llm_response(
log_file_path: str,
model_name: str,
power_name: Optional[str], # Optional for non-power-specific calls like summary
power_name: Optional[str], # Optional for non-power-specific calls like summary
phase: str,
response_type: str,
raw_input_prompt: str, # Added new parameter for the raw input
raw_input_prompt: str, # Added new parameter for the raw input
raw_response: str,
success: str, # Changed from bool to str
):
@ -365,8 +361,8 @@ def log_llm_response(
try:
# Ensure the directory exists
log_dir = os.path.dirname(log_file_path)
if log_dir: # Ensure log_dir is not empty (e.g., if path is just a filename)
os.makedirs(log_dir, exist_ok=True)
if log_dir: # Ensure log_dir is not empty (e.g., if path is just a filename)
os.makedirs(log_dir, exist_ok=True)
# Check if file exists and has content to determine if we need headers
file_exists = os.path.isfile(log_file_path) and os.path.getsize(log_file_path) > 0
@ -374,34 +370,38 @@ def log_llm_response(
with open(log_file_path, "a", newline="", encoding="utf-8") as csvfile:
# Added "raw_input" to fieldnames
fieldnames = ["model", "power", "phase", "response_type", "raw_input", "raw_response", "success"]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames,
quoting=csv.QUOTE_ALL, # Quote all fields to handle commas and newlines
escapechar='\\') # Use backslash for escaping
writer = csv.DictWriter(
csvfile,
fieldnames=fieldnames,
quoting=csv.QUOTE_ALL, # Quote all fields to handle commas and newlines
escapechar="\\",
) # Use backslash for escaping
if not file_exists:
writer.writeheader() # Write header only if file is new
writer.writerow({
"model": model_name,
"power": power_name if power_name else "game", # Use 'game' if no specific power
"phase": phase,
"response_type": response_type,
"raw_input": raw_input_prompt, # Added raw_input to the row
"raw_response": raw_response,
"success": success,
})
writer.writerow(
{
"model": model_name,
"power": power_name if power_name else "game", # Use 'game' if no specific power
"phase": phase,
"response_type": response_type,
"raw_input": raw_input_prompt, # Added raw_input to the row
"raw_response": raw_response,
"success": success,
}
)
except Exception as e:
logger.error(f"Failed to log LLM response to {log_file_path}: {e}", exc_info=True)
async def run_llm_and_log(
client: "BaseModelClient",
prompt: str,
log_file_path: str, # Kept for context, but not used for logging here
log_file_path: str, # Kept for context, but not used for logging here
power_name: Optional[str], # Kept for context, but not used for logging here
phase: str, # Kept for context, but not used for logging here
response_type: str, # Kept for context, but not used for logging here
phase: str, # Kept for context, but not used for logging here
response_type: str, # Kept for context, but not used for logging here
temperature: float = 0.0,
*,
attempts: int = 5,
@ -428,43 +428,34 @@ async def run_llm_and_log(
except Exception as e:
if attempt == attempts - 1:
logger.error(
f"API Error during LLM call for {client.model_name}/{power_name}/{response_type} "
f"in phase {phase}: {e}",
f"API Error during LLM call for {client.model_name}/{power_name}/{response_type} in phase {phase}: {e}",
exc_info=True,
)
# Back-off before the next attempt (unless this was the last)
delay = backoff_base * (backoff_factor ** attempt) + random.uniform(0, jitter)
delay = backoff_base * (backoff_factor**attempt) + random.uniform(0, jitter)
await asyncio.sleep(delay)
return raw_response
# This generates a few lines of random alphanum chars to inject into the
# system prompt. This lets us use temp=0 while still getting variation
# This generates a few lines of random alphanum chars to inject into the
# system prompt. This lets us use temp=0 while still getting variation
# between trials.
# Temp=0 is important for better performance on deciding moves, and to
# Temp=0 is important for better performance on deciding moves, and to
# ensure valid json outputs.
def generate_random_seed(n_lines: int = 5, n_chars_per_line: int = 80):
# Generate x lines of y random alphanumeric characters
seed_lines = [
''.join(random.choices(string.ascii_letters + string.digits, k=n_chars_per_line))
for _ in range(n_lines)
]
random_seed_block = (
"<RANDOM SEED PLEASE IGNORE>\n" +
"\n".join(seed_lines) +
"\n</RANDOM SEED>"
)
return random_seed_block
# Generate x lines of y random alphanumeric characters
seed_lines = ["".join(random.choices(string.ascii_letters + string.digits, k=n_chars_per_line)) for _ in range(n_lines)]
random_seed_block = "<RANDOM SEED PLEASE IGNORE>\n" + "\n".join(seed_lines) + "\n</RANDOM SEED>"
return random_seed_block
def get_prompt_path(prompt_name: str) -> str:
"""Get the appropriate prompt path based on USE_UNFORMATTED_PROMPTS setting.
Args:
prompt_name: Base name of the prompt file (e.g., "conversation_instructions.txt")
Returns:
str: Either "unformatted/{prompt_name}" or just "{prompt_name}"
"""
@ -472,36 +463,37 @@ def get_prompt_path(prompt_name: str) -> str:
return f"unformatted/{prompt_name}"
else:
return prompt_name
def normalize_recipient_name(recipient: str) -> str:
"""Normalize recipient names to handle LLM typos and abbreviations."""
if not recipient:
return recipient
recipient = recipient.upper().strip()
# Handle common LLM typos and abbreviations found in data
name_mapping = {
'EGMANY': 'GERMANY',
'GERMAN': 'GERMANY',
'UK': 'ENGLAND',
'BRIT': 'ENGLAND',
'ENGLAND': 'ENGLAND', # Keep as-is
'FRANCE': 'FRANCE', # Keep as-is
'GERMANY': 'GERMANY', # Keep as-is
'ITALY': 'ITALY', # Keep as-is
'AUSTRIA': 'AUSTRIA', # Keep as-is
'RUSSIA': 'RUSSIA', # Keep as-is
'TURKEY': 'TURKEY', # Keep as-is
'Germany': 'GERMANY',
'England': 'ENGLAND',
'France': 'FRANCE',
'Italy': 'ITALY',
'Russia': 'RUSSIA',
'Austria': 'AUSTRIA',
'Turkey': 'TURKEY',
}
normalized = name_mapping.get(recipient, recipient)
return normalized
def normalize_recipient_name(recipient: str) -> str:
"""Normalize recipient names to handle LLM typos and abbreviations."""
if not recipient:
return recipient
recipient = recipient.upper().strip()
# Handle common LLM typos and abbreviations found in data
name_mapping = {
"EGMANY": "GERMANY",
"GERMAN": "GERMANY",
"UK": "ENGLAND",
"BRIT": "ENGLAND",
"ENGLAND": "ENGLAND", # Keep as-is
"FRANCE": "FRANCE", # Keep as-is
"GERMANY": "GERMANY", # Keep as-is
"ITALY": "ITALY", # Keep as-is
"AUSTRIA": "AUSTRIA", # Keep as-is
"RUSSIA": "RUSSIA", # Keep as-is
"TURKEY": "TURKEY", # Keep as-is
"Germany": "GERMANY",
"England": "ENGLAND",
"France": "FRANCE",
"Italy": "ITALY",
"Russia": "RUSSIA",
"Austria": "AUSTRIA",
"Turkey": "TURKEY",
}
normalized = name_mapping.get(recipient, recipient)
return normalized