mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-19 12:58:09 +00:00
Ruff format and check
This commit is contained in:
parent
aadedd6512
commit
ad817695ef
13 changed files with 738 additions and 973 deletions
|
|
@ -10,7 +10,7 @@ import ast
|
|||
from ..config import config
|
||||
|
||||
# Assuming BaseModelClient is importable from clients.py in the same directory
|
||||
from .clients import BaseModelClient, load_model_client
|
||||
from .clients import BaseModelClient
|
||||
|
||||
# Import load_prompt and the new logging wrapper from utils
|
||||
from .utils import load_prompt, run_llm_and_log, log_llm_response, get_prompt_path
|
||||
|
|
@ -336,7 +336,7 @@ class DiplomacyAgent:
|
|||
)
|
||||
# If even the last resort doesn't give a dict, return empty dict
|
||||
return {}
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.error(f"[{self.power_name}] All JSON extraction attempts failed. Original text: {original_text[:500]}...")
|
||||
return {}
|
||||
|
||||
|
|
@ -965,7 +965,7 @@ class DiplomacyAgent:
|
|||
other_powers = [p for p in game.powers if p != power_name]
|
||||
|
||||
# Create a readable board state string from the board_state dict
|
||||
board_state_str = f"Board State:\n"
|
||||
board_state_str = "Board State:\n"
|
||||
for p_name, power_data in board_state.get("powers", {}).items():
|
||||
# Get units and centers from the board state
|
||||
units = power_data.get("units", [])
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -11,25 +11,28 @@ if TYPE_CHECKING:
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> str | None:
|
||||
"""A local copy of the helper from agent.py to avoid circular imports."""
|
||||
import os
|
||||
|
||||
try:
|
||||
if prompts_dir:
|
||||
filepath = os.path.join(prompts_dir, filename)
|
||||
else:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
default_prompts_dir = os.path.join(current_dir, 'prompts')
|
||||
default_prompts_dir = os.path.join(current_dir, "prompts")
|
||||
filepath = os.path.join(default_prompts_dir, filename)
|
||||
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading prompt file {filepath}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def run_diary_consolidation(
|
||||
agent: 'DiplomacyAgent',
|
||||
agent: "DiplomacyAgent",
|
||||
game: "Game",
|
||||
log_file_path: str,
|
||||
entries_to_keep_unsummarized: int = 6,
|
||||
|
|
@ -39,70 +42,42 @@ async def run_diary_consolidation(
|
|||
Consolidate older diary entries while keeping recent ones.
|
||||
This is the logic moved from the DiplomacyAgent class.
|
||||
"""
|
||||
logger.info(
|
||||
f"[{agent.power_name}] CONSOLIDATION START — "
|
||||
f"{len(agent.full_private_diary)} total full entries"
|
||||
)
|
||||
logger.info(f"[{agent.power_name}] CONSOLIDATION START — {len(agent.full_private_diary)} total full entries")
|
||||
|
||||
full_entries = [
|
||||
e for e in agent.full_private_diary
|
||||
if not e.startswith("[CONSOLIDATED HISTORY]")
|
||||
]
|
||||
full_entries = [e for e in agent.full_private_diary if not e.startswith("[CONSOLIDATED HISTORY]")]
|
||||
|
||||
if len(full_entries) <= entries_to_keep_unsummarized:
|
||||
agent.private_diary = list(agent.full_private_diary)
|
||||
logger.info(
|
||||
f"[{agent.power_name}] ≤ {entries_to_keep_unsummarized} full entries — "
|
||||
"skipping consolidation"
|
||||
)
|
||||
logger.info(f"[{agent.power_name}] ≤ {entries_to_keep_unsummarized} full entries — skipping consolidation")
|
||||
return
|
||||
|
||||
boundary_entry = full_entries[-entries_to_keep_unsummarized]
|
||||
match = re.search(r"\[[SFWRAB]\s*(\d{4})", boundary_entry)
|
||||
if not match:
|
||||
logger.error(
|
||||
f"[{agent.power_name}] Could not parse year from boundary entry; "
|
||||
"aborting consolidation"
|
||||
)
|
||||
logger.error(f"[{agent.power_name}] Could not parse year from boundary entry; aborting consolidation")
|
||||
agent.private_diary = list(agent.full_private_diary)
|
||||
return
|
||||
|
||||
cutoff_year = int(match.group(1))
|
||||
logger.info(
|
||||
f"[{agent.power_name}] Cut-off year for consolidation: {cutoff_year}"
|
||||
)
|
||||
logger.info(f"[{agent.power_name}] Cut-off year for consolidation: {cutoff_year}")
|
||||
|
||||
def _entry_year(entry: str) -> int | None:
|
||||
m = re.search(r"\[[SFWRAB]\s*(\d{4})", entry)
|
||||
return int(m.group(1)) if m else None
|
||||
|
||||
entries_to_summarize = [
|
||||
e for e in full_entries
|
||||
if (_entry_year(e) is not None and _entry_year(e) < cutoff_year)
|
||||
]
|
||||
entries_to_keep = [
|
||||
e for e in full_entries
|
||||
if (_entry_year(e) is None or _entry_year(e) >= cutoff_year)
|
||||
]
|
||||
entries_to_summarize = [e for e in full_entries if (_entry_year(e) is not None and _entry_year(e) < cutoff_year)]
|
||||
entries_to_keep = [e for e in full_entries if (_entry_year(e) is None or _entry_year(e) >= cutoff_year)]
|
||||
|
||||
logger.info(
|
||||
f"[{agent.power_name}] Summarising {len(entries_to_summarize)} entries; "
|
||||
f"keeping {len(entries_to_keep)} recent entries verbatim"
|
||||
)
|
||||
logger.info(f"[{agent.power_name}] Summarising {len(entries_to_summarize)} entries; keeping {len(entries_to_keep)} recent entries verbatim")
|
||||
|
||||
if not entries_to_summarize:
|
||||
agent.private_diary = list(agent.full_private_diary)
|
||||
logger.warning(
|
||||
f"[{agent.power_name}] No eligible entries to summarise; "
|
||||
"context diary left unchanged"
|
||||
)
|
||||
logger.warning(f"[{agent.power_name}] No eligible entries to summarise; context diary left unchanged")
|
||||
return
|
||||
|
||||
prompt_template = _load_prompt_file("diary_consolidation_prompt.txt", prompts_dir=prompts_dir)
|
||||
if not prompt_template:
|
||||
logger.error(
|
||||
f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting"
|
||||
)
|
||||
logger.error(f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting")
|
||||
return
|
||||
|
||||
prompt = prompt_template.format(
|
||||
|
|
@ -132,27 +107,18 @@ async def run_diary_consolidation(
|
|||
new_summary_entry = f"[CONSOLIDATED HISTORY] {consolidated_text}"
|
||||
agent.private_diary = [new_summary_entry] + entries_to_keep
|
||||
success_flag = "TRUE"
|
||||
logger.info(
|
||||
f"[{agent.power_name}] Consolidation complete — "
|
||||
f"{len(agent.private_diary)} context entries now"
|
||||
)
|
||||
logger.info(f"[{agent.power_name}] Consolidation complete — {len(agent.private_diary)} context entries now")
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
f"[{agent.power_name}] Diary consolidation failed: {exc}", exc_info=True
|
||||
)
|
||||
logger.error(f"[{agent.power_name}] Diary consolidation failed: {exc}", exc_info=True)
|
||||
finally:
|
||||
log_llm_response(
|
||||
log_file_path=log_file_path,
|
||||
model_name=(
|
||||
consolidation_client.model_name
|
||||
if consolidation_client is not None
|
||||
else agent.client.model_name
|
||||
),
|
||||
model_name=(consolidation_client.model_name if consolidation_client is not None else agent.client.model_name),
|
||||
power_name=agent.power_name,
|
||||
phase=game.current_short_phase,
|
||||
response_type="diary_consolidation",
|
||||
raw_input_prompt=prompt,
|
||||
raw_response=raw_response,
|
||||
success=success_flag,
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,14 +3,13 @@ Formatter module for converting natural language LLM responses to structured JSO
|
|||
Uses Gemini 2.5 Flash via OpenRouter to extract and format information from reasoning-focused responses.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
# Import logging function and model configuration
|
||||
from .utils import log_llm_response, get_special_models
|
||||
|
||||
# Import client loading function
|
||||
from .clients import load_model_client
|
||||
|
||||
|
|
@ -26,77 +25,73 @@ FORMAT_ORDER_DIARY = "order_diary"
|
|||
|
||||
|
||||
async def format_with_gemini_flash(
|
||||
raw_response: str,
|
||||
format_type: str,
|
||||
power_name: Optional[str] = None,
|
||||
phase: Optional[str] = None,
|
||||
log_file_path: Optional[str] = None
|
||||
raw_response: str, format_type: str, power_name: Optional[str] = None, phase: Optional[str] = None, log_file_path: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Format a natural language response into required JSON structure using Gemini Flash.
|
||||
|
||||
|
||||
Args:
|
||||
raw_response: Natural language response from primary LLM
|
||||
format_type: Type of formatting required (e.g., FORMAT_ORDERS, FORMAT_STATE_UPDATE)
|
||||
power_name: Optional power name for logging
|
||||
phase: Optional phase for logging
|
||||
log_file_path: Optional path for CSV logging
|
||||
|
||||
|
||||
Returns:
|
||||
JSON string in the expected format
|
||||
"""
|
||||
# Map format types to prompt files
|
||||
format_prompts = {
|
||||
FORMAT_STATE_UPDATE: "formatting/format_state_update.txt",
|
||||
FORMAT_CONVERSATION: "formatting/format_conversation.txt",
|
||||
FORMAT_CONVERSATION: "formatting/format_conversation.txt",
|
||||
FORMAT_NEGOTIATION_DIARY: "formatting/format_negotiation_diary.txt",
|
||||
FORMAT_ORDERS: "formatting/format_orders.txt",
|
||||
FORMAT_INITIAL_STATE: "formatting/format_initial_state.txt",
|
||||
FORMAT_ORDER_DIARY: "formatting/format_order_diary.txt"
|
||||
FORMAT_ORDER_DIARY: "formatting/format_order_diary.txt",
|
||||
}
|
||||
|
||||
|
||||
if format_type not in format_prompts:
|
||||
raise ValueError(f"Unknown format type: {format_type}")
|
||||
|
||||
|
||||
# Load the formatting prompt
|
||||
prompt_file = Path(__file__).parent / "prompts" / format_prompts[format_type]
|
||||
if not prompt_file.exists():
|
||||
raise FileNotFoundError(f"Formatting prompt not found: {prompt_file}")
|
||||
|
||||
with open(prompt_file, 'r') as f:
|
||||
|
||||
with open(prompt_file, "r") as f:
|
||||
format_prompt = f.read()
|
||||
|
||||
|
||||
# Replace placeholder with actual response
|
||||
format_prompt = format_prompt.replace("[RAW_RESPONSE]", raw_response)
|
||||
|
||||
|
||||
# Get model name from configuration
|
||||
special_models = get_special_models()
|
||||
model_name = special_models["formatter"]
|
||||
|
||||
|
||||
# Load the formatter client using the same logic as other models
|
||||
formatter_client = load_model_client(model_name)
|
||||
|
||||
|
||||
try:
|
||||
logger.info(f"[FORMATTER] Calling {model_name} for {format_type} formatting")
|
||||
|
||||
|
||||
# Create the full prompt with system message
|
||||
system_content = "You are a precise formatting assistant. Extract and format information exactly as requested."
|
||||
formatter_client.set_system_prompt(system_content)
|
||||
|
||||
|
||||
# Use the client's generate_response method
|
||||
formatted_response = await formatter_client.generate_response(
|
||||
prompt=format_prompt,
|
||||
temperature=0, # Deterministic formatting
|
||||
inject_random_seed=False # No need for random seed in formatting
|
||||
inject_random_seed=False, # No need for random seed in formatting
|
||||
)
|
||||
|
||||
|
||||
if not formatted_response:
|
||||
logger.warning(f"[FORMATTER] {model_name} returned empty response")
|
||||
return ""
|
||||
|
||||
|
||||
# Log successful formatting
|
||||
logger.info(f"[FORMATTER] Successfully formatted {format_type} response")
|
||||
|
||||
|
||||
# Strip any markdown formatting that Gemini might add
|
||||
if formatted_response.startswith("```json"):
|
||||
formatted_response = formatted_response[7:]
|
||||
|
|
@ -105,7 +100,7 @@ async def format_with_gemini_flash(
|
|||
if formatted_response.endswith("```"):
|
||||
formatted_response = formatted_response[:-3]
|
||||
formatted_response = formatted_response.strip()
|
||||
|
||||
|
||||
# Log if requested
|
||||
if log_file_path:
|
||||
log_llm_response(
|
||||
|
|
@ -116,14 +111,14 @@ async def format_with_gemini_flash(
|
|||
response_type=f"format_{format_type}",
|
||||
raw_input_prompt=format_prompt,
|
||||
raw_response=formatted_response,
|
||||
success="Success"
|
||||
success="Success",
|
||||
)
|
||||
|
||||
|
||||
return formatted_response
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[FORMATTER] Error calling Gemini Flash: {e}")
|
||||
|
||||
|
||||
# Log error if requested
|
||||
if log_file_path:
|
||||
log_llm_response(
|
||||
|
|
@ -134,11 +129,11 @@ async def format_with_gemini_flash(
|
|||
response_type=f"format_{format_type}",
|
||||
raw_input_prompt=format_prompt,
|
||||
raw_response=f"ERROR: {str(e)}",
|
||||
success=f"Failure: {type(e).__name__}"
|
||||
success=f"Failure: {type(e).__name__}",
|
||||
)
|
||||
|
||||
|
||||
# Return empty structure based on format type
|
||||
if format_type == FORMAT_CONVERSATION:
|
||||
return "[]"
|
||||
else:
|
||||
return "{}"
|
||||
return "{}"
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import logging
|
|||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional
|
||||
from pprint import pformat
|
||||
|
||||
logger = logging.getLogger("utils")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
|
@ -24,12 +23,8 @@ class Phase:
|
|||
plans: Dict[str, str] = field(default_factory=dict)
|
||||
messages: List[Message] = field(default_factory=list)
|
||||
submitted_orders_by_power: Dict[str, List[str]] = field(default_factory=dict)
|
||||
orders_by_power: Dict[str, List[str]] = field(
|
||||
default_factory=lambda: defaultdict(list)
|
||||
)
|
||||
results_by_power: Dict[str, List[List[str]]] = field(
|
||||
default_factory=lambda: defaultdict(list)
|
||||
)
|
||||
orders_by_power: Dict[str, List[str]] = field(default_factory=lambda: defaultdict(list))
|
||||
results_by_power: Dict[str, List[List[str]]] = field(default_factory=lambda: defaultdict(list))
|
||||
# NEW: Store phase-end summaries provided by each power
|
||||
phase_summaries: Dict[str, str] = field(default_factory=dict)
|
||||
# NEW: Store experience/journal updates from each power for this phase
|
||||
|
|
@ -37,11 +32,9 @@ class Phase:
|
|||
|
||||
def add_plan(self, power_name: str, plan: str):
|
||||
self.plans[power_name] = plan
|
||||
|
||||
|
||||
def add_message(self, sender: str, recipient: str, content: str):
|
||||
self.messages.append(
|
||||
Message(sender=sender, recipient=recipient, content=content)
|
||||
)
|
||||
self.messages.append(Message(sender=sender, recipient=recipient, content=content))
|
||||
|
||||
def add_orders(self, power: str, orders: List[str], results: List[List[str]]):
|
||||
self.orders_by_power[power].extend(orders)
|
||||
|
|
@ -110,14 +103,10 @@ class GameHistory:
|
|||
phase.plans[power_name] = plan
|
||||
logger.debug(f"Added plan for {power_name} in {phase_name}")
|
||||
|
||||
def add_message(
|
||||
self, phase_name: str, sender: str, recipient: str, message_content: str
|
||||
):
|
||||
def add_message(self, phase_name: str, sender: str, recipient: str, message_content: str):
|
||||
phase = self._get_phase(phase_name)
|
||||
if phase:
|
||||
message = Message(
|
||||
sender=sender, recipient=recipient, content=message_content
|
||||
)
|
||||
message = Message(sender=sender, recipient=recipient, content=message_content)
|
||||
phase.messages.append(message)
|
||||
logger.debug(f"Added message from {sender} to {recipient} in {phase_name}")
|
||||
|
||||
|
|
@ -147,16 +136,12 @@ class GameHistory:
|
|||
phase.experience_updates[power_name] = update
|
||||
logger.debug(f"Added experience update for {power_name} in {phase_name}")
|
||||
|
||||
def get_strategic_directives(self):
|
||||
def get_strategic_directives(self):
|
||||
# returns for last phase only if exists
|
||||
if not self.phases:
|
||||
if not self.phases:
|
||||
return {}
|
||||
return self.phases[-1].plans
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def get_order_history_for_prompt(
|
||||
self,
|
||||
game: "Game",
|
||||
|
|
@ -194,10 +179,7 @@ class GameHistory:
|
|||
return str(tag).strip().lower()
|
||||
|
||||
engine_phases = {ph.name: ph for ph in getattr(game, "get_phase_history", lambda: [])()}
|
||||
eng2code = {
|
||||
"AUSTRIA": "AUT", "ENGLAND": "ENG", "FRANCE": "FRA",
|
||||
"GERMANY": "GER", "ITALY": "ITA", "RUSSIA": "RUS", "TURKEY": "TUR"
|
||||
}
|
||||
eng2code = {"AUSTRIA": "AUT", "ENGLAND": "ENG", "FRANCE": "FRA", "GERMANY": "GER", "ITALY": "ITA", "RUSSIA": "RUS", "TURKEY": "TUR"}
|
||||
norm = game.map.norm
|
||||
|
||||
out_lines = ["**ORDER HISTORY (Recent Rounds)**"]
|
||||
|
|
@ -209,14 +191,14 @@ class GameHistory:
|
|||
|
||||
for pwr in sorted(set(ph.orders_by_power) | set(ph.submitted_orders_by_power)):
|
||||
submitted = ph.submitted_orders_by_power.get(pwr, [])
|
||||
accepted = ph.orders_by_power.get(pwr, [])
|
||||
accepted = ph.orders_by_power.get(pwr, [])
|
||||
|
||||
if isinstance(submitted, str):
|
||||
submitted = [submitted]
|
||||
if isinstance(accepted, str):
|
||||
accepted = [accepted]
|
||||
|
||||
def _norm_keep(o): # keep WAIVE readable
|
||||
def _norm_keep(o): # keep WAIVE readable
|
||||
return o if o.upper() == "WAIVE" else norm(o)
|
||||
|
||||
sub_norm = {_norm_keep(o): o for o in submitted}
|
||||
|
|
@ -264,13 +246,6 @@ class GameHistory:
|
|||
return "\n(No orders were issued in recent history)\n"
|
||||
return "\n".join(out_lines)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def get_messages_this_round(self, power_name: str, current_phase_name: str) -> str:
|
||||
current_phase: Optional[Phase] = None
|
||||
for phase_obj in self.phases:
|
||||
|
|
@ -281,7 +256,7 @@ class GameHistory:
|
|||
if not current_phase:
|
||||
return f"\n(No messages found for current phase: {current_phase_name})\n"
|
||||
|
||||
messages_str = ""
|
||||
messages_str = ""
|
||||
|
||||
global_msgs_content = current_phase.get_global_messages()
|
||||
if global_msgs_content:
|
||||
|
|
@ -299,7 +274,7 @@ class GameHistory:
|
|||
messages_str += "\n"
|
||||
else:
|
||||
messages_str += "\n**PRIVATE MESSAGES TO/FROM YOU THIS ROUND:**\n (No private messages this round)\n"
|
||||
|
||||
|
||||
if not global_msgs_content and not private_msgs_dict:
|
||||
return f"\n(No messages recorded for current phase: {current_phase_name})\n"
|
||||
|
||||
|
|
@ -313,10 +288,10 @@ class GameHistory:
|
|||
"""
|
||||
if not self.phases:
|
||||
return []
|
||||
|
||||
|
||||
# Get the most recent 2 phases including current phase
|
||||
recent_phases = self.phases[-2:] if len(self.phases) >= 2 else self.phases[-1:]
|
||||
|
||||
|
||||
# Collect all messages sent TO this power
|
||||
messages_to_power = []
|
||||
for phase in recent_phases:
|
||||
|
|
@ -325,48 +300,44 @@ class GameHistory:
|
|||
if msg.recipient == power_name or (msg.recipient == "GLOBAL" and msg.sender != power_name):
|
||||
# Skip if sender is this power (don't need to respond to own messages)
|
||||
if msg.sender != power_name:
|
||||
messages_to_power.append({
|
||||
'sender': msg.sender,
|
||||
'content': msg.content,
|
||||
'phase': phase.name
|
||||
})
|
||||
|
||||
messages_to_power.append({"sender": msg.sender, "content": msg.content, "phase": phase.name})
|
||||
|
||||
# Add debug logging
|
||||
logger.info(f"Found {len(messages_to_power)} messages to {power_name} across {len(recent_phases)} phases")
|
||||
if not messages_to_power:
|
||||
logger.info(f"No messages found for {power_name} to respond to")
|
||||
|
||||
|
||||
# Take the most recent 'limit' messages
|
||||
return messages_to_power[-limit:] if messages_to_power else []
|
||||
|
||||
|
||||
def get_ignored_messages_by_power(self, sender_name: str, num_phases: int = 3) -> Dict[str, List[Dict[str, str]]]:
|
||||
"""
|
||||
Identifies which powers are not responding to messages from sender_name.
|
||||
Returns a dict mapping power names to their ignored messages.
|
||||
|
||||
|
||||
A message is considered ignored if:
|
||||
1. It was sent from sender_name to another power (private)
|
||||
2. No response from that power was received in the same or next phase
|
||||
"""
|
||||
ignored_by_power = {}
|
||||
|
||||
|
||||
# Get recent phases
|
||||
recent_phases = self.phases[-num_phases:] if self.phases else []
|
||||
if not recent_phases:
|
||||
return ignored_by_power
|
||||
|
||||
|
||||
for i, phase in enumerate(recent_phases):
|
||||
# Get messages sent by sender to specific powers (not global)
|
||||
sender_messages = []
|
||||
for msg in phase.messages:
|
||||
# Handle both Message objects and dict objects
|
||||
if isinstance(msg, Message):
|
||||
if msg.sender == sender_name and msg.recipient not in ['GLOBAL', 'ALL']:
|
||||
if msg.sender == sender_name and msg.recipient not in ["GLOBAL", "ALL"]:
|
||||
sender_messages.append(msg)
|
||||
else: # Assume dict
|
||||
if msg['sender'] == sender_name and msg['recipient'] not in ['GLOBAL', 'ALL']:
|
||||
if msg["sender"] == sender_name and msg["recipient"] not in ["GLOBAL", "ALL"]:
|
||||
sender_messages.append(msg)
|
||||
|
||||
|
||||
# Check for responses in this and next phases
|
||||
for msg in sender_messages:
|
||||
# Handle both Message objects and dict objects
|
||||
|
|
@ -374,37 +345,36 @@ class GameHistory:
|
|||
recipient = msg.recipient
|
||||
msg_content = msg.content
|
||||
else:
|
||||
recipient = msg['recipient']
|
||||
msg_content = msg['content']
|
||||
|
||||
recipient = msg["recipient"]
|
||||
msg_content = msg["content"]
|
||||
|
||||
# Look for responses in current phase and next phases
|
||||
found_response = False
|
||||
|
||||
|
||||
# Check remaining phases starting from current
|
||||
for check_phase in recent_phases[i:min(i+2, len(recent_phases))]:
|
||||
for check_phase in recent_phases[i : min(i + 2, len(recent_phases))]:
|
||||
# Look for messages FROM the recipient TO the sender (direct response)
|
||||
# or FROM the recipient to GLOBAL/ALL that might acknowledge sender
|
||||
response_msgs = []
|
||||
for m in check_phase.messages:
|
||||
if isinstance(m, Message):
|
||||
if m.sender == recipient and (m.recipient == sender_name or
|
||||
(m.recipient in ['GLOBAL', 'ALL'] and sender_name in m.content)):
|
||||
if m.sender == recipient and (
|
||||
m.recipient == sender_name or (m.recipient in ["GLOBAL", "ALL"] and sender_name in m.content)
|
||||
):
|
||||
response_msgs.append(m)
|
||||
else: # Assume dict
|
||||
if m['sender'] == recipient and (m['recipient'] == sender_name or
|
||||
(m['recipient'] in ['GLOBAL', 'ALL'] and sender_name in m.get('content', ''))):
|
||||
if m["sender"] == recipient and (
|
||||
m["recipient"] == sender_name or (m["recipient"] in ["GLOBAL", "ALL"] and sender_name in m.get("content", ""))
|
||||
):
|
||||
response_msgs.append(m)
|
||||
|
||||
|
||||
if response_msgs:
|
||||
found_response = True
|
||||
break
|
||||
|
||||
|
||||
if not found_response:
|
||||
if recipient not in ignored_by_power:
|
||||
ignored_by_power[recipient] = []
|
||||
ignored_by_power[recipient].append({
|
||||
'phase': phase.name,
|
||||
'content': msg_content
|
||||
})
|
||||
|
||||
return ignored_by_power
|
||||
ignored_by_power[recipient].append({"phase": phase.name, "content": msg_content})
|
||||
|
||||
return ignored_by_power
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import logging
|
|||
import os
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Dict, List, Tuple, Optional, Any
|
||||
from typing import Dict, Tuple, Optional, Any
|
||||
from argparse import Namespace
|
||||
|
||||
from diplomacy import Game
|
||||
|
|
@ -19,6 +19,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
# --- Serialization / Deserialization ---
|
||||
|
||||
|
||||
def serialize_agent(agent: DiplomacyAgent) -> dict:
|
||||
"""Converts an agent object to a JSON-serializable dictionary."""
|
||||
return {
|
||||
|
|
@ -31,12 +32,8 @@ def serialize_agent(agent: DiplomacyAgent) -> dict:
|
|||
"private_diary": agent.private_diary,
|
||||
}
|
||||
|
||||
def deserialize_agent(
|
||||
agent_data: dict,
|
||||
prompts_dir: Optional[str] = None,
|
||||
*,
|
||||
override_model_id: Optional[str] = None
|
||||
) -> DiplomacyAgent:
|
||||
|
||||
def deserialize_agent(agent_data: dict, prompts_dir: Optional[str] = None, *, override_model_id: Optional[str] = None) -> DiplomacyAgent:
|
||||
"""
|
||||
Recreates an agent object from a dictionary.
|
||||
|
||||
|
|
@ -68,34 +65,31 @@ def deserialize_agent(
|
|||
# --- State Management ---
|
||||
|
||||
# game_logic.py
|
||||
_PHASE_ORDER = ["M", "R", "A"] # Movement → Retreats → Adjustments
|
||||
_PHASE_ORDER = ["M", "R", "A"] # Movement → Retreats → Adjustments
|
||||
|
||||
|
||||
def _next_phase_name(short: str) -> str:
|
||||
"""
|
||||
Return the Diplomacy phase string that chronologically follows *short*.
|
||||
(E.g. S1901M → S1901R, S1901R → W1901A, W1901A → S1902M)
|
||||
"""
|
||||
season = short[0] # 'S' | 'W'
|
||||
year = int(short[1:5])
|
||||
typ = short[-1] # 'M' | 'R' | 'A'
|
||||
season = short[0] # 'S' | 'W'
|
||||
year = int(short[1:5])
|
||||
typ = short[-1] # 'M' | 'R' | 'A'
|
||||
|
||||
idx = _PHASE_ORDER.index(typ)
|
||||
if idx < 2: # still in the same season
|
||||
return f"{season}{year}{_PHASE_ORDER[idx+1]}"
|
||||
if idx < 2: # still in the same season
|
||||
return f"{season}{year}{_PHASE_ORDER[idx + 1]}"
|
||||
|
||||
# typ was 'A' → roll season
|
||||
if season == "S": # summer → winter, same year
|
||||
if season == "S": # summer → winter, same year
|
||||
return f"W{year}M"
|
||||
else: # winter→ spring, next year
|
||||
return f"S{year+1}M"
|
||||
else: # winter→ spring, next year
|
||||
return f"S{year + 1}M"
|
||||
|
||||
|
||||
def save_game_state(
|
||||
game: Game,
|
||||
agents: Dict[str, DiplomacyAgent],
|
||||
game_history: GameHistory,
|
||||
output_path: str,
|
||||
run_config: Namespace,
|
||||
completed_phase_name: str
|
||||
game: Game, agents: Dict[str, DiplomacyAgent], game_history: GameHistory, output_path: str, run_config: Namespace, completed_phase_name: str
|
||||
):
|
||||
"""
|
||||
Serialise the entire game to JSON, preserving per-phase custom metadata
|
||||
|
|
@ -130,9 +124,7 @@ def save_game_state(
|
|||
if extras:
|
||||
previous_phase_extras[phase["name"]] = extras
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Could not load previous save to retain metadata: %s", exc, exc_info=True
|
||||
)
|
||||
logger.warning("Could not load previous save to retain metadata: %s", exc, exc_info=True)
|
||||
|
||||
# -------------------------------------------------------------- #
|
||||
# 2. Build the fresh base structure from the diplomacy library. #
|
||||
|
|
@ -143,16 +135,12 @@ def save_game_state(
|
|||
# 3. Walk every phase and merge the metadata back in. #
|
||||
# -------------------------------------------------------------- #
|
||||
# Capture the *current* snapshot of every live agent exactly once.
|
||||
current_state_agents = {
|
||||
p_name: serialize_agent(p_agent)
|
||||
for p_name, p_agent in agents.items()
|
||||
if not game.powers[p_name].is_eliminated()
|
||||
}
|
||||
current_state_agents = {p_name: serialize_agent(p_agent) for p_name, p_agent in agents.items() if not game.powers[p_name].is_eliminated()}
|
||||
|
||||
for phase_block in saved_game.get("phases", []):
|
||||
if int(phase_block["name"][1:5]) > run_config.max_year:
|
||||
break
|
||||
|
||||
|
||||
phase_name = phase_block["name"]
|
||||
|
||||
# 3a. Re-attach anything we cached from a previous save.
|
||||
|
|
@ -167,31 +155,24 @@ def save_game_state(
|
|||
|
||||
# Plans for this phase – may be empty in non-movement phases.
|
||||
phase_obj = game_history._get_phase(phase_name)
|
||||
phase_block["state_history_plans"] = (
|
||||
phase_obj.plans if phase_obj else {}
|
||||
)
|
||||
|
||||
phase_block["state_history_plans"] = phase_obj.plans if phase_obj else {}
|
||||
|
||||
# -------------------------------------------------------------- #
|
||||
# 4. Attach top-level metadata and write atomically. #
|
||||
# -------------------------------------------------------------- #
|
||||
saved_game["phase_summaries"] = getattr(game, "phase_summaries", {})
|
||||
saved_game["final_agent_states"] = {
|
||||
p_name: {"relationships": a.relationships, "goals": a.goals}
|
||||
for p_name, a in agents.items()
|
||||
}
|
||||
saved_game["final_agent_states"] = {p_name: {"relationships": a.relationships, "goals": a.goals} for p_name, a in agents.items()}
|
||||
|
||||
# Filter out phases > max_year
|
||||
#saved_game["phases"] = [
|
||||
# saved_game["phases"] = [
|
||||
# ph for ph in saved_game["phases"]
|
||||
# if int(ph["name"][1:5]) <= run_config.max_year # <= 1902, for example
|
||||
#]
|
||||
# ]
|
||||
atomic_write_json(saved_game, output_path)
|
||||
|
||||
logger.info("Game state saved successfully.")
|
||||
|
||||
|
||||
|
||||
def load_game_state(
|
||||
run_dir: str,
|
||||
game_file_name: str,
|
||||
|
|
@ -204,7 +185,7 @@ def load_game_state(
|
|||
raise FileNotFoundError(f"Cannot resume. Save file not found at: {game_file_path}")
|
||||
|
||||
logger.info(f"Loading game state from: {game_file_path}")
|
||||
with open(game_file_path, 'r') as f:
|
||||
with open(game_file_path, "r") as f:
|
||||
saved_game_data = json.load(f)
|
||||
|
||||
# If resuming, find the specified phase and truncate the data after it
|
||||
|
|
@ -213,49 +194,49 @@ def load_game_state(
|
|||
try:
|
||||
# Find the index of the phase *before* the one we want to resume from.
|
||||
# We will start the simulation *at* the resume_from_phase.
|
||||
resume_idx = next(i for i, phase in enumerate(saved_game_data['phases']) if phase['name'] == resume_from_phase)
|
||||
resume_idx = next(i for i, phase in enumerate(saved_game_data["phases"]) if phase["name"] == resume_from_phase)
|
||||
# Truncate the list to exclude everything after the resume phase
|
||||
# Note: the state saved for a given phase represents the state at the beginning of that phase.
|
||||
saved_game_data['phases'] = saved_game_data['phases'][:resume_idx+1]
|
||||
saved_game_data["phases"] = saved_game_data["phases"][: resume_idx + 1]
|
||||
|
||||
# Wipe any data that must be regenerated.
|
||||
for key in ("orders", "results", "messages"):
|
||||
saved_game_data['phases'][-1].pop(key, None)
|
||||
saved_game_data["phases"][-1].pop(key, None)
|
||||
logger.info(f"Game history truncated to {len(saved_game_data['phases'])} phases. The next phase to run will be {resume_from_phase}.")
|
||||
except StopIteration:
|
||||
# If the phase is not found, maybe it's the first phase (S1901M)
|
||||
if resume_from_phase == "S1901M":
|
||||
saved_game_data['phases'] = []
|
||||
logger.info("Resuming from S1901M. Starting with a clean history.")
|
||||
saved_game_data["phases"] = []
|
||||
logger.info("Resuming from S1901M. Starting with a clean history.")
|
||||
else:
|
||||
raise ValueError(f"Resume phase '{resume_from_phase}' not found in the save file.")
|
||||
|
||||
# Reconstruct the Game object
|
||||
last_phase = saved_game_data['phases'][-1]
|
||||
last_phase = saved_game_data["phases"][-1]
|
||||
|
||||
# Wipe the data that must be regenerated **but preserve the keys**
|
||||
last_phase['orders'] = {} # was dict
|
||||
last_phase['results'] = {} # was dict
|
||||
last_phase['messages'] = []
|
||||
last_phase["orders"] = {} # was dict
|
||||
last_phase["results"] = {} # was dict
|
||||
last_phase["messages"] = []
|
||||
|
||||
game = from_saved_game_format(saved_game_data)
|
||||
|
||||
game.phase_summaries = saved_game_data.get('phase_summaries', {})
|
||||
game.phase_summaries = saved_game_data.get("phase_summaries", {})
|
||||
|
||||
# Reconstruct agents and game history from the *last* valid phase in the data
|
||||
if not saved_game_data['phases']:
|
||||
if not saved_game_data["phases"]:
|
||||
# This happens if we are resuming from the very beginning (S1901M)
|
||||
logger.info("No previous phases found. Initializing fresh agents and history.")
|
||||
agents = {} # Will be created by the main loop
|
||||
agents = {} # Will be created by the main loop
|
||||
game_history = GameHistory()
|
||||
else:
|
||||
# We save the game state up to & including the current (uncompleted) phase.
|
||||
# So we need to grab the agent state from the previous (completed) phase.
|
||||
if len(saved_game_data['phases']) <= 1:
|
||||
if len(saved_game_data["phases"]) <= 1:
|
||||
last_phase_data = {}
|
||||
else:
|
||||
last_phase_data = saved_game_data['phases'][-2]
|
||||
|
||||
last_phase_data = saved_game_data["phases"][-2]
|
||||
|
||||
# -------------------- Rebuild agents -------------------- #
|
||||
agents = {}
|
||||
|
||||
|
|
@ -269,18 +250,11 @@ def load_game_state(
|
|||
elif len(provided) == 1:
|
||||
power_model_map = dict(zip(powers_order, provided * len(powers_order)))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid --models argument: expected 1 or {len(powers_order)} "
|
||||
f"items, got {len(provided)}."
|
||||
)
|
||||
raise ValueError(f"Invalid --models argument: expected 1 or {len(powers_order)} items, got {len(provided)}.")
|
||||
|
||||
if "state_agents" in last_phase_data:
|
||||
logger.info("Rebuilding agents from saved state...")
|
||||
prompts_dir_from_config = (
|
||||
run_config.prompts_dir
|
||||
if run_config and hasattr(run_config, "prompts_dir")
|
||||
else None
|
||||
)
|
||||
prompts_dir_from_config = run_config.prompts_dir if run_config and hasattr(run_config, "prompts_dir") else None
|
||||
for power_name, agent_data in last_phase_data["state_agents"].items():
|
||||
override_id = power_model_map.get(power_name)
|
||||
agents[power_name] = deserialize_agent(
|
||||
|
|
@ -290,40 +264,30 @@ def load_game_state(
|
|||
)
|
||||
logger.info(f"Rebuilt {len(agents)} agents.")
|
||||
else:
|
||||
raise ValueError(
|
||||
"Cannot resume: 'state_agents' key not found in the last phase "
|
||||
"of the save file."
|
||||
)
|
||||
|
||||
raise ValueError("Cannot resume: 'state_agents' key not found in the last phase of the save file.")
|
||||
|
||||
# Rebuild GameHistory
|
||||
game_history = GameHistory()
|
||||
logger.info("Rebuilding game history...")
|
||||
for phase_data in saved_game_data['phases'][:-1]:
|
||||
phase_name = phase_data['name']
|
||||
for phase_data in saved_game_data["phases"][:-1]:
|
||||
phase_name = phase_data["name"]
|
||||
game_history.add_phase(phase_name)
|
||||
# Add messages
|
||||
for msg in phase_data.get('messages', []):
|
||||
game_history.add_message(phase_name, msg['sender'], msg['recipient'], msg['message'])
|
||||
for msg in phase_data.get("messages", []):
|
||||
game_history.add_message(phase_name, msg["sender"], msg["recipient"], msg["message"])
|
||||
# Add plans
|
||||
if 'state_history_plans' in phase_data:
|
||||
for p_name, plan in phase_data['state_history_plans'].items():
|
||||
if "state_history_plans" in phase_data:
|
||||
for p_name, plan in phase_data["state_history_plans"].items():
|
||||
game_history.add_plan(phase_name, p_name, plan)
|
||||
logger.info("Game history rebuilt.")
|
||||
|
||||
|
||||
return game, agents, game_history, run_config
|
||||
|
||||
|
||||
async def initialize_new_game(
|
||||
args: Namespace,
|
||||
game: Game,
|
||||
game_history: GameHistory,
|
||||
llm_log_file_path: str
|
||||
) -> Dict[str, DiplomacyAgent]:
|
||||
async def initialize_new_game(args: Namespace, game: Game, game_history: GameHistory, llm_log_file_path: str) -> Dict[str, DiplomacyAgent]:
|
||||
"""Initializes agents for a new game."""
|
||||
powers_order = sorted(list(ALL_POWERS))
|
||||
|
||||
|
||||
# Parse token limits
|
||||
default_max_tokens = args.max_tokens
|
||||
model_max_tokens = {p: default_max_tokens for p in powers_order}
|
||||
|
|
@ -339,7 +303,7 @@ async def initialize_new_game(
|
|||
# Handle power model mapping
|
||||
|
||||
if args.models:
|
||||
provided_models = [name.strip() for name in args.models.split(",")]
|
||||
provided_models = [name.strip() for name in args.models.split(",")]
|
||||
if len(provided_models) == len(powers_order):
|
||||
game.power_model_map = dict(zip(powers_order, provided_models))
|
||||
elif len(provided_models) == 1:
|
||||
|
|
@ -364,17 +328,17 @@ async def initialize_new_game(
|
|||
initialization_tasks.append(initialize_agent_state_ext(agent, game, game_history, llm_log_file_path, prompts_dir=args.prompts_dir))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create agent or client for {power_name} with model {model_id}: {e}", exc_info=True)
|
||||
|
||||
|
||||
logger.info(f"Running {len(initialization_tasks)} agent initializations concurrently...")
|
||||
initialization_results = await asyncio.gather(*initialization_tasks, return_exceptions=True)
|
||||
|
||||
|
||||
initialized_powers = list(agents.keys())
|
||||
for i, result in enumerate(initialization_results):
|
||||
if i < len(initialized_powers):
|
||||
power_name = initialized_powers[i]
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Failed to initialize agent state for {power_name}: {result}", exc_info=result)
|
||||
else:
|
||||
logger.info(f"Successfully initialized agent state for {power_name}.")
|
||||
|
||||
return agents
|
||||
if i < len(initialized_powers):
|
||||
power_name = initialized_powers[i]
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Failed to initialize agent state for {power_name}: {result}", exc_info=result)
|
||||
else:
|
||||
logger.info(f"Successfully initialized agent state for {power_name}.")
|
||||
|
||||
return agents
|
||||
|
|
|
|||
|
|
@ -1,12 +1,11 @@
|
|||
# ai_diplomacy/initialization.py
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
from ..config import config
|
||||
|
||||
# Forward declaration for type hinting, actual imports in function if complex
|
||||
if False: # TYPE_CHECKING
|
||||
if False: # TYPE_CHECKING
|
||||
from diplomacy import Game
|
||||
from diplomacy.models.game import GameHistory
|
||||
from .agent import DiplomacyAgent
|
||||
|
|
@ -18,40 +17,40 @@ from .formatter import format_with_gemini_flash, FORMAT_INITIAL_STATE
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def initialize_agent_state_ext(
|
||||
agent: 'DiplomacyAgent',
|
||||
game: 'Game',
|
||||
game_history: 'GameHistory',
|
||||
agent: "DiplomacyAgent",
|
||||
game: "Game",
|
||||
game_history: "GameHistory",
|
||||
log_file_path: str,
|
||||
prompts_dir: Optional[str] = None,
|
||||
):
|
||||
"""Uses the LLM to set initial goals and relationships for the agent."""
|
||||
power_name = agent.power_name
|
||||
logger.info(f"[{power_name}] Initializing agent state using LLM (external function)..." )
|
||||
logger.info(f"[{power_name}] Initializing agent state using LLM (external function)...")
|
||||
current_phase = game.get_current_phase() if game else "UnknownPhase"
|
||||
|
||||
full_prompt = "" # Ensure full_prompt is defined in the outer scope for finally block
|
||||
response = "" # Ensure response is defined for finally block
|
||||
success_status = "Failure: Initialized" # Default status
|
||||
response = "" # Ensure response is defined for finally block
|
||||
success_status = "Failure: Initialized" # Default status
|
||||
|
||||
try:
|
||||
# Load the prompt template
|
||||
allowed_labels_str = ", ".join(ALLOWED_RELATIONSHIPS)
|
||||
initial_prompt_template = _load_prompt_file(get_prompt_path('initial_state_prompt.txt'), prompts_dir=prompts_dir)
|
||||
initial_prompt_template = _load_prompt_file(get_prompt_path("initial_state_prompt.txt"), prompts_dir=prompts_dir)
|
||||
if not initial_prompt_template:
|
||||
logger.error(f"[{power_name}] Could not load {get_prompt_path('initial_state_prompt.txt')}. Cannot initialize.")
|
||||
return
|
||||
|
||||
|
||||
# Format the prompt with variables
|
||||
initial_prompt = initial_prompt_template.format(
|
||||
power_name=power_name,
|
||||
allowed_labels_str=allowed_labels_str
|
||||
)
|
||||
initial_prompt = initial_prompt_template.format(power_name=power_name, allowed_labels_str=allowed_labels_str)
|
||||
|
||||
board_state = game.get_state() if game else {}
|
||||
possible_orders = game.get_all_possible_orders() if game else {}
|
||||
|
||||
logger.debug(f"[{power_name}] Preparing context for initial state. Board state type: {type(board_state)}, possible_orders type: {type(possible_orders)}, game_history type: {type(game_history)}")
|
||||
logger.debug(
|
||||
f"[{power_name}] Preparing context for initial state. Board state type: {type(board_state)}, possible_orders type: {type(possible_orders)}, game_history type: {type(game_history)}"
|
||||
)
|
||||
# Ensure agent.client and its methods can handle None for game/board_state/etc. if that's a possibility
|
||||
# For initialization, game should always be present.
|
||||
|
||||
|
|
@ -59,12 +58,12 @@ async def initialize_agent_state_ext(
|
|||
|
||||
context = build_context_prompt(
|
||||
game=game,
|
||||
board_state=board_state,
|
||||
board_state=board_state,
|
||||
power_name=power_name,
|
||||
possible_orders=possible_orders,
|
||||
game_history=game_history,
|
||||
agent_goals=None,
|
||||
agent_relationships=None,
|
||||
possible_orders=possible_orders,
|
||||
game_history=game_history,
|
||||
agent_goals=None,
|
||||
agent_relationships=None,
|
||||
agent_private_diary=formatted_diary,
|
||||
prompts_dir=prompts_dir,
|
||||
)
|
||||
|
|
@ -76,9 +75,9 @@ async def initialize_agent_state_ext(
|
|||
log_file_path=log_file_path,
|
||||
power_name=power_name,
|
||||
phase=current_phase,
|
||||
response_type='initialization', # Context for run_llm_and_log internal error logging
|
||||
response_type="initialization", # Context for run_llm_and_log internal error logging
|
||||
)
|
||||
logger.debug(f"[{power_name}] LLM response for initial state: {response[:300]}...") # Log a snippet
|
||||
logger.debug(f"[{power_name}] LLM response for initial state: {response[:300]}...") # Log a snippet
|
||||
|
||||
parsed_successfully = False
|
||||
try:
|
||||
|
|
@ -86,11 +85,7 @@ async def initialize_agent_state_ext(
|
|||
if config.USE_UNFORMATTED_PROMPTS:
|
||||
# Format the natural language response into JSON
|
||||
formatted_response = await format_with_gemini_flash(
|
||||
response,
|
||||
FORMAT_INITIAL_STATE,
|
||||
power_name=power_name,
|
||||
phase=current_phase,
|
||||
log_file_path=log_file_path
|
||||
response, FORMAT_INITIAL_STATE, power_name=power_name, phase=current_phase, log_file_path=log_file_path
|
||||
)
|
||||
else:
|
||||
# Use the raw response directly (already formatted)
|
||||
|
|
@ -101,20 +96,24 @@ async def initialize_agent_state_ext(
|
|||
except json.JSONDecodeError as e:
|
||||
logger.error(f"[{power_name}] All JSON extraction attempts failed: {e}. Response snippet: {response[:300]}...")
|
||||
success_status = "Failure: JSONDecodeError"
|
||||
update_data = {} # Ensure update_data exists for fallback logic below
|
||||
parsed_successfully = False # Explicitly set here too
|
||||
update_data = {} # Ensure update_data exists for fallback logic below
|
||||
parsed_successfully = False # Explicitly set here too
|
||||
# Fallback logic for goals/relationships will be handled later if update_data is empty
|
||||
|
||||
# Defensive check for update_data type if parsing was initially considered successful
|
||||
if parsed_successfully:
|
||||
if parsed_successfully:
|
||||
if isinstance(update_data, str):
|
||||
logger.error(f"[{power_name}] _extract_json_from_text returned a string, not a dict/list, despite not raising an exception. This indicates an unexpected parsing issue. String returned: {update_data[:300]}...")
|
||||
update_data = {} # Treat as parsing failure
|
||||
logger.error(
|
||||
f"[{power_name}] _extract_json_from_text returned a string, not a dict/list, despite not raising an exception. This indicates an unexpected parsing issue. String returned: {update_data[:300]}..."
|
||||
)
|
||||
update_data = {} # Treat as parsing failure
|
||||
parsed_successfully = False
|
||||
success_status = "Failure: ParsedAsStr"
|
||||
elif not isinstance(update_data, dict): # Expecting a dict from JSON object
|
||||
logger.error(f"[{power_name}] _extract_json_from_text returned a non-dict type ({type(update_data)}), expected dict. Data: {str(update_data)[:300]}")
|
||||
update_data = {} # Treat as parsing failure
|
||||
elif not isinstance(update_data, dict): # Expecting a dict from JSON object
|
||||
logger.error(
|
||||
f"[{power_name}] _extract_json_from_text returned a non-dict type ({type(update_data)}), expected dict. Data: {str(update_data)[:300]}"
|
||||
)
|
||||
update_data = {} # Treat as parsing failure
|
||||
parsed_successfully = False
|
||||
success_status = "Failure: NotADict"
|
||||
|
||||
|
|
@ -122,8 +121,8 @@ async def initialize_agent_state_ext(
|
|||
initial_relationships_applied = False
|
||||
|
||||
if parsed_successfully:
|
||||
initial_goals = update_data.get('initial_goals') or update_data.get('goals')
|
||||
initial_relationships = update_data.get('initial_relationships') or update_data.get('relationships')
|
||||
initial_goals = update_data.get("initial_goals") or update_data.get("goals")
|
||||
initial_relationships = update_data.get("initial_relationships") or update_data.get("relationships")
|
||||
|
||||
if isinstance(initial_goals, list) and initial_goals:
|
||||
agent.goals = initial_goals
|
||||
|
|
@ -152,30 +151,30 @@ async def initialize_agent_state_ext(
|
|||
else:
|
||||
logger.warning(f"[{power_name}] No valid relationships found in LLM response.")
|
||||
else:
|
||||
logger.warning(f"[{power_name}] LLM did not provide valid 'initial_relationships' dict (got: {initial_relationships}).")
|
||||
|
||||
logger.warning(f"[{power_name}] LLM did not provide valid 'initial_relationships' dict (got: {initial_relationships}).")
|
||||
|
||||
if initial_goals_applied or initial_relationships_applied:
|
||||
success_status = "Success: Applied LLM data"
|
||||
elif parsed_successfully: # Parsed but nothing useful to apply
|
||||
elif parsed_successfully: # Parsed but nothing useful to apply
|
||||
success_status = "Success: Parsed but no data applied"
|
||||
# If not parsed_successfully, success_status is already "Failure: JSONDecodeError"
|
||||
|
||||
# Fallback if LLM data was not applied or parsing failed
|
||||
if not initial_goals_applied:
|
||||
if not agent.goals: # Only set defaults if no goals were set during agent construction or by LLM
|
||||
if not agent.goals: # Only set defaults if no goals were set during agent construction or by LLM
|
||||
agent.goals = ["Survive and expand", "Form beneficial alliances", "Secure key territories"]
|
||||
agent.add_journal_entry(f"[{current_phase}] Set default initial goals as LLM provided none or parse failed.")
|
||||
logger.info(f"[{power_name}] Default goals set.")
|
||||
|
||||
|
||||
if not initial_relationships_applied:
|
||||
# Check if relationships are still default-like before overriding
|
||||
# Check if relationships are still default-like before overriding
|
||||
is_default_relationships = True
|
||||
if agent.relationships: # Check if it's not empty
|
||||
if agent.relationships: # Check if it's not empty
|
||||
for p in ALL_POWERS:
|
||||
if p != power_name and agent.relationships.get(p) != "Neutral":
|
||||
is_default_relationships = False
|
||||
break
|
||||
if is_default_relationships:
|
||||
if is_default_relationships:
|
||||
agent.relationships = {p: "Neutral" for p in ALL_POWERS if p != power_name}
|
||||
agent.add_journal_entry(f"[{current_phase}] Set default neutral relationships as LLM provided none valid or parse failed.")
|
||||
logger.info(f"[{power_name}] Default neutral relationships set.")
|
||||
|
|
@ -191,16 +190,16 @@ async def initialize_agent_state_ext(
|
|||
agent.relationships = {p: "Neutral" for p in ALL_POWERS if p != power_name}
|
||||
logger.info(f"[{power_name}] Set fallback neutral relationships after top-level error: {agent.relationships}")
|
||||
finally:
|
||||
if log_file_path: # Ensure log_file_path is provided
|
||||
if log_file_path: # Ensure log_file_path is provided
|
||||
log_llm_response(
|
||||
log_file_path=log_file_path,
|
||||
model_name=agent.client.model_name if agent and agent.client else "UnknownModel",
|
||||
power_name=power_name,
|
||||
phase=current_phase,
|
||||
response_type="initial_state_setup", # Specific type for CSV logging
|
||||
response_type="initial_state_setup", # Specific type for CSV logging
|
||||
raw_input_prompt=full_prompt,
|
||||
raw_response=response,
|
||||
success=success_status
|
||||
success=success_status,
|
||||
)
|
||||
|
||||
# Final log of state after initialization attempt
|
||||
|
|
|
|||
|
|
@ -209,4 +209,3 @@ def _patched_generate_phase_summary(self: Game, phase_key, summary_callback=None
|
|||
Game._generate_phase_summary = _patched_generate_phase_summary # type: ignore[assignment]
|
||||
|
||||
LOGGER.info("Game._generate_phase_summary patched with narrative generation.")
|
||||
|
||||
|
|
|
|||
|
|
@ -6,8 +6,7 @@ from typing import Dict, TYPE_CHECKING
|
|||
from diplomacy.engine.message import Message, GLOBAL
|
||||
|
||||
from .agent import DiplomacyAgent
|
||||
from .clients import load_model_client
|
||||
from .utils import gather_possible_orders, load_prompt, normalize_recipient_name
|
||||
from .utils import gather_possible_orders, normalize_recipient_name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .game_history import GameHistory
|
||||
|
|
@ -21,9 +20,9 @@ load_dotenv()
|
|||
|
||||
|
||||
async def conduct_negotiations(
|
||||
game: 'Game',
|
||||
game: "Game",
|
||||
agents: Dict[str, DiplomacyAgent],
|
||||
game_history: 'GameHistory',
|
||||
game_history: "GameHistory",
|
||||
model_error_stats: Dict[str, Dict[str, int]],
|
||||
log_file_path: str,
|
||||
max_rounds: int = 3,
|
||||
|
|
@ -35,13 +34,9 @@ async def conduct_negotiations(
|
|||
"""
|
||||
logger.info("Starting negotiation phase.")
|
||||
|
||||
active_powers = [
|
||||
p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()
|
||||
]
|
||||
eliminated_powers = [
|
||||
p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()
|
||||
]
|
||||
|
||||
active_powers = [p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()]
|
||||
eliminated_powers = [p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()]
|
||||
|
||||
logger.info(f"Active powers for negotiations: {active_powers}")
|
||||
if eliminated_powers:
|
||||
logger.info(f"Eliminated powers (skipped): {eliminated_powers}")
|
||||
|
|
@ -51,7 +46,7 @@ async def conduct_negotiations(
|
|||
# We do up to 'max_rounds' single-message turns for each power
|
||||
for round_index in range(max_rounds):
|
||||
logger.info(f"Negotiation Round {round_index + 1}/{max_rounds}")
|
||||
|
||||
|
||||
# Prepare tasks for asyncio.gather
|
||||
tasks = []
|
||||
power_names_for_tasks = []
|
||||
|
|
@ -99,28 +94,28 @@ async def conduct_negotiations(
|
|||
# Process results
|
||||
for i, result in enumerate(results):
|
||||
power_name = power_names_for_tasks[i]
|
||||
agent = agents[power_name] # Get agent again for journaling
|
||||
model_name = agent.client.model_name # Get model name for stats
|
||||
agent = agents[power_name] # Get agent again for journaling
|
||||
model_name = agent.client.model_name # Get model name for stats
|
||||
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Error getting conversation reply for {power_name}: {result}", exc_info=result)
|
||||
# Use model_name for stats key if possible
|
||||
if model_name in model_error_stats:
|
||||
model_error_stats[model_name]["conversation_errors"] += 1
|
||||
else: # Fallback to power_name if model name not tracked (shouldn't happen)
|
||||
model_error_stats.setdefault(power_name, {}).setdefault("conversation_errors", 0)
|
||||
model_error_stats[power_name]["conversation_errors"] += 1
|
||||
messages = [] # Treat as no messages on error
|
||||
elif result is None: # Handle case where client might return None on internal error
|
||||
logger.warning(f"Received None instead of messages for {power_name}.")
|
||||
messages = []
|
||||
if model_name in model_error_stats:
|
||||
model_error_stats[model_name]["conversation_errors"] += 1
|
||||
else:
|
||||
model_error_stats.setdefault(power_name, {}).setdefault("conversation_errors", 0)
|
||||
model_error_stats[power_name]["conversation_errors"] += 1
|
||||
model_error_stats[model_name]["conversation_errors"] += 1
|
||||
else: # Fallback to power_name if model name not tracked (shouldn't happen)
|
||||
model_error_stats.setdefault(power_name, {}).setdefault("conversation_errors", 0)
|
||||
model_error_stats[power_name]["conversation_errors"] += 1
|
||||
messages = [] # Treat as no messages on error
|
||||
elif result is None: # Handle case where client might return None on internal error
|
||||
logger.warning(f"Received None instead of messages for {power_name}.")
|
||||
messages = []
|
||||
if model_name in model_error_stats:
|
||||
model_error_stats[model_name]["conversation_errors"] += 1
|
||||
else:
|
||||
model_error_stats.setdefault(power_name, {}).setdefault("conversation_errors", 0)
|
||||
model_error_stats[power_name]["conversation_errors"] += 1
|
||||
else:
|
||||
messages = result # result is the list of message dicts
|
||||
messages = result # result is the list of message dicts
|
||||
logger.debug(f"Received {len(messages)} message(s) from {power_name}.")
|
||||
|
||||
# Process the received messages (same logic as before)
|
||||
|
|
@ -134,27 +129,27 @@ async def conduct_negotiations(
|
|||
# Create an official message in the Diplomacy engine
|
||||
# Determine recipient based on message type
|
||||
if message.get("message_type") == "private":
|
||||
recipient = normalize_recipient_name(message.get("recipient", GLOBAL)) # Default to GLOBAL if recipient missing somehow
|
||||
recipient = normalize_recipient_name(message.get("recipient", GLOBAL)) # Default to GLOBAL if recipient missing somehow
|
||||
if recipient not in game.powers and recipient != GLOBAL:
|
||||
logger.warning(f"Invalid recipient '{recipient}' in message from {power_name}. Sending globally.")
|
||||
recipient = GLOBAL # Fallback to GLOBAL if recipient power is invalid
|
||||
else: # Assume global if not private or type is missing
|
||||
recipient = GLOBAL # Fallback to GLOBAL if recipient power is invalid
|
||||
else: # Assume global if not private or type is missing
|
||||
recipient = GLOBAL
|
||||
|
||||
|
||||
diplo_message = Message(
|
||||
phase=game.current_short_phase,
|
||||
sender=power_name,
|
||||
recipient=recipient, # Use determined recipient
|
||||
message=message.get("content", ""), # Use .get for safety
|
||||
time_sent=None, # Let the engine assign time
|
||||
recipient=recipient, # Use determined recipient
|
||||
message=message.get("content", ""), # Use .get for safety
|
||||
time_sent=None, # Let the engine assign time
|
||||
)
|
||||
game.add_message(diplo_message)
|
||||
# Also add to our custom history
|
||||
game_history.add_message(
|
||||
game.current_short_phase,
|
||||
power_name,
|
||||
recipient, # Use determined recipient here too
|
||||
message.get("content", ""), # Use .get for safety
|
||||
recipient, # Use determined recipient here too
|
||||
message.get("content", ""), # Use .get for safety
|
||||
)
|
||||
journal_recipient = f"to {recipient}" if recipient != GLOBAL else "globally"
|
||||
agent.add_journal_entry(f"Sent message {journal_recipient} in {game.current_short_phase}: {message.get('content', '')[:100]}...")
|
||||
|
|
|
|||
|
|
@ -1,18 +1,17 @@
|
|||
from dotenv import load_dotenv
|
||||
import logging
|
||||
import concurrent.futures
|
||||
from typing import Dict
|
||||
|
||||
from .clients import load_model_client
|
||||
from .game_history import GameHistory
|
||||
from .agent import DiplomacyAgent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def planning_phase(
|
||||
game,
|
||||
agents: Dict[str, DiplomacyAgent],
|
||||
game_history: GameHistory,
|
||||
game,
|
||||
agents: Dict[str, DiplomacyAgent],
|
||||
game_history: GameHistory,
|
||||
model_error_stats,
|
||||
log_file_path: str,
|
||||
):
|
||||
|
|
@ -20,24 +19,18 @@ async def planning_phase(
|
|||
Lets each power generate a strategic plan using their DiplomacyAgent.
|
||||
"""
|
||||
logger.info(f"Starting planning phase for {game.current_short_phase}...")
|
||||
active_powers = [
|
||||
p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()
|
||||
]
|
||||
eliminated_powers = [
|
||||
p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()
|
||||
]
|
||||
|
||||
active_powers = [p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()]
|
||||
eliminated_powers = [p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()]
|
||||
|
||||
logger.info(f"Active powers for planning: {active_powers}")
|
||||
if eliminated_powers:
|
||||
logger.info(f"Eliminated powers (skipped): {eliminated_powers}")
|
||||
else:
|
||||
logger.info("No eliminated powers yet.")
|
||||
|
||||
|
||||
board_state = game.get_state()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=len(active_powers)
|
||||
) as executor:
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=len(active_powers)) as executor:
|
||||
futures = {}
|
||||
for power_name in active_powers:
|
||||
if power_name not in agents:
|
||||
|
|
@ -45,7 +38,7 @@ async def planning_phase(
|
|||
continue
|
||||
agent = agents[power_name]
|
||||
client = agent.client
|
||||
|
||||
|
||||
future = executor.submit(
|
||||
client.get_plan,
|
||||
game,
|
||||
|
|
@ -67,20 +60,18 @@ async def planning_phase(
|
|||
try:
|
||||
plan_result = future.result()
|
||||
logger.info(f"Received planning result from {power_name}.")
|
||||
|
||||
|
||||
if plan_result.startswith("Error:"):
|
||||
logger.warning(f"Agent {power_name} reported an error during planning: {plan_result}")
|
||||
if power_name in model_error_stats:
|
||||
model_error_stats[power_name].setdefault('planning_generation_errors', 0)
|
||||
model_error_stats[power_name]['planning_generation_errors'] += 1
|
||||
else:
|
||||
model_error_stats.setdefault(f'{power_name}_planning_generation_errors', 0)
|
||||
model_error_stats[f'{power_name}_planning_generation_errors'] += 1
|
||||
logger.warning(f"Agent {power_name} reported an error during planning: {plan_result}")
|
||||
if power_name in model_error_stats:
|
||||
model_error_stats[power_name].setdefault("planning_generation_errors", 0)
|
||||
model_error_stats[power_name]["planning_generation_errors"] += 1
|
||||
else:
|
||||
model_error_stats.setdefault(f"{power_name}_planning_generation_errors", 0)
|
||||
model_error_stats[f"{power_name}_planning_generation_errors"] += 1
|
||||
elif plan_result:
|
||||
agent.add_journal_entry(f"Generated plan for {game.current_short_phase}: {plan_result[:100]}...")
|
||||
game_history.add_plan(
|
||||
game.current_short_phase, power_name, plan_result
|
||||
)
|
||||
game_history.add_plan(game.current_short_phase, power_name, plan_result)
|
||||
logger.debug(f"Added plan for {power_name} to history.")
|
||||
else:
|
||||
logger.warning(f"Agent {power_name} returned an empty plan.")
|
||||
|
|
@ -88,11 +79,11 @@ async def planning_phase(
|
|||
except Exception as e:
|
||||
logger.error(f"Exception during planning result processing for {power_name}: {e}")
|
||||
if power_name in model_error_stats:
|
||||
model_error_stats[power_name].setdefault('planning_execution_errors', 0)
|
||||
model_error_stats[power_name]['planning_execution_errors'] += 1
|
||||
model_error_stats[power_name].setdefault("planning_execution_errors", 0)
|
||||
model_error_stats[power_name]["planning_execution_errors"] += 1
|
||||
else:
|
||||
model_error_stats.setdefault(f'{power_name}_planning_execution_errors', 0)
|
||||
model_error_stats[f'{power_name}_planning_execution_errors'] += 1
|
||||
|
||||
model_error_stats.setdefault(f"{power_name}_planning_execution_errors", 0)
|
||||
model_error_stats[f"{power_name}_planning_execution_errors"] += 1
|
||||
|
||||
logger.info("Planning phase processing complete.")
|
||||
return game_history
|
||||
return game_history
|
||||
|
|
|
|||
|
|
@ -8,12 +8,13 @@ import logging
|
|||
import re
|
||||
|
||||
# Placeholder for actual map type from diplomacy.engine.map.Map
|
||||
# GameMap = Any
|
||||
# GameMap = Any
|
||||
# Type hint for board_state dictionary from game.get_state()
|
||||
# BoardState = Dict[str, Any]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]:
|
||||
"""
|
||||
Return graph[PROV]['ARMY'|'FLEET'] = list of 3-letter neighbour provinces.
|
||||
|
|
@ -22,20 +23,18 @@ def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]:
|
|||
"""
|
||||
# ── collect all 3-letter province codes ───────────────────────────────
|
||||
provs: Set[str] = {
|
||||
loc.split("/")[0][:3].upper() # 'BUL/EC' -> 'BUL'
|
||||
loc.split("/")[0][:3].upper() # 'BUL/EC' -> 'BUL'
|
||||
for loc in game_map.locs
|
||||
if len(loc.split("/")[0]) == 3
|
||||
}
|
||||
|
||||
graph: Dict[str, Dict[str, List[str]]] = {
|
||||
p: {"ARMY": [], "FLEET": []} for p in provs
|
||||
}
|
||||
graph: Dict[str, Dict[str, List[str]]] = {p: {"ARMY": [], "FLEET": []} for p in provs}
|
||||
|
||||
# ── helper: list every concrete variant of a province ─────────────────
|
||||
def variants(code: str) -> List[str]:
|
||||
lst = list(game_map.loc_coasts.get(code, []))
|
||||
if code not in lst:
|
||||
lst.append(code) # ensure base node included
|
||||
lst.append(code) # ensure base node included
|
||||
return lst
|
||||
|
||||
# ── populate adjacency by brute-force queries to `abuts()` ────────────
|
||||
|
|
@ -55,48 +54,43 @@ def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]:
|
|||
graph[src]["ARMY"].append(dest)
|
||||
|
||||
# FLEET — any src variant that can host a fleet is valid
|
||||
if any(
|
||||
game_map.abuts("F", sv, "-", dv)
|
||||
for sv in src_vers
|
||||
for dv in dest_vers
|
||||
):
|
||||
if any(game_map.abuts("F", sv, "-", dv) for sv in src_vers for dv in dest_vers):
|
||||
graph[src]["FLEET"].append(dest)
|
||||
|
||||
# ── tidy up duplicates / order ---------------------------------------
|
||||
for p in graph:
|
||||
graph[p]["ARMY"] = sorted(set(graph[p]["ARMY"]))
|
||||
graph[p]["ARMY"] = sorted(set(graph[p]["ARMY"]))
|
||||
graph[p]["FLEET"] = sorted(set(graph[p]["FLEET"]))
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
|
||||
def bfs_shortest_path(
|
||||
graph: Dict[str, Dict[str, List[str]]],
|
||||
board_state: BoardState,
|
||||
game_map: GameMap, # Added game_map
|
||||
start_loc_full: str, # This is a FULL location name like 'VIE' or 'STP/SC'
|
||||
unit_type: str,
|
||||
is_target_func: Callable[[str, BoardState], bool] # Expects SHORT name for loc
|
||||
) -> Optional[List[str]]: # Returns path of SHORT names
|
||||
graph: Dict[str, Dict[str, List[str]]],
|
||||
board_state: BoardState,
|
||||
game_map: GameMap, # Added game_map
|
||||
start_loc_full: str, # This is a FULL location name like 'VIE' or 'STP/SC'
|
||||
unit_type: str,
|
||||
is_target_func: Callable[[str, BoardState], bool], # Expects SHORT name for loc
|
||||
) -> Optional[List[str]]: # Returns path of SHORT names
|
||||
"""Performs BFS to find the shortest path from start_loc to a target satisfying is_target_func."""
|
||||
|
||||
|
||||
# Convert full start location to short province name
|
||||
start_loc_short = game_map.loc_name.get(start_loc_full, start_loc_full)
|
||||
if '/' in start_loc_short: # If it was STP/SC, loc_name gives STP. If it was VIE, loc_name gives VIE.
|
||||
if "/" in start_loc_short: # If it was STP/SC, loc_name gives STP. If it was VIE, loc_name gives VIE.
|
||||
start_loc_short = start_loc_short[:3]
|
||||
# If start_loc_full was already short (e.g. 'VIE'), get might return it as is, or its value if it was a key.
|
||||
# A simpler way for non-coastal full (like 'VIE') or already short:
|
||||
if '/' not in start_loc_full:
|
||||
start_loc_short = start_loc_full[:3] # Ensures 'VIE' -> 'VIE', 'PAR' -> 'PAR'
|
||||
else: # Has '/', e.g. 'STP/SC'
|
||||
start_loc_short = start_loc_full[:3] # 'STP/SC' -> 'STP'
|
||||
if "/" not in start_loc_full:
|
||||
start_loc_short = start_loc_full[:3] # Ensures 'VIE' -> 'VIE', 'PAR' -> 'PAR'
|
||||
else: # Has '/', e.g. 'STP/SC'
|
||||
start_loc_short = start_loc_full[:3] # 'STP/SC' -> 'STP'
|
||||
|
||||
if start_loc_short not in graph:
|
||||
logger.warning(f"BFS: Start province {start_loc_short} (from {start_loc_full}) not in graph. Pathfinding may fail.")
|
||||
return None
|
||||
|
||||
queue: deque[Tuple[str, List[str]]] = deque([(start_loc_short, [start_loc_short])])
|
||||
queue: deque[Tuple[str, List[str]]] = deque([(start_loc_short, [start_loc_short])])
|
||||
visited_nodes: Set[str] = {start_loc_short}
|
||||
|
||||
while queue:
|
||||
|
|
@ -104,14 +98,14 @@ def bfs_shortest_path(
|
|||
|
||||
# is_target_func expects a short location name
|
||||
if is_target_func(current_loc_short, board_state):
|
||||
return path # Path of short names
|
||||
return path # Path of short names
|
||||
|
||||
# possible_neighbors are SHORT names from the graph
|
||||
possible_neighbors_short = graph.get(current_loc_short, {}).get(unit_type, [])
|
||||
|
||||
|
||||
for next_loc_short in possible_neighbors_short:
|
||||
if next_loc_short not in visited_nodes:
|
||||
if next_loc_short not in graph: # Defensive check for neighbors not in graph keys
|
||||
if next_loc_short not in graph: # Defensive check for neighbors not in graph keys
|
||||
logger.warning(f"BFS: Neighbor {next_loc_short} of {current_loc_short} not in graph. Skipping.")
|
||||
continue
|
||||
visited_nodes.add(next_loc_short)
|
||||
|
|
@ -119,11 +113,12 @@ def bfs_shortest_path(
|
|||
queue.append((next_loc_short, new_path))
|
||||
return None
|
||||
|
||||
|
||||
# --- Helper functions for context generation ---
|
||||
def get_unit_at_location(board_state: BoardState, location: str) -> Optional[str]:
|
||||
"""Returns the full unit string (e.g., 'A PAR (FRA)') if a unit is at the location, else None."""
|
||||
for power, unit_list in board_state.get('units', {}).items():
|
||||
for unit_str in unit_list: # e.g., "A PAR", "F STP/SC"
|
||||
for power, unit_list in board_state.get("units", {}).items():
|
||||
for unit_str in unit_list: # e.g., "A PAR", "F STP/SC"
|
||||
parts = unit_str.split(" ")
|
||||
if len(parts) == 2:
|
||||
unit_map_loc = parts[1]
|
||||
|
|
@ -131,27 +126,29 @@ def get_unit_at_location(board_state: BoardState, location: str) -> Optional[str
|
|||
return f"{parts[0]} {location} ({power})"
|
||||
return None
|
||||
|
||||
|
||||
def get_sc_controller(game_map: GameMap, board_state: BoardState, location: str) -> Optional[str]:
|
||||
"""Returns the controlling power's name if the location is an SC, else None."""
|
||||
# Normalize location to base province name, as SCs are tied to provinces, not specific coasts
|
||||
loc_province_name = game_map.loc_name.get(location, location).upper()[:3]
|
||||
if loc_province_name not in game_map.scs:
|
||||
return None
|
||||
for power, sc_list in board_state.get('centers', {}).items():
|
||||
for power, sc_list in board_state.get("centers", {}).items():
|
||||
if loc_province_name in sc_list:
|
||||
return power
|
||||
return None # Unowned SC
|
||||
return None # Unowned SC
|
||||
|
||||
|
||||
def get_shortest_path_to_friendly_unit(
|
||||
board_state: BoardState,
|
||||
board_state: BoardState,
|
||||
graph: Dict[str, Dict[str, List[str]]],
|
||||
game_map: GameMap, # Added game_map
|
||||
power_name: str,
|
||||
start_unit_loc_full: str,
|
||||
start_unit_type: str
|
||||
game_map: GameMap, # Added game_map
|
||||
power_name: str,
|
||||
start_unit_loc_full: str,
|
||||
start_unit_type: str,
|
||||
) -> Optional[Tuple[str, List[str]]]:
|
||||
"""Finds the shortest path to any friendly unit of the same power."""
|
||||
|
||||
|
||||
def is_target_friendly(loc_short: str, current_board_state: BoardState) -> bool:
|
||||
# loc_short is a short province name. Need to check all its full locations.
|
||||
full_locs_for_short = game_map.loc_coasts.get(loc_short, [loc_short])
|
||||
|
|
@ -162,7 +159,7 @@ def get_shortest_path_to_friendly_unit(
|
|||
return False
|
||||
|
||||
path_short_names = bfs_shortest_path(graph, board_state, game_map, start_unit_loc_full, start_unit_type, is_target_friendly)
|
||||
if path_short_names and len(path_short_names) > 1: # Path includes start, so > 1 means a distinct friendly unit found
|
||||
if path_short_names and len(path_short_names) > 1: # Path includes start, so > 1 means a distinct friendly unit found
|
||||
target_loc_short = path_short_names[-1]
|
||||
# Find the actual friendly unit string at one of the full locations of target_loc_short
|
||||
friendly_unit_str = "UNKNOWN_FRIENDLY_UNIT"
|
||||
|
|
@ -177,41 +174,41 @@ def get_shortest_path_to_friendly_unit(
|
|||
|
||||
|
||||
def get_nearest_enemy_units(
|
||||
board_state: BoardState,
|
||||
board_state: BoardState,
|
||||
graph: Dict[str, Dict[str, List[str]]],
|
||||
game_map: GameMap, # Added game_map
|
||||
power_name: str,
|
||||
start_unit_loc_full: str,
|
||||
start_unit_type: str,
|
||||
n: int = 3
|
||||
game_map: GameMap, # Added game_map
|
||||
power_name: str,
|
||||
start_unit_loc_full: str,
|
||||
start_unit_type: str,
|
||||
n: int = 3,
|
||||
) -> List[Tuple[str, List[str]]]:
|
||||
"""Finds up to N nearest enemy units, sorted by path length."""
|
||||
enemy_paths: List[Tuple[str, List[str]]] = [] # (enemy_unit_str, path_short_names)
|
||||
|
||||
all_enemy_unit_locations_full: List[Tuple[str,str]] = [] # (loc_full, unit_str_full)
|
||||
enemy_paths: List[Tuple[str, List[str]]] = [] # (enemy_unit_str, path_short_names)
|
||||
|
||||
all_enemy_unit_locations_full: List[Tuple[str, str]] = [] # (loc_full, unit_str_full)
|
||||
# board_state.get("units", {}) has format: { "POWER_NAME": ["A PAR", "F BRE"], ... }
|
||||
for p_name, unit_list_for_power in board_state.get("units", {}).items():
|
||||
if p_name != power_name: # If it's an enemy power
|
||||
for unit_repr_from_state in unit_list_for_power: # e.g., "A PAR" or "F STP/SC"
|
||||
if p_name != power_name: # If it's an enemy power
|
||||
for unit_repr_from_state in unit_list_for_power: # e.g., "A PAR" or "F STP/SC"
|
||||
parts = unit_repr_from_state.split(" ")
|
||||
if len(parts) == 2:
|
||||
# unit_type_char = parts[0] # 'A' or 'F'
|
||||
loc_full = parts[1] # 'PAR' or 'STP/SC'
|
||||
|
||||
loc_full = parts[1] # 'PAR' or 'STP/SC'
|
||||
|
||||
# Use get_unit_at_location to get the consistent full unit string like "A PAR (POWER_NAME)"
|
||||
full_unit_str_with_power = get_unit_at_location(board_state, loc_full)
|
||||
if full_unit_str_with_power: # Should find the unit if iteration is correct
|
||||
all_enemy_unit_locations_full.append((loc_full, full_unit_str_with_power))
|
||||
if full_unit_str_with_power: # Should find the unit if iteration is correct
|
||||
all_enemy_unit_locations_full.append((loc_full, full_unit_str_with_power))
|
||||
|
||||
for target_enemy_loc_full, enemy_unit_str in all_enemy_unit_locations_full:
|
||||
target_enemy_loc_short = game_map.loc_name.get(target_enemy_loc_full, target_enemy_loc_full)
|
||||
if '/' in target_enemy_loc_short:
|
||||
if "/" in target_enemy_loc_short:
|
||||
target_enemy_loc_short = target_enemy_loc_short[:3]
|
||||
if '/' not in target_enemy_loc_full:
|
||||
if "/" not in target_enemy_loc_full:
|
||||
target_enemy_loc_short = target_enemy_loc_full[:3]
|
||||
else:
|
||||
target_enemy_loc_short = target_enemy_loc_full[:3]
|
||||
|
||||
|
||||
def is_specific_enemy_loc(loc_short: str, current_board_state: BoardState) -> bool:
|
||||
# Check if loc_short corresponds to target_enemy_loc_full
|
||||
return loc_short == target_enemy_loc_short
|
||||
|
|
@ -219,8 +216,8 @@ def get_nearest_enemy_units(
|
|||
path_short_names = bfs_shortest_path(graph, board_state, game_map, start_unit_loc_full, start_unit_type, is_specific_enemy_loc)
|
||||
if path_short_names:
|
||||
enemy_paths.append((enemy_unit_str, path_short_names))
|
||||
|
||||
enemy_paths.sort(key=lambda x: len(x[1])) # Sort by path length
|
||||
|
||||
enemy_paths.sort(key=lambda x: len(x[1])) # Sort by path length
|
||||
return enemy_paths[:n]
|
||||
|
||||
|
||||
|
|
@ -242,10 +239,10 @@ def get_nearest_uncontrolled_scs(
|
|||
"""
|
||||
results: List[Tuple[str, int, List[str]]] = []
|
||||
|
||||
for sc_short in game_map.scs: # all SC province codes
|
||||
for sc_short in game_map.scs: # all SC province codes
|
||||
controller = get_sc_controller(game_map, board_state, sc_short)
|
||||
if controller == power_name:
|
||||
continue # already ours
|
||||
continue # already ours
|
||||
|
||||
# helper for BFS target test
|
||||
def is_target(loc_short: str, _state: BoardState) -> bool:
|
||||
|
|
@ -260,9 +257,9 @@ def get_nearest_uncontrolled_scs(
|
|||
is_target,
|
||||
)
|
||||
if not path:
|
||||
continue # unreachable
|
||||
continue # unreachable
|
||||
|
||||
distance = len(path) - 1 # moves needed
|
||||
distance = len(path) - 1 # moves needed
|
||||
|
||||
# skip distance 0 (same province) and 1 (adjacent)
|
||||
if distance <= 1:
|
||||
|
|
@ -277,11 +274,11 @@ def get_nearest_uncontrolled_scs(
|
|||
|
||||
|
||||
def get_adjacent_territory_details(
|
||||
game_map: GameMap,
|
||||
board_state: BoardState,
|
||||
unit_loc_full: str, # The location of the unit whose adjacencies we're checking
|
||||
unit_type: str, # ARMY or FLEET of the unit at unit_loc_full
|
||||
graph: Dict[str, Dict[str, List[str]]]
|
||||
game_map: GameMap,
|
||||
board_state: BoardState,
|
||||
unit_loc_full: str, # The location of the unit whose adjacencies we're checking
|
||||
unit_type: str, # ARMY or FLEET of the unit at unit_loc_full
|
||||
graph: Dict[str, Dict[str, List[str]]],
|
||||
) -> str:
|
||||
"""Generates a string describing adjacent territories and units that can interact with them."""
|
||||
output_lines: List[str] = []
|
||||
|
|
@ -291,37 +288,37 @@ def get_adjacent_territory_details(
|
|||
# For fleets, graph[unit_loc_full]['FLEET'] gives full loc names (incl coasts)
|
||||
# THIS COMMENT IS NOW OUTDATED. Graph uses short names for keys and values.
|
||||
unit_loc_short = game_map.loc_name.get(unit_loc_full, unit_loc_full)
|
||||
if '/' in unit_loc_short:
|
||||
if "/" in unit_loc_short:
|
||||
unit_loc_short = unit_loc_short[:3]
|
||||
if '/' not in unit_loc_full:
|
||||
if "/" not in unit_loc_full:
|
||||
unit_loc_short = unit_loc_full[:3]
|
||||
else:
|
||||
unit_loc_short = unit_loc_full[:3]
|
||||
|
||||
adjacent_locs_short_for_unit = graph.get(unit_loc_short, {}).get(unit_type, [])
|
||||
adjacent_locs_short_for_unit = graph.get(unit_loc_short, {}).get(unit_type, [])
|
||||
|
||||
processed_adj_provinces = set() # To handle cases like STP/NC and STP/SC both being adjacent to BOT
|
||||
processed_adj_provinces = set() # To handle cases like STP/NC and STP/SC both being adjacent to BOT
|
||||
|
||||
for adj_loc_short in adjacent_locs_short_for_unit: # adj_loc_short is already short
|
||||
for adj_loc_short in adjacent_locs_short_for_unit: # adj_loc_short is already short
|
||||
# adj_province_short = game_map.loc_name.get(adj_loc_full, adj_loc_full).upper()[:3] # No longer needed
|
||||
if adj_loc_short in processed_adj_provinces: # adj_loc_short is already short and upper implicitly by map data
|
||||
if adj_loc_short in processed_adj_provinces: # adj_loc_short is already short and upper implicitly by map data
|
||||
continue
|
||||
processed_adj_provinces.add(adj_loc_short)
|
||||
|
||||
adj_loc_type = game_map.loc_type.get(adj_loc_short, 'UNKNOWN').upper()
|
||||
if adj_loc_type == 'COAST' or adj_loc_type == 'LAND':
|
||||
adj_loc_type_display = 'LAND' if adj_loc_type == 'LAND' else 'COAST'
|
||||
elif adj_loc_type == 'WATER':
|
||||
adj_loc_type_display = 'WATER'
|
||||
else: # SHUT etc.
|
||||
adj_loc_type = game_map.loc_type.get(adj_loc_short, "UNKNOWN").upper()
|
||||
if adj_loc_type == "COAST" or adj_loc_type == "LAND":
|
||||
adj_loc_type_display = "LAND" if adj_loc_type == "LAND" else "COAST"
|
||||
elif adj_loc_type == "WATER":
|
||||
adj_loc_type_display = "WATER"
|
||||
else: # SHUT etc.
|
||||
adj_loc_type_display = adj_loc_type
|
||||
|
||||
|
||||
line = f" {adj_loc_short} ({adj_loc_type_display})"
|
||||
|
||||
|
||||
sc_controller = get_sc_controller(game_map, board_state, adj_loc_short)
|
||||
if sc_controller:
|
||||
line += f" SC Control: {sc_controller}"
|
||||
|
||||
|
||||
unit_in_adj_loc = get_unit_at_location(board_state, adj_loc_short)
|
||||
if unit_in_adj_loc:
|
||||
line += f" Units: {unit_in_adj_loc}"
|
||||
|
|
@ -330,15 +327,14 @@ def get_adjacent_territory_details(
|
|||
# "Can support/move to" - Simplified: list units in *further* adjacent provinces
|
||||
# A true "can support/move to" would require checking possible orders of those further units.
|
||||
# further_adj_provinces are short names from the graph
|
||||
further_adj_provinces_short = graph.get(adj_loc_short, {}).get('ARMY', []) + \
|
||||
graph.get(adj_loc_short, {}).get('FLEET', [])
|
||||
|
||||
further_adj_provinces_short = graph.get(adj_loc_short, {}).get("ARMY", []) + graph.get(adj_loc_short, {}).get("FLEET", [])
|
||||
|
||||
supporting_units_info = []
|
||||
processed_further_provinces = set()
|
||||
for further_adj_loc_short in further_adj_provinces_short:
|
||||
# further_adj_province_short = game_map.loc_name.get(further_adj_loc_full, further_adj_loc_full).upper()[:3]
|
||||
# No conversion needed, it's already short
|
||||
if further_adj_loc_short == adj_loc_short or further_adj_loc_short == unit_loc_short: # Don't list itself or origin
|
||||
if further_adj_loc_short == adj_loc_short or further_adj_loc_short == unit_loc_short: # Don't list itself or origin
|
||||
continue
|
||||
if further_adj_loc_short in processed_further_provinces:
|
||||
continue
|
||||
|
|
@ -354,14 +350,14 @@ def get_adjacent_territory_details(
|
|||
temp_unit = get_unit_at_location(board_state, fv_further)
|
||||
if temp_unit:
|
||||
unit_in_further_loc = temp_unit
|
||||
break # Found a unit in one of the coasts/base
|
||||
|
||||
break # Found a unit in one of the coasts/base
|
||||
|
||||
# if not unit_in_further_loc and further_adj_loc_full != further_adj_province_short:
|
||||
# unit_in_further_loc = get_unit_at_location(board_state, further_adj_province_short)
|
||||
|
||||
|
||||
if unit_in_further_loc:
|
||||
supporting_units_info.append(unit_in_further_loc)
|
||||
|
||||
|
||||
if supporting_units_info:
|
||||
output_lines.append(f" => Can support/move to: {', '.join(sorted(list(set(supporting_units_info))))}")
|
||||
|
||||
|
|
@ -378,30 +374,30 @@ def generate_rich_order_context_xml(game: Any, power_name: str, possible_orders_
|
|||
board_state: BoardState = game.get_state()
|
||||
game_map: GameMap = game.map
|
||||
graph = build_diplomacy_graph(game_map)
|
||||
|
||||
|
||||
final_context_lines: List[str] = ["<PossibleOrdersContext>"]
|
||||
|
||||
# Iterate through units that have orders (keys of possible_orders_for_power are unit locations)
|
||||
for unit_loc_full, unit_specific_possible_orders in possible_orders_for_power.items():
|
||||
unit_str_full = get_unit_at_location(board_state, unit_loc_full)
|
||||
if not unit_str_full: # Should not happen if unit_loc_full is from possible_orders keys
|
||||
continue
|
||||
if not unit_str_full: # Should not happen if unit_loc_full is from possible_orders keys
|
||||
continue
|
||||
|
||||
unit_type_char = unit_str_full.split(" ")[0] # 'A' or 'F'
|
||||
unit_type_long = "ARMY" if unit_type_char == 'A' else "FLEET"
|
||||
unit_type_char = unit_str_full.split(" ")[0] # 'A' or 'F'
|
||||
unit_type_long = "ARMY" if unit_type_char == "A" else "FLEET"
|
||||
|
||||
loc_province_short = game_map.loc_name.get(unit_loc_full, unit_loc_full).upper()[:3]
|
||||
loc_type_short = game_map.loc_type.get(loc_province_short, "UNKNOWN").upper()
|
||||
if loc_type_short == 'COAST' or loc_type_short == 'LAND':
|
||||
loc_type_display = 'LAND' if loc_type_short == 'LAND' else 'COAST'
|
||||
if loc_type_short == "COAST" or loc_type_short == "LAND":
|
||||
loc_type_display = "LAND" if loc_type_short == "LAND" else "COAST"
|
||||
else:
|
||||
loc_type_display = loc_type_short
|
||||
|
||||
current_unit_lines: List[str] = []
|
||||
current_unit_lines.append(f' <UnitContext loc="{unit_loc_full}">')
|
||||
|
||||
|
||||
# Unit Information section
|
||||
current_unit_lines.append(' <UnitInformation>')
|
||||
current_unit_lines.append(" <UnitInformation>")
|
||||
sc_owner_at_loc = get_sc_controller(game_map, board_state, unit_loc_full)
|
||||
header_content = f"Strategic territory held by {power_name}: {unit_loc_full} ({loc_type_display})"
|
||||
if sc_owner_at_loc == power_name:
|
||||
|
|
@ -410,55 +406,61 @@ def generate_rich_order_context_xml(game: Any, power_name: str, possible_orders_
|
|||
header_content += f" (SC controlled by {sc_owner_at_loc})"
|
||||
current_unit_lines.append(f" {header_content}")
|
||||
current_unit_lines.append(f" Units present: {unit_str_full}")
|
||||
current_unit_lines.append(' </UnitInformation>')
|
||||
current_unit_lines.append(" </UnitInformation>")
|
||||
|
||||
# Possible moves section
|
||||
current_unit_lines.append(' <PossibleMoves>')
|
||||
current_unit_lines.append(" <PossibleMoves>")
|
||||
current_unit_lines.append(" Possible moves:")
|
||||
for order_str in unit_specific_possible_orders:
|
||||
current_unit_lines.append(f" {order_str}")
|
||||
current_unit_lines.append(' </PossibleMoves>')
|
||||
|
||||
current_unit_lines.append(" </PossibleMoves>")
|
||||
|
||||
# Nearest enemy units section
|
||||
enemy_units_info = get_nearest_enemy_units(board_state, graph, game_map, power_name, unit_loc_full, unit_type_long, n=3)
|
||||
current_unit_lines.append(' <NearestEnemyUnits>')
|
||||
current_unit_lines.append(" <NearestEnemyUnits>")
|
||||
if enemy_units_info:
|
||||
current_unit_lines.append(" Nearest units (not ours):")
|
||||
for enemy_unit_str, enemy_path_short in enemy_units_info:
|
||||
current_unit_lines.append(f" {enemy_unit_str}, path=[{unit_loc_full}→{('→'.join(enemy_path_short[1:])) if len(enemy_path_short) > 1 else enemy_path_short[0]}]")
|
||||
current_unit_lines.append(
|
||||
f" {enemy_unit_str}, path=[{unit_loc_full}→{('→'.join(enemy_path_short[1:])) if len(enemy_path_short) > 1 else enemy_path_short[0]}]"
|
||||
)
|
||||
else:
|
||||
current_unit_lines.append(" Nearest units (not ours): None found")
|
||||
current_unit_lines.append(' </NearestEnemyUnits>')
|
||||
current_unit_lines.append(" </NearestEnemyUnits>")
|
||||
|
||||
# Nearest supply centers (not controlled by us) section
|
||||
uncontrolled_scs_info = get_nearest_uncontrolled_scs(game_map, board_state, graph, power_name, unit_loc_full, unit_type_long, n=3)
|
||||
current_unit_lines.append(' <NearestUncontrolledSupplyCenters>')
|
||||
current_unit_lines.append(" <NearestUncontrolledSupplyCenters>")
|
||||
if uncontrolled_scs_info:
|
||||
current_unit_lines.append(" Nearest supply centers (not controlled by us):")
|
||||
for sc_str, dist, sc_path_short in uncontrolled_scs_info:
|
||||
current_unit_lines.append(f" {sc_str}, dist={dist}, path=[{unit_loc_full}→{('→'.join(sc_path_short[1:])) if len(sc_path_short) > 1 else sc_path_short[0]}]")
|
||||
current_unit_lines.append(
|
||||
f" {sc_str}, dist={dist}, path=[{unit_loc_full}→{('→'.join(sc_path_short[1:])) if len(sc_path_short) > 1 else sc_path_short[0]}]"
|
||||
)
|
||||
else:
|
||||
current_unit_lines.append(" Nearest supply centers (not controlled by us): None found")
|
||||
current_unit_lines.append(' </NearestUncontrolledSupplyCenters>')
|
||||
current_unit_lines.append(" </NearestUncontrolledSupplyCenters>")
|
||||
|
||||
# Adjacent territories details section
|
||||
adj_details_str = get_adjacent_territory_details(game_map, board_state, unit_loc_full, unit_type_long, graph)
|
||||
current_unit_lines.append(' <AdjacentTerritories>')
|
||||
current_unit_lines.append(" <AdjacentTerritories>")
|
||||
if adj_details_str:
|
||||
current_unit_lines.append(" Adjacent territories (including units that can support/move to the adjacent territory):")
|
||||
# Assuming adj_details_str is already formatted with newlines and indentation for its content
|
||||
# We might need to indent adj_details_str if it's a single block of text
|
||||
# For now, let's add a standard indent to each line of adj_details_str if it contains newlines
|
||||
if '\n' in adj_details_str:
|
||||
indented_adj_details = "\n".join([f" {line}" for line in adj_details_str.split('\n')])
|
||||
if "\n" in adj_details_str:
|
||||
indented_adj_details = "\n".join([f" {line}" for line in adj_details_str.split("\n")])
|
||||
current_unit_lines.append(indented_adj_details)
|
||||
else:
|
||||
current_unit_lines.append(f" {adj_details_str}")
|
||||
current_unit_lines.append(f" {adj_details_str}")
|
||||
else:
|
||||
current_unit_lines.append(" Adjacent territories: None relevant or all are empty/uncontested by direct threats.") # Added more descriptive else
|
||||
current_unit_lines.append(' </AdjacentTerritories>')
|
||||
|
||||
current_unit_lines.append(' </UnitContext>')
|
||||
current_unit_lines.append(
|
||||
" Adjacent territories: None relevant or all are empty/uncontested by direct threats."
|
||||
) # Added more descriptive else
|
||||
current_unit_lines.append(" </AdjacentTerritories>")
|
||||
|
||||
current_unit_lines.append(" </UnitContext>")
|
||||
final_context_lines.extend(current_unit_lines)
|
||||
|
||||
final_context_lines.append("</PossibleOrdersContext>")
|
||||
|
|
@ -469,16 +471,16 @@ def generate_rich_order_context_xml(game: Any, power_name: str, possible_orders_
|
|||
# Regex and tiny helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
import re
|
||||
from typing import Tuple, List, Dict, Optional, Any
|
||||
|
||||
# ── order-syntax matchers ─────────────────────────────────────────────────
|
||||
_SIMPLE_MOVE_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? - [A-Z]{3}(?:/[A-Z]{2})?$")
|
||||
_HOLD_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? H$") # NEW
|
||||
_RETREAT_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? R [A-Z]{3}(?:/[A-Z]{2})?$")
|
||||
_ADJUST_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? [BD]$") # build / disband
|
||||
_HOLD_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? H$") # NEW
|
||||
_RETREAT_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? R [A-Z]{3}(?:/[A-Z]{2})?$")
|
||||
_ADJUST_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? [BD]$") # build / disband
|
||||
|
||||
def _is_hold_order(order: str) -> bool: # NEW
|
||||
|
||||
def _is_hold_order(order: str) -> bool: # NEW
|
||||
return bool(_HOLD_RE.match(order.strip()))
|
||||
|
||||
|
||||
|
|
@ -486,15 +488,19 @@ def _norm_power(name: str) -> str:
|
|||
"""Trim & uppercase for reliable comparisons."""
|
||||
return name.strip().upper()
|
||||
|
||||
|
||||
def _is_simple_move(order: str) -> bool:
|
||||
return bool(_SIMPLE_MOVE_RE.match(order.strip()))
|
||||
|
||||
|
||||
def _is_retreat_order(order: str) -> bool:
|
||||
return bool(_RETREAT_RE.match(order.strip()))
|
||||
|
||||
|
||||
def _is_adjust_order(order: str) -> bool:
|
||||
return bool(_ADJUST_RE.match(order.strip()))
|
||||
|
||||
|
||||
def _split_move(order: str) -> Tuple[str, str]:
|
||||
"""Return ('A BUD', 'TRI') from 'A BUD - TRI' (validated move only)."""
|
||||
unit_part, dest = order.split(" - ")
|
||||
|
|
@ -505,6 +511,7 @@ def _split_move(order: str) -> Tuple[str, str]:
|
|||
# Gather *all* friendly support orders for a given move
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _all_support_examples(
|
||||
mover: str,
|
||||
dest: str,
|
||||
|
|
@ -526,6 +533,7 @@ def _all_support_examples(
|
|||
|
||||
return supports
|
||||
|
||||
|
||||
def _all_support_hold_examples(
|
||||
holder: str,
|
||||
all_orders: Dict[str, List[str]],
|
||||
|
|
@ -538,7 +546,7 @@ def _all_support_hold_examples(
|
|||
supports: List[str] = []
|
||||
|
||||
for loc, orders in all_orders.items():
|
||||
if holder.endswith(loc): # skip the holding unit
|
||||
if holder.endswith(loc): # skip the holding unit
|
||||
continue
|
||||
for o in orders:
|
||||
if o.strip().endswith(target):
|
||||
|
|
@ -550,6 +558,7 @@ def _all_support_hold_examples(
|
|||
# Province-type resolver (handles short codes, coasts, seas)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _province_type_display(game_map, prov_short: str) -> str:
|
||||
"""
|
||||
Return 'LAND', 'COAST', or 'WATER' for the 3-letter province code.
|
||||
|
|
@ -575,7 +584,7 @@ def _dest_occupancy_desc(
|
|||
board_state,
|
||||
our_power: str,
|
||||
) -> str:
|
||||
""" '(occupied by X)', '(occupied by X — you!)', or '(unoccupied)' """
|
||||
"""'(occupied by X)', '(occupied by X — you!)', or '(unoccupied)'"""
|
||||
occupant: Optional[str] = None
|
||||
for full in game_map.loc_coasts.get(dest_short, [dest_short]):
|
||||
u = get_unit_at_location(board_state, full)
|
||||
|
|
@ -593,6 +602,7 @@ def _dest_occupancy_desc(
|
|||
# Adjacent-territory lines (used by movement-phase builder)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _adjacent_territory_lines(
|
||||
graph,
|
||||
game_map,
|
||||
|
|
@ -633,13 +643,9 @@ def _adjacent_territory_lines(
|
|||
pwr = unit_here.split(" ")[-1].strip("()")
|
||||
if pwr == our_power:
|
||||
friend_descr = unit_here.split(" (")[0]
|
||||
lines.append(
|
||||
f"{indent2}Support hold: {mover_descr} S {friend_descr}"
|
||||
)
|
||||
lines.append(f"{indent2}Support hold: {mover_descr} S {friend_descr}")
|
||||
else:
|
||||
lines.append(
|
||||
f"{indent2}-> {unit_here} can support or contest {mover_descr}’s moves and vice-versa"
|
||||
)
|
||||
lines.append(f"{indent2}-> {unit_here} can support or contest {mover_descr}’s moves and vice-versa")
|
||||
|
||||
return lines
|
||||
|
||||
|
|
@ -648,6 +654,7 @@ def _adjacent_territory_lines(
|
|||
# Movement-phase generator (UNCHANGED LOGIC)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _generate_rich_order_context_movement(
|
||||
game: Any,
|
||||
power_name: str,
|
||||
|
|
@ -672,18 +679,14 @@ def _generate_rich_order_context_movement(
|
|||
if _norm_power(unit_power) != me:
|
||||
continue # Skip units that aren’t ours
|
||||
|
||||
mover_descr, _ = _split_move(
|
||||
f"{unit_full_str.split(' ')[0]} {unit_loc_full} - {unit_loc_full}"
|
||||
)
|
||||
mover_descr, _ = _split_move(f"{unit_full_str.split(' ')[0]} {unit_loc_full} - {unit_loc_full}")
|
||||
|
||||
prov_short = game_map.loc_name.get(unit_loc_full, unit_loc_full)[:3]
|
||||
prov_type_disp = _province_type_display(game_map, prov_short)
|
||||
sc_tag = " (SC)" if prov_short in game_map.scs else ""
|
||||
|
||||
owner = get_sc_controller(game_map, board_state, unit_loc_full) or "None"
|
||||
owner_line = (
|
||||
f"Held by {owner} (You)" if owner == power_name else f"Held by {owner}"
|
||||
)
|
||||
owner_line = f"Held by {owner} (You)" if owner == power_name else f"Held by {owner}"
|
||||
|
||||
ind = " "
|
||||
block: List[str] = [f"<Territory {prov_short}>"]
|
||||
|
|
@ -693,18 +696,16 @@ def _generate_rich_order_context_movement(
|
|||
|
||||
# ----- Adjacent territories -----
|
||||
block.append("# Adjacent territories:")
|
||||
block.extend(
|
||||
_adjacent_territory_lines(
|
||||
graph, game_map, board_state,
|
||||
unit_loc_full, mover_descr, power_name
|
||||
)
|
||||
)
|
||||
block.extend(_adjacent_territory_lines(graph, game_map, board_state, unit_loc_full, mover_descr, power_name))
|
||||
|
||||
# ----- Nearest enemy units -----
|
||||
block.append("# Nearest units (not ours):")
|
||||
enemies = get_nearest_enemy_units(
|
||||
board_state, graph, game_map,
|
||||
power_name, unit_loc_full,
|
||||
board_state,
|
||||
graph,
|
||||
game_map,
|
||||
power_name,
|
||||
unit_loc_full,
|
||||
"ARMY" if mover_descr.startswith("A") else "FLEET",
|
||||
n=3,
|
||||
)
|
||||
|
|
@ -715,8 +716,11 @@ def _generate_rich_order_context_movement(
|
|||
# ----- Nearest uncontrolled SCs -----
|
||||
block.append("# Nearest supply centers (not controlled by us):")
|
||||
scs = get_nearest_uncontrolled_scs(
|
||||
game_map, board_state, graph,
|
||||
power_name, unit_loc_full,
|
||||
game_map,
|
||||
board_state,
|
||||
graph,
|
||||
power_name,
|
||||
unit_loc_full,
|
||||
"ARMY" if mover_descr.startswith("A") else "FLEET",
|
||||
n=3,
|
||||
)
|
||||
|
|
@ -729,7 +733,7 @@ def _generate_rich_order_context_movement(
|
|||
block.append(f"# Possible {mover_descr} unit movements & supports:")
|
||||
|
||||
simple_moves = [o for o in orders if _is_simple_move(o)]
|
||||
hold_orders = [o for o in orders if _is_hold_order(o)] # NEW
|
||||
hold_orders = [o for o in orders if _is_hold_order(o)] # NEW
|
||||
|
||||
if not simple_moves and not hold_orders:
|
||||
block.append(f"{ind}None")
|
||||
|
|
@ -737,22 +741,19 @@ def _generate_rich_order_context_movement(
|
|||
# ---- Moves (same behaviour as before) ----
|
||||
for mv in simple_moves:
|
||||
mover, dest = _split_move(mv)
|
||||
occ = _dest_occupancy_desc(
|
||||
dest.split("/")[0][:3], game_map, board_state, power_name
|
||||
)
|
||||
occ = _dest_occupancy_desc(dest.split("/")[0][:3], game_map, board_state, power_name)
|
||||
block.append(f"{ind}{mv} {occ}")
|
||||
|
||||
for s in _all_support_examples(mover, dest, possible_orders_for_power):
|
||||
block.append(f"{ind*2}Available Support: {s}")
|
||||
block.append(f"{ind * 2}Available Support: {s}")
|
||||
|
||||
# ---- Holds (new) ----
|
||||
for hd in hold_orders:
|
||||
holder = hd.split(" H")[0] # e.g., 'F DEN'
|
||||
holder = hd.split(" H")[0] # e.g., 'F DEN'
|
||||
block.append(f"{ind}{hd}")
|
||||
|
||||
for s in _all_support_hold_examples(holder, possible_orders_for_power):
|
||||
block.append(f"{ind*2}Available Support: {s}")
|
||||
|
||||
block.append(f"{ind * 2}Available Support: {s}")
|
||||
|
||||
block.append(f"</Territory {prov_short}>")
|
||||
blocks.append("\n".join(block))
|
||||
|
|
@ -764,6 +765,7 @@ def _generate_rich_order_context_movement(
|
|||
# Retreat-phase builder – echo orders verbatim, no tags
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _generate_rich_order_context_retreat(
|
||||
game: Any,
|
||||
power_name: str,
|
||||
|
|
@ -784,11 +786,11 @@ def _generate_rich_order_context_retreat(
|
|||
return "\n".join(lines) if lines else "(No dislodged units)"
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Adjustment-phase builder – summary line + orders, no WAIVEs, no tags
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _generate_rich_order_context_adjustment(
|
||||
game: Any,
|
||||
power_name: str,
|
||||
|
|
@ -800,12 +802,12 @@ def _generate_rich_order_context_adjustment(
|
|||
* No wrapper tags.
|
||||
"""
|
||||
board_state = game.get_state()
|
||||
sc_owned = len(board_state.get("centers", {}).get(power_name, []))
|
||||
units_num = len(board_state.get("units", {}).get(power_name, []))
|
||||
delta = sc_owned - units_num # +ve ⇒ builds, -ve ⇒ disbands
|
||||
sc_owned = len(board_state.get("centers", {}).get(power_name, []))
|
||||
units_num = len(board_state.get("units", {}).get(power_name, []))
|
||||
delta = sc_owned - units_num # +ve ⇒ builds, -ve ⇒ disbands
|
||||
|
||||
# ----- summary line ----------------------------------------------------
|
||||
if delta > 0:
|
||||
if delta > 0:
|
||||
summary = f"Builds available: {delta}"
|
||||
elif delta < 0:
|
||||
summary = f"Disbands required: {-delta}"
|
||||
|
|
@ -824,13 +826,11 @@ def _generate_rich_order_context_adjustment(
|
|||
return "\n".join(lines) if len(lines) > 1 else summary
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phase-dispatch wrapper (public entry point)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def generate_rich_order_context(
|
||||
game: Any,
|
||||
power_name: str,
|
||||
|
|
@ -846,22 +846,14 @@ def generate_rich_order_context(
|
|||
|
||||
phase_type = game.current_short_phase[-1]
|
||||
|
||||
if phase_type == "M": # Movement
|
||||
return _generate_rich_order_context_movement(
|
||||
game, power_name, possible_orders_for_power
|
||||
)
|
||||
if phase_type == "M": # Movement
|
||||
return _generate_rich_order_context_movement(game, power_name, possible_orders_for_power)
|
||||
|
||||
if phase_type == "R": # Retreat
|
||||
return _generate_rich_order_context_retreat(
|
||||
game, power_name, possible_orders_for_power
|
||||
)
|
||||
if phase_type == "R": # Retreat
|
||||
return _generate_rich_order_context_retreat(game, power_name, possible_orders_for_power)
|
||||
|
||||
if phase_type == "A": # Adjustment (build / disband)
|
||||
return _generate_rich_order_context_adjustment(
|
||||
game, power_name, possible_orders_for_power
|
||||
)
|
||||
if phase_type == "A": # Adjustment (build / disband)
|
||||
return _generate_rich_order_context_adjustment(game, power_name, possible_orders_for_power)
|
||||
|
||||
# Fallback – treat unknown formats as movement
|
||||
return _generate_rich_order_context_movement(
|
||||
game, power_name, possible_orders_for_power
|
||||
)
|
||||
return _generate_rich_order_context_movement(game, power_name, possible_orders_for_power)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ from .possible_order_context import (
|
|||
generate_rich_order_context,
|
||||
generate_rich_order_context_xml,
|
||||
)
|
||||
import os
|
||||
from .game_history import GameHistory # Assuming GameHistory is correctly importable
|
||||
|
||||
# placeholder for diplomacy.Game to avoid circular or direct dependency if not needed for typehinting only
|
||||
|
|
@ -230,4 +229,3 @@ def construct_order_generation_prompt(
|
|||
logger.debug(f"Final order generation prompt preview for {power_name}: {final_prompt[:500]}...")
|
||||
|
||||
return final_prompt
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from ..config import config
|
|||
if TYPE_CHECKING:
|
||||
from .clients import BaseModelClient
|
||||
# If DiplomacyAgent is used for type hinting for an 'agent' parameter:
|
||||
# from .agent import DiplomacyAgent
|
||||
# from .agent import DiplomacyAgent
|
||||
|
||||
logger = logging.getLogger("utils")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
|
@ -32,12 +32,12 @@ def atomic_write_json(data: dict, filepath: str):
|
|||
dir_name = os.path.dirname(filepath)
|
||||
if dir_name:
|
||||
os.makedirs(dir_name, exist_ok=True)
|
||||
|
||||
|
||||
# Write to a temporary file in the same directory
|
||||
temp_filepath = f"{filepath}.tmp.{os.getpid()}"
|
||||
with open(temp_filepath, 'w', encoding='utf-8') as f:
|
||||
with open(temp_filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
|
||||
|
||||
# Atomically rename the temporary file to the final destination
|
||||
os.rename(temp_filepath, filepath)
|
||||
except Exception as e:
|
||||
|
|
@ -55,15 +55,15 @@ def assign_models_to_powers() -> Dict[str, str]:
|
|||
Example usage: define which model each power uses.
|
||||
Return a dict: { power_name: model_id, ... }
|
||||
POWERS = ['AUSTRIA', 'ENGLAND', 'FRANCE', 'GERMANY', 'ITALY', 'RUSSIA', 'TURKEY']
|
||||
Models supported: o3-mini, o4-mini, o3, gpt-4o, gpt-4o-mini,
|
||||
claude-opus-4-20250514, claude-sonnet-4-20250514, claude-3-5-haiku-20241022, claude-3-5-sonnet-20241022, claude-3-7-sonnet-20250219
|
||||
gemini-2.0-flash, gemini-2.5-flash-preview-04-17, gemini-2.5-pro-preview-03-25,
|
||||
Models supported: o3-mini, o4-mini, o3, gpt-4o, gpt-4o-mini,
|
||||
claude-opus-4-20250514, claude-sonnet-4-20250514, claude-3-5-haiku-20241022, claude-3-5-sonnet-20241022, claude-3-7-sonnet-20250219
|
||||
gemini-2.0-flash, gemini-2.5-flash-preview-04-17, gemini-2.5-pro-preview-03-25,
|
||||
deepseek-chat, deepseek-reasoner
|
||||
openrouter-meta-llama/llama-3.3-70b-instruct, openrouter-qwen/qwen3-235b-a22b, openrouter-microsoft/phi-4-reasoning-plus:free,
|
||||
openrouter-deepseek/deepseek-prover-v2:free, openrouter-meta-llama/llama-4-maverick:free, openrouter-nvidia/llama-3.3-nemotron-super-49b-v1:free,
|
||||
openrouter-deepseek/deepseek-prover-v2:free, openrouter-meta-llama/llama-4-maverick:free, openrouter-nvidia/llama-3.3-nemotron-super-49b-v1:free,
|
||||
openrouter-google/gemma-3-12b-it:free, openrouter-google/gemini-2.5-flash-preview-05-20
|
||||
"""
|
||||
|
||||
|
||||
# POWER MODELS
|
||||
"""
|
||||
return {
|
||||
|
|
@ -77,13 +77,13 @@ def assign_models_to_powers() -> Dict[str, str]:
|
|||
}
|
||||
"""
|
||||
# TEST MODELS
|
||||
|
||||
|
||||
return {
|
||||
"AUSTRIA": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
|
||||
"ENGLAND": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
|
||||
"FRANCE": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
|
||||
"GERMANY": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
|
||||
"ITALY": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
|
||||
"ITALY": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
|
||||
"RUSSIA": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
|
||||
"TURKEY": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
|
||||
}
|
||||
|
|
@ -92,30 +92,26 @@ def assign_models_to_powers() -> Dict[str, str]:
|
|||
def get_special_models() -> Dict[str, str]:
|
||||
"""
|
||||
Define models for special purposes like phase summaries and formatting.
|
||||
|
||||
|
||||
These can be overridden via environment variables:
|
||||
- AI_DIPLOMACY_NARRATIVE_MODEL: Model for phase summaries (default: "o3")
|
||||
- AI_DIPLOMACY_FORMATTER_MODEL: Model for JSON formatting (default: "google/gemini-2.5-flash-lite-preview-06-17")
|
||||
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
"phase_summary": model for generating narrative phase summaries,
|
||||
"formatter": model for formatting natural language to JSON
|
||||
}
|
||||
|
||||
|
||||
Examples:
|
||||
# Use Claude for phase summaries
|
||||
export AI_DIPLOMACY_NARRATIVE_MODEL="claude-3-5-sonnet-20241022"
|
||||
|
||||
|
||||
# Use a different Gemini model for formatting
|
||||
export AI_DIPLOMACY_FORMATTER_MODEL="gemini-2.0-flash"
|
||||
"""
|
||||
return {
|
||||
"phase_summary": config.AI_DIPLOMACY_NARRATIVE_MODEL,
|
||||
"formatter": config.AI_DIPLOMACY_FORMATTER_MODEL
|
||||
}
|
||||
|
||||
|
||||
return {"phase_summary": config.AI_DIPLOMACY_NARRATIVE_MODEL, "formatter": config.AI_DIPLOMACY_FORMATTER_MODEL}
|
||||
|
||||
|
||||
def gather_possible_orders(game: Game, power_name: str) -> Dict[str, List[str]]:
|
||||
"""
|
||||
|
|
@ -132,7 +128,7 @@ def gather_possible_orders(game: Game, power_name: str) -> Dict[str, List[str]]:
|
|||
|
||||
async def get_valid_orders(
|
||||
game: Game,
|
||||
client, # BaseModelClient instance
|
||||
client, # BaseModelClient instance
|
||||
board_state,
|
||||
power_name: str,
|
||||
possible_orders: Dict[str, List[str]],
|
||||
|
|
@ -170,20 +166,24 @@ async def get_valid_orders(
|
|||
)
|
||||
|
||||
invalid_info: list[str] = []
|
||||
valid: list[str] = []
|
||||
valid: list[str] = []
|
||||
invalid: list[str] = []
|
||||
|
||||
# ── 2. Type check ──────────────────────────────────────────
|
||||
if not isinstance(raw_orders, list):
|
||||
logger.warning("[%s] Orders received from LLM are not a list: %s. Using fallback.",
|
||||
power_name, raw_orders)
|
||||
logger.warning("[%s] Orders received from LLM are not a list: %s. Using fallback.", power_name, raw_orders)
|
||||
model_error_stats[client.model_name]["order_decoding_errors"] += 1
|
||||
return {"valid": client.fallback_orders(possible_orders), "invalid": []}
|
||||
|
||||
# ── 3. Round-trip validation with engine ───────────────────
|
||||
CODE_TO_ENGINE = {
|
||||
"AUT": "AUSTRIA", "ENG": "ENGLAND", "FRA": "FRANCE",
|
||||
"GER": "GERMANY", "ITA": "ITALY", "RUS": "RUSSIA", "TUR": "TURKEY",
|
||||
"AUT": "AUSTRIA",
|
||||
"ENG": "ENGLAND",
|
||||
"FRA": "FRANCE",
|
||||
"GER": "GERMANY",
|
||||
"ITA": "ITALY",
|
||||
"RUS": "RUSSIA",
|
||||
"TUR": "TURKEY",
|
||||
}
|
||||
engine_power = power_name if power_name in game.powers else CODE_TO_ENGINE[power_name]
|
||||
|
||||
|
|
@ -202,16 +202,16 @@ async def get_valid_orders(
|
|||
game.set_orders(engine_power, [upper])
|
||||
normed = game.get_orders(engine_power)
|
||||
|
||||
if normed: # accepted
|
||||
if normed: # accepted
|
||||
valid.append(normed[0])
|
||||
else: # rejected
|
||||
else: # rejected
|
||||
invalid.append(upper)
|
||||
invalid_info.append(f"Order '{move}' is invalid for {power_name}")
|
||||
|
||||
game.clear_orders(engine_power) # clean slate for main engine flow
|
||||
|
||||
# ── 4. Legacy logging & stats updates ──────────────────────
|
||||
if invalid_info: # at least one bad move
|
||||
if invalid_info: # at least one bad move
|
||||
logger.debug("[%s] Invalid orders: %s", power_name, ", ".join(invalid_info))
|
||||
model_error_stats[client.model_name]["order_decoding_errors"] += 1
|
||||
logger.debug("[%s] Some orders invalid, using fallback.", power_name)
|
||||
|
|
@ -226,8 +226,6 @@ async def get_valid_orders(
|
|||
return {"valid": valid, "invalid": invalid}
|
||||
|
||||
|
||||
|
||||
|
||||
def normalize_and_compare_orders(
|
||||
issued_orders: Dict[str, List[str]],
|
||||
accepted_orders_dict: Dict[str, List[str]],
|
||||
|
|
@ -325,18 +323,18 @@ def load_prompt(filename: str, prompts_dir: Optional[str] = None) -> str:
|
|||
3. Elif *prompts_dir* is provided → join prompts_dir + filename.
|
||||
4. Otherwise → join the package’s default prompts dir.
|
||||
"""
|
||||
if os.path.isabs(filename): # rule 1
|
||||
if os.path.isabs(filename): # rule 1
|
||||
prompt_path = filename
|
||||
elif os.path.dirname(filename): # rule 2 (has slash)
|
||||
elif os.path.dirname(filename): # rule 2 (has slash)
|
||||
# If it's a relative path with directory, join with prompts_dir if provided
|
||||
if prompts_dir:
|
||||
prompt_path = os.path.join(prompts_dir, filename)
|
||||
else:
|
||||
default_dir = os.path.join(os.path.dirname(__file__), "prompts")
|
||||
prompt_path = os.path.join(default_dir, filename)
|
||||
elif prompts_dir: # rule 3
|
||||
elif prompts_dir: # rule 3
|
||||
prompt_path = os.path.join(prompts_dir, filename)
|
||||
else: # rule 4
|
||||
else: # rule 4
|
||||
default_dir = os.path.join(os.path.dirname(__file__), "prompts")
|
||||
prompt_path = os.path.join(default_dir, filename)
|
||||
|
||||
|
|
@ -348,16 +346,14 @@ def load_prompt(filename: str, prompts_dir: Optional[str] = None) -> str:
|
|||
return ""
|
||||
|
||||
|
||||
|
||||
|
||||
# == New LLM Response Logging Function ==
|
||||
def log_llm_response(
|
||||
log_file_path: str,
|
||||
model_name: str,
|
||||
power_name: Optional[str], # Optional for non-power-specific calls like summary
|
||||
power_name: Optional[str], # Optional for non-power-specific calls like summary
|
||||
phase: str,
|
||||
response_type: str,
|
||||
raw_input_prompt: str, # Added new parameter for the raw input
|
||||
raw_input_prompt: str, # Added new parameter for the raw input
|
||||
raw_response: str,
|
||||
success: str, # Changed from bool to str
|
||||
):
|
||||
|
|
@ -365,8 +361,8 @@ def log_llm_response(
|
|||
try:
|
||||
# Ensure the directory exists
|
||||
log_dir = os.path.dirname(log_file_path)
|
||||
if log_dir: # Ensure log_dir is not empty (e.g., if path is just a filename)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
if log_dir: # Ensure log_dir is not empty (e.g., if path is just a filename)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
# Check if file exists and has content to determine if we need headers
|
||||
file_exists = os.path.isfile(log_file_path) and os.path.getsize(log_file_path) > 0
|
||||
|
|
@ -374,34 +370,38 @@ def log_llm_response(
|
|||
with open(log_file_path, "a", newline="", encoding="utf-8") as csvfile:
|
||||
# Added "raw_input" to fieldnames
|
||||
fieldnames = ["model", "power", "phase", "response_type", "raw_input", "raw_response", "success"]
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames,
|
||||
quoting=csv.QUOTE_ALL, # Quote all fields to handle commas and newlines
|
||||
escapechar='\\') # Use backslash for escaping
|
||||
writer = csv.DictWriter(
|
||||
csvfile,
|
||||
fieldnames=fieldnames,
|
||||
quoting=csv.QUOTE_ALL, # Quote all fields to handle commas and newlines
|
||||
escapechar="\\",
|
||||
) # Use backslash for escaping
|
||||
|
||||
if not file_exists:
|
||||
writer.writeheader() # Write header only if file is new
|
||||
|
||||
writer.writerow({
|
||||
"model": model_name,
|
||||
"power": power_name if power_name else "game", # Use 'game' if no specific power
|
||||
"phase": phase,
|
||||
"response_type": response_type,
|
||||
"raw_input": raw_input_prompt, # Added raw_input to the row
|
||||
"raw_response": raw_response,
|
||||
"success": success,
|
||||
})
|
||||
writer.writerow(
|
||||
{
|
||||
"model": model_name,
|
||||
"power": power_name if power_name else "game", # Use 'game' if no specific power
|
||||
"phase": phase,
|
||||
"response_type": response_type,
|
||||
"raw_input": raw_input_prompt, # Added raw_input to the row
|
||||
"raw_response": raw_response,
|
||||
"success": success,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log LLM response to {log_file_path}: {e}", exc_info=True)
|
||||
|
||||
|
||||
|
||||
async def run_llm_and_log(
|
||||
client: "BaseModelClient",
|
||||
prompt: str,
|
||||
log_file_path: str, # Kept for context, but not used for logging here
|
||||
log_file_path: str, # Kept for context, but not used for logging here
|
||||
power_name: Optional[str], # Kept for context, but not used for logging here
|
||||
phase: str, # Kept for context, but not used for logging here
|
||||
response_type: str, # Kept for context, but not used for logging here
|
||||
phase: str, # Kept for context, but not used for logging here
|
||||
response_type: str, # Kept for context, but not used for logging here
|
||||
temperature: float = 0.0,
|
||||
*,
|
||||
attempts: int = 5,
|
||||
|
|
@ -428,43 +428,34 @@ async def run_llm_and_log(
|
|||
except Exception as e:
|
||||
if attempt == attempts - 1:
|
||||
logger.error(
|
||||
f"API Error during LLM call for {client.model_name}/{power_name}/{response_type} "
|
||||
f"in phase {phase}: {e}",
|
||||
f"API Error during LLM call for {client.model_name}/{power_name}/{response_type} in phase {phase}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Back-off before the next attempt (unless this was the last)
|
||||
delay = backoff_base * (backoff_factor ** attempt) + random.uniform(0, jitter)
|
||||
delay = backoff_base * (backoff_factor**attempt) + random.uniform(0, jitter)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
return raw_response
|
||||
|
||||
|
||||
|
||||
# This generates a few lines of random alphanum chars to inject into the
|
||||
# system prompt. This lets us use temp=0 while still getting variation
|
||||
# This generates a few lines of random alphanum chars to inject into the
|
||||
# system prompt. This lets us use temp=0 while still getting variation
|
||||
# between trials.
|
||||
# Temp=0 is important for better performance on deciding moves, and to
|
||||
# Temp=0 is important for better performance on deciding moves, and to
|
||||
# ensure valid json outputs.
|
||||
def generate_random_seed(n_lines: int = 5, n_chars_per_line: int = 80):
|
||||
# Generate x lines of y random alphanumeric characters
|
||||
seed_lines = [
|
||||
''.join(random.choices(string.ascii_letters + string.digits, k=n_chars_per_line))
|
||||
for _ in range(n_lines)
|
||||
]
|
||||
random_seed_block = (
|
||||
"<RANDOM SEED PLEASE IGNORE>\n" +
|
||||
"\n".join(seed_lines) +
|
||||
"\n</RANDOM SEED>"
|
||||
)
|
||||
return random_seed_block
|
||||
# Generate x lines of y random alphanumeric characters
|
||||
seed_lines = ["".join(random.choices(string.ascii_letters + string.digits, k=n_chars_per_line)) for _ in range(n_lines)]
|
||||
random_seed_block = "<RANDOM SEED PLEASE IGNORE>\n" + "\n".join(seed_lines) + "\n</RANDOM SEED>"
|
||||
return random_seed_block
|
||||
|
||||
|
||||
def get_prompt_path(prompt_name: str) -> str:
|
||||
"""Get the appropriate prompt path based on USE_UNFORMATTED_PROMPTS setting.
|
||||
|
||||
|
||||
Args:
|
||||
prompt_name: Base name of the prompt file (e.g., "conversation_instructions.txt")
|
||||
|
||||
|
||||
Returns:
|
||||
str: Either "unformatted/{prompt_name}" or just "{prompt_name}"
|
||||
"""
|
||||
|
|
@ -472,36 +463,37 @@ def get_prompt_path(prompt_name: str) -> str:
|
|||
return f"unformatted/{prompt_name}"
|
||||
else:
|
||||
return prompt_name
|
||||
|
||||
def normalize_recipient_name(recipient: str) -> str:
|
||||
"""Normalize recipient names to handle LLM typos and abbreviations."""
|
||||
if not recipient:
|
||||
return recipient
|
||||
|
||||
recipient = recipient.upper().strip()
|
||||
|
||||
# Handle common LLM typos and abbreviations found in data
|
||||
name_mapping = {
|
||||
'EGMANY': 'GERMANY',
|
||||
'GERMAN': 'GERMANY',
|
||||
'UK': 'ENGLAND',
|
||||
'BRIT': 'ENGLAND',
|
||||
'ENGLAND': 'ENGLAND', # Keep as-is
|
||||
'FRANCE': 'FRANCE', # Keep as-is
|
||||
'GERMANY': 'GERMANY', # Keep as-is
|
||||
'ITALY': 'ITALY', # Keep as-is
|
||||
'AUSTRIA': 'AUSTRIA', # Keep as-is
|
||||
'RUSSIA': 'RUSSIA', # Keep as-is
|
||||
'TURKEY': 'TURKEY', # Keep as-is
|
||||
'Germany': 'GERMANY',
|
||||
'England': 'ENGLAND',
|
||||
'France': 'FRANCE',
|
||||
'Italy': 'ITALY',
|
||||
'Russia': 'RUSSIA',
|
||||
'Austria': 'AUSTRIA',
|
||||
'Turkey': 'TURKEY',
|
||||
}
|
||||
|
||||
normalized = name_mapping.get(recipient, recipient)
|
||||
|
||||
return normalized
|
||||
|
||||
def normalize_recipient_name(recipient: str) -> str:
|
||||
"""Normalize recipient names to handle LLM typos and abbreviations."""
|
||||
if not recipient:
|
||||
return recipient
|
||||
|
||||
recipient = recipient.upper().strip()
|
||||
|
||||
# Handle common LLM typos and abbreviations found in data
|
||||
name_mapping = {
|
||||
"EGMANY": "GERMANY",
|
||||
"GERMAN": "GERMANY",
|
||||
"UK": "ENGLAND",
|
||||
"BRIT": "ENGLAND",
|
||||
"ENGLAND": "ENGLAND", # Keep as-is
|
||||
"FRANCE": "FRANCE", # Keep as-is
|
||||
"GERMANY": "GERMANY", # Keep as-is
|
||||
"ITALY": "ITALY", # Keep as-is
|
||||
"AUSTRIA": "AUSTRIA", # Keep as-is
|
||||
"RUSSIA": "RUSSIA", # Keep as-is
|
||||
"TURKEY": "TURKEY", # Keep as-is
|
||||
"Germany": "GERMANY",
|
||||
"England": "ENGLAND",
|
||||
"France": "FRANCE",
|
||||
"Italy": "ITALY",
|
||||
"Russia": "RUSSIA",
|
||||
"Austria": "AUSTRIA",
|
||||
"Turkey": "TURKEY",
|
||||
}
|
||||
|
||||
normalized = name_mapping.get(recipient, recipient)
|
||||
|
||||
return normalized
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue